TraceStudio-dev/cloud/custom_nodes/assert_loading.py
Boshuang Zhao 5790ec164f add web v2
2026-01-10 19:08:49 +08:00

183 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Optional
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
)
from server.app.core.node_registry import register_node
import polars as pl
@register_node
class DungeonFilterNode(TraceNode):
"""
副本过滤器节点
"""
CATEGORY = "Filiter"
DISPLAY_NAME = "副本过滤器"
DESCRIPTION = "依据规则筛选出副本内的数据"
ICON = "📥"
@input_port("metadata", "DataTable", description="输入的元数据表")
@output_port("metadata", "DataTable", description="导出的元数据表")
def process(self, inputs: Dict[str, Any], context: Optional[object] = None) -> Dict[str, Any]:
metadata = inputs.get("metadata", None)
if metadata is None or not isinstance(metadata, pl.DataFrame):
raise ValueError("metadata 输入必须为 polars.DataFrame")
# 1. 过滤主线程
main_thread_id = 2
if "ThreadId" in metadata.columns:
main_thread_df = metadata.filter(pl.col("ThreadId") == main_thread_id)
else:
main_thread_df = metadata
# 2. 找到所有进入副本的时间点
if "Metadata" not in main_thread_df.columns:
raise ValueError("缺少 Metadata 列")
enter_mask = main_thread_df["Metadata"].str.contains("PackageName:/Game/Asset/Audio/FMOD/Events/bgm/cbt01/level/zhuiji")
enter_idx = enter_mask.to_list()
enter_indices = [i for i, v in enumerate(enter_idx) if v]
if not enter_indices:
raise ValueError("未找到进入副本的事件")
# 3. 找到所有离开副本的时间点
leave_mask = main_thread_df["Metadata"].str.contains("Token:TriggerPlayerFinish")
leave_idx = leave_mask.to_list()
leave_indices = [i for i, v in enumerate(leave_idx) if v]
# 允许 leave_indices 为空,后续逻辑自动用最后一条数据为结尾
# 4. 匹配每一次进入-离开区间,生成 Dungeon 区间
dungeon_ranges = []
leave_ptr = 0
n_rows = main_thread_df.height
for dungeon_num, enter_index in enumerate(enter_indices, 1):
# 找到第一个大于enter_index的leave_index
while leave_ptr < len(leave_indices) and leave_indices[leave_ptr] <= enter_index:
leave_ptr += 1
if leave_ptr < len(leave_indices):
leave_index = leave_indices[leave_ptr] - 1
leave_ptr += 1
else:
# 没有匹配到离开副本事件,取最后一条数据为结尾
leave_index = n_rows - 1
# 只有当进入点在结尾之前才输出区间
if enter_index <= leave_index:
dungeon_ranges.append((enter_index, leave_index, dungeon_num))
if not dungeon_ranges:
raise ValueError("未能匹配到任何副本区间(所有进入点都在数据结尾之后)")
# 5. 合并所有区间,并新增 Dungeon 列
dfs = []
for enter_index, leave_index, dungeon_num in dungeon_ranges:
df = main_thread_df.slice(enter_index, leave_index - enter_index + 1).with_columns([
pl.lit(dungeon_num).alias("Dungeon")
])
dfs.append(df)
if dfs:
filtered_df = pl.concat(dfs)
else:
filtered_df = pl.DataFrame([])
# 示例:将筛选到的区间数量写入节点私有上下文,并记录到全局上下文的统计
try:
# 某些旧实现仍然接收 dict上下文可能是 ExecutionContext 或纯 dict
node_id = getattr(self, "node_id", None) or "unknown"
dungeon_count = len(dungeon_ranges)
if hasattr(context, "update_node_private"):
context.update_node_private(node_id, "dungeon_count", dungeon_count)
elif isinstance(context, dict):
ns = context.setdefault("nodes", {})
ns.setdefault(node_id, {})["dungeon_count"] = dungeon_count
if hasattr(context, "update_global"):
prev = context.get_global("total_dungeons", 0)
context.update_global("total_dungeons", prev + dungeon_count)
elif isinstance(context, dict):
context.setdefault("global", {})["total_dungeons"] = context.get("global", {}).get("total_dungeons", 0) + dungeon_count
except Exception:
# 不要阻塞主流程:上下文更新失败应当是非致命的
pass
#print(f"DungeonFilterNode: filtered to {len(dungeon_ranges)} dungeon ranges {filtered_df.shape}.")
outputs = {"metadata": filtered_df}
return outputs
# ============= 资源标签分类节点 =============
@register_node
class AssetLabelNode(TraceNode):
"""资源标签分类节点 - 解析 Metadata 列并按规则分类标签"""
CATEGORY = "Transform"
DISPLAY_NAME = "资源标签分类"
DESCRIPTION = "根据资源路径对数据进行标签分类,支持自定义规则"
ICON = "🏷️"
@input_port("table", "DataTable", description="待分类的表,需包含 Metadata 列", required=True)
@output_port("result", "DataTable", description="带标签的表")
def process(self, inputs, context=None):
import polars as pl
table = inputs.get("table", None)
if table is None or not isinstance(table, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
if "Metadata" not in table.columns:
raise ValueError("表中缺少 Metadata 列")
# 标签规则
label_rules = [
("/Game/Asset/Audio", "Audio"),
("/Game/Asset/Char/Player", "Player"),
("/Game/Asset/Effect/Niagara", "Niagara"),
("/Game/UI/Texture", "Texture"),
("/Game/Blueprints/Combat", "Combat"),
]
def extract_path(meta):
# 去掉 PackageName: 前缀
if isinstance(meta, str) and meta.startswith("PackageName:"):
return meta[len("PackageName:"):]
return meta
def label_func(meta):
path = extract_path(meta)
if not isinstance(path, str):
return "Unknown"
# 先查规则,命中直接返回
for rule_path, label in label_rules:
if path.startswith(rule_path):
return label
# 没命中规则,找第一个未被规则覆盖的目录
parts = path.strip("/").split("/")
# 逐层拼接,找第一个未被规则覆盖的目录
prefix = ""
for i in range(len(parts)):
prefix = prefix + "/" + parts[i] if prefix else "/" + parts[i]
if not any(prefix.startswith(rule[0]) for rule in label_rules):
return parts[i]
return parts[-1] if parts else "Unknown"
# 新增 Label 列
# 某些 polars 版本的 Expr 没有 apply 方法,改为在 Python 层对列做 map然后附加回表
meta_list = table["Metadata"].to_list()
labels = [label_func(m) for m in meta_list]
result = table.with_columns([
pl.Series("Label", labels)
])
# 按 Metadata 字典序排序
result = result.sort("Metadata")
# 兼容:如果传入的是 ExecutionContext演示如何查询节点私有数据
try:
node_id = getattr(self, "node_id", None) or "unknown"
if hasattr(context, "get_node_private"):
# 读取同一执行里其他节点写入的私有数据(只作示例,不做强依赖)
prev = context.get_node_private(node_id, "dungeon_count", None)
if prev is not None:
# 将读取到的值写入全局统计作为使用示例
if hasattr(context, "update_global"):
context.update_global("seen_dungeon_count_for_" + node_id, prev)
except Exception:
pass
return {"result": result}