TraceStudio-dev/cloud/custom_nodes/assert_loading.py

183 lines
8.0 KiB
Python
Raw Normal View History

from typing import Any, Dict, Optional
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
)
from server.app.core.node_registry import register_node
import polars as pl
@register_node
class DungeonFilterNode(TraceNode):
"""
副本过滤器节点
"""
CATEGORY = "Filiter"
DISPLAY_NAME = "副本过滤器"
DESCRIPTION = "依据规则筛选出副本内的数据"
ICON = "📥"
@input_port("metadata", "DataTable", description="输入的元数据表")
@output_port("metadata", "DataTable", description="导出的元数据表")
def process(self, inputs: Dict[str, Any], context: Optional[object] = None) -> Dict[str, Any]:
metadata = inputs.get("metadata", None)
if metadata is None or not isinstance(metadata, pl.DataFrame):
raise ValueError("metadata 输入必须为 polars.DataFrame")
# 1. 过滤主线程
main_thread_id = 2
if "ThreadId" in metadata.columns:
main_thread_df = metadata.filter(pl.col("ThreadId") == main_thread_id)
else:
main_thread_df = metadata
# 2. 找到所有进入副本的时间点
if "Metadata" not in main_thread_df.columns:
raise ValueError("缺少 Metadata 列")
enter_mask = main_thread_df["Metadata"].str.contains("PackageName:/Game/Asset/Audio/FMOD/Events/bgm/cbt01/level/zhuiji")
enter_idx = enter_mask.to_list()
enter_indices = [i for i, v in enumerate(enter_idx) if v]
if not enter_indices:
raise ValueError("未找到进入副本的事件")
# 3. 找到所有离开副本的时间点
leave_mask = main_thread_df["Metadata"].str.contains("Token:TriggerPlayerFinish")
leave_idx = leave_mask.to_list()
leave_indices = [i for i, v in enumerate(leave_idx) if v]
# 允许 leave_indices 为空,后续逻辑自动用最后一条数据为结尾
# 4. 匹配每一次进入-离开区间,生成 Dungeon 区间
dungeon_ranges = []
leave_ptr = 0
n_rows = main_thread_df.height
for dungeon_num, enter_index in enumerate(enter_indices, 1):
# 找到第一个大于enter_index的leave_index
while leave_ptr < len(leave_indices) and leave_indices[leave_ptr] <= enter_index:
leave_ptr += 1
if leave_ptr < len(leave_indices):
leave_index = leave_indices[leave_ptr]
leave_ptr += 1
else:
# 没有匹配到离开副本事件,取最后一条数据为结尾
leave_index = n_rows - 1
# 只有当进入点在结尾之前才输出区间
if enter_index <= leave_index:
dungeon_ranges.append((enter_index, leave_index, dungeon_num))
if not dungeon_ranges:
raise ValueError("未能匹配到任何副本区间(所有进入点都在数据结尾之后)")
# 5. 合并所有区间,并新增 Dungeon 列
dfs = []
for enter_index, leave_index, dungeon_num in dungeon_ranges:
df = main_thread_df.slice(enter_index, leave_index - enter_index + 1).with_columns([
pl.lit(dungeon_num).alias("Dungeon")
])
dfs.append(df)
if dfs:
filtered_df = pl.concat(dfs)
else:
filtered_df = pl.DataFrame([])
# 示例:将筛选到的区间数量写入节点私有上下文,并记录到全局上下文的统计
try:
# 某些旧实现仍然接收 dict上下文可能是 ExecutionContext 或纯 dict
node_id = getattr(self, "node_id", None) or "unknown"
dungeon_count = len(dungeon_ranges)
if hasattr(context, "update_node_private"):
context.update_node_private(node_id, "dungeon_count", dungeon_count)
elif isinstance(context, dict):
ns = context.setdefault("nodes", {})
ns.setdefault(node_id, {})["dungeon_count"] = dungeon_count
if hasattr(context, "update_global"):
prev = context.get_global("total_dungeons", 0)
context.update_global("total_dungeons", prev + dungeon_count)
elif isinstance(context, dict):
context.setdefault("global", {})["total_dungeons"] = context.get("global", {}).get("total_dungeons", 0) + dungeon_count
except Exception:
# 不要阻塞主流程:上下文更新失败应当是非致命的
pass
#print(f"DungeonFilterNode: filtered to {len(dungeon_ranges)} dungeon ranges {filtered_df.shape}.")
outputs = {"metadata": filtered_df}
return outputs
# ============= 资源标签分类节点 =============
@register_node
class AssetLabelNode(TraceNode):
"""资源标签分类节点 - 解析 Metadata 列并按规则分类标签"""
CATEGORY = "Transform"
DISPLAY_NAME = "资源标签分类"
DESCRIPTION = "根据资源路径对数据进行标签分类,支持自定义规则"
ICON = "🏷️"
@input_port("table", "DataTable", description="待分类的表,需包含 Metadata 列", required=True)
@output_port("result", "DataTable", description="带标签的表")
def process(self, inputs, context=None):
import polars as pl
table = inputs.get("table", None)
if table is None or not isinstance(table, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
if "Metadata" not in table.columns:
raise ValueError("表中缺少 Metadata 列")
# 标签规则
label_rules = [
("/Game/Asset/Audio", "Audio"),
("/Game/Asset/Char/Player", "Player"),
("/Game/Asset/Effect/Niagara", "Niagara"),
("/Game/UI/Texture", "Texture"),
("/Game/Blueprints/Combat", "Combat"),
]
def extract_path(meta):
# 去掉 PackageName: 前缀
if isinstance(meta, str) and meta.startswith("PackageName:"):
return meta[len("PackageName:"):]
return meta
def label_func(meta):
path = extract_path(meta)
if not isinstance(path, str):
return "Unknown"
# 先查规则,命中直接返回
for rule_path, label in label_rules:
if path.startswith(rule_path):
return label
# 没命中规则,找第一个未被规则覆盖的目录
parts = path.strip("/").split("/")
# 逐层拼接,找第一个未被规则覆盖的目录
prefix = ""
for i in range(len(parts)):
prefix = prefix + "/" + parts[i] if prefix else "/" + parts[i]
if not any(prefix.startswith(rule[0]) for rule in label_rules):
return parts[i]
return parts[-1] if parts else "Unknown"
# 新增 Label 列
# 某些 polars 版本的 Expr 没有 apply 方法,改为在 Python 层对列做 map然后附加回表
meta_list = table["Metadata"].to_list()
labels = [label_func(m) for m in meta_list]
result = table.with_columns([
pl.Series("Label", labels)
])
# 按 Metadata 字典序排序
result = result.sort("Metadata")
# 兼容:如果传入的是 ExecutionContext演示如何查询节点私有数据
try:
node_id = getattr(self, "node_id", None) or "unknown"
if hasattr(context, "get_node_private"):
# 读取同一执行里其他节点写入的私有数据(只作示例,不做强依赖)
prev = context.get_node_private(node_id, "dungeon_count", None)
if prev is not None:
# 将读取到的值写入全局统计作为使用示例
if hasattr(context, "update_global"):
context.update_global("seen_dungeon_count_for_" + node_id, prev)
except Exception:
pass
return {"result": result}