TraceStudio-dist/cloud/custom_nodes/assert_loading_test.py
2026-01-13 19:26:36 +08:00

109 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Optional
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
)
from server.app.core.node_registry import register_node
import polars as pl
@register_node
class DungeonFilterNode2(TraceNode):
"""
副本过滤器节点
"""
CATEGORY = "Filiter"
DISPLAY_NAME = "副本过滤器2"
DESCRIPTION = "依据规则筛选出副本内的数据"
ICON = "📥"
@input_port("metadata", "DataTable", description="输入的元数据表")
@param("sep", "String", default=",", description="分隔符,例如 , 或 \t")
@output_port("metadata", "DataTable", description="导出的元数据表")
@output_port("table", "DataTable", description="加载得到的表polars.DataFrame")
def process(self, inputs: Dict[str, Any], context: Optional[object] = None) -> Dict[str, Any]:
metadata = inputs.get("metadata", None)
if metadata is None or not isinstance(metadata, pl.DataFrame):
raise ValueError("metadata 输入必须为 polars.DataFrame")
# 1. 过滤主线程
main_thread_id = 2
if "ThreadId" in metadata.columns:
main_thread_df = metadata.filter(pl.col("ThreadId") == main_thread_id)
else:
main_thread_df = metadata
# 2. 找到所有进入副本的时间点
if "Metadata" not in main_thread_df.columns:
raise ValueError("缺少 Metadata 列")
enter_mask = main_thread_df["Metadata"].str.contains("PackageName:/Game/Asset/Audio/FMOD/Events/bgm/cbt01/level/zhuiji")
enter_idx = enter_mask.to_list()
enter_indices = [i for i, v in enumerate(enter_idx) if v]
if not enter_indices:
raise ValueError("未找到进入副本的事件")
# 3. 找到所有离开副本的时间点
leave_mask = main_thread_df["Metadata"].str.contains("Token:TriggerPlayerFinish")
leave_idx = leave_mask.to_list()
leave_indices = [i for i, v in enumerate(leave_idx) if v]
# 允许 leave_indices 为空,后续逻辑自动用最后一条数据为结尾
# 4. 匹配每一次进入-离开区间,生成 Dungeon 区间
dungeon_ranges = []
leave_ptr = 0
n_rows = main_thread_df.height
for dungeon_num, enter_index in enumerate(enter_indices, 1):
# 找到第一个大于enter_index的leave_index
while leave_ptr < len(leave_indices) and leave_indices[leave_ptr] <= enter_index:
leave_ptr += 1
if leave_ptr < len(leave_indices):
leave_index = leave_indices[leave_ptr] - 1
leave_ptr += 1
else:
# 没有匹配到离开副本事件,取最后一条数据为结尾
leave_index = n_rows - 1
# 只有当进入点在结尾之前才输出区间
if enter_index <= leave_index:
dungeon_ranges.append((enter_index, leave_index, dungeon_num))
if not dungeon_ranges:
raise ValueError("未能匹配到任何副本区间(所有进入点都在数据结尾之后)")
# 5. 合并所有区间,并新增 Dungeon 列
dfs = []
for enter_index, leave_index, dungeon_num in dungeon_ranges:
df = main_thread_df.slice(enter_index, leave_index - enter_index + 1).with_columns([
pl.lit(dungeon_num).alias("Dungeon")
])
dfs.append(df)
if dfs:
filtered_df = pl.concat(dfs)
else:
filtered_df = pl.DataFrame([])
# 示例:将筛选到的区间数量写入节点私有上下文,并记录到全局上下文的统计
try:
# 某些旧实现仍然接收 dict上下文可能是 ExecutionContext 或纯 dict
node_id = getattr(self, "node_id", None) or "unknown"
dungeon_count = len(dungeon_ranges)
if hasattr(context, "update_node_private"):
context.update_node_private(node_id, "dungeon_count", dungeon_count)
elif isinstance(context, dict):
ns = context.setdefault("nodes", {})
ns.setdefault(node_id, {})["dungeon_count"] = dungeon_count
if hasattr(context, "update_global"):
prev = context.get_global("total_dungeons", 0)
context.update_global("total_dungeons", prev + dungeon_count)
elif isinstance(context, dict):
context.setdefault("global", {})["total_dungeons"] = context.get("global", {}).get("total_dungeons", 0) + dungeon_count
except Exception:
# 不要阻塞主流程:上下文更新失败应当是非致命的
pass
#print(f"DungeonFilterNode: filtered to {len(dungeon_ranges)} dungeon ranges {filtered_df.shape}.")
outputs = {"metadata": filtered_df}
return outputs