109 lines
4.8 KiB
Python
109 lines
4.8 KiB
Python
from typing import Any, Dict, Optional
|
||
from server.app.core.node_base import (
|
||
TraceNode,
|
||
input_port,
|
||
output_port,
|
||
param,
|
||
)
|
||
from server.app.core.node_registry import register_node
|
||
import polars as pl
|
||
|
||
@register_node
|
||
class DungeonFilterNode2(TraceNode):
|
||
"""
|
||
副本过滤器节点
|
||
"""
|
||
|
||
CATEGORY = "Filiter"
|
||
DISPLAY_NAME = "副本过滤器2"
|
||
DESCRIPTION = "依据规则筛选出副本内的数据"
|
||
ICON = "📥"
|
||
|
||
@input_port("metadata", "DataTable", description="输入的元数据表")
|
||
@param("sep", "String", default=",", description="分隔符,例如 , 或 \t")
|
||
@output_port("metadata", "DataTable", description="导出的元数据表")
|
||
@output_port("table", "DataTable", description="加载得到的表(polars.DataFrame)")
|
||
def process(self, inputs: Dict[str, Any], context: Optional[object] = None) -> Dict[str, Any]:
|
||
metadata = inputs.get("metadata", None)
|
||
if metadata is None or not isinstance(metadata, pl.DataFrame):
|
||
raise ValueError("metadata 输入必须为 polars.DataFrame")
|
||
|
||
# 1. 过滤主线程
|
||
main_thread_id = 2
|
||
if "ThreadId" in metadata.columns:
|
||
main_thread_df = metadata.filter(pl.col("ThreadId") == main_thread_id)
|
||
else:
|
||
main_thread_df = metadata
|
||
|
||
# 2. 找到所有进入副本的时间点
|
||
if "Metadata" not in main_thread_df.columns:
|
||
raise ValueError("缺少 Metadata 列")
|
||
enter_mask = main_thread_df["Metadata"].str.contains("PackageName:/Game/Asset/Audio/FMOD/Events/bgm/cbt01/level/zhuiji")
|
||
enter_idx = enter_mask.to_list()
|
||
enter_indices = [i for i, v in enumerate(enter_idx) if v]
|
||
if not enter_indices:
|
||
raise ValueError("未找到进入副本的事件")
|
||
|
||
# 3. 找到所有离开副本的时间点
|
||
leave_mask = main_thread_df["Metadata"].str.contains("Token:TriggerPlayerFinish")
|
||
leave_idx = leave_mask.to_list()
|
||
leave_indices = [i for i, v in enumerate(leave_idx) if v]
|
||
# 允许 leave_indices 为空,后续逻辑自动用最后一条数据为结尾
|
||
|
||
# 4. 匹配每一次进入-离开区间,生成 Dungeon 区间
|
||
|
||
dungeon_ranges = []
|
||
leave_ptr = 0
|
||
n_rows = main_thread_df.height
|
||
for dungeon_num, enter_index in enumerate(enter_indices, 1):
|
||
# 找到第一个大于enter_index的leave_index
|
||
while leave_ptr < len(leave_indices) and leave_indices[leave_ptr] <= enter_index:
|
||
leave_ptr += 1
|
||
if leave_ptr < len(leave_indices):
|
||
leave_index = leave_indices[leave_ptr] - 1
|
||
leave_ptr += 1
|
||
else:
|
||
# 没有匹配到离开副本事件,取最后一条数据为结尾
|
||
leave_index = n_rows - 1
|
||
# 只有当进入点在结尾之前才输出区间
|
||
if enter_index <= leave_index:
|
||
dungeon_ranges.append((enter_index, leave_index, dungeon_num))
|
||
|
||
if not dungeon_ranges:
|
||
raise ValueError("未能匹配到任何副本区间(所有进入点都在数据结尾之后)")
|
||
|
||
# 5. 合并所有区间,并新增 Dungeon 列
|
||
dfs = []
|
||
for enter_index, leave_index, dungeon_num in dungeon_ranges:
|
||
df = main_thread_df.slice(enter_index, leave_index - enter_index + 1).with_columns([
|
||
pl.lit(dungeon_num).alias("Dungeon")
|
||
])
|
||
dfs.append(df)
|
||
if dfs:
|
||
filtered_df = pl.concat(dfs)
|
||
else:
|
||
filtered_df = pl.DataFrame([])
|
||
|
||
# 示例:将筛选到的区间数量写入节点私有上下文,并记录到全局上下文的统计
|
||
try:
|
||
# 某些旧实现仍然接收 dict,上下文可能是 ExecutionContext 或纯 dict
|
||
node_id = getattr(self, "node_id", None) or "unknown"
|
||
dungeon_count = len(dungeon_ranges)
|
||
if hasattr(context, "update_node_private"):
|
||
context.update_node_private(node_id, "dungeon_count", dungeon_count)
|
||
elif isinstance(context, dict):
|
||
ns = context.setdefault("nodes", {})
|
||
ns.setdefault(node_id, {})["dungeon_count"] = dungeon_count
|
||
|
||
if hasattr(context, "update_global"):
|
||
prev = context.get_global("total_dungeons", 0)
|
||
context.update_global("total_dungeons", prev + dungeon_count)
|
||
elif isinstance(context, dict):
|
||
context.setdefault("global", {})["total_dungeons"] = context.get("global", {}).get("total_dungeons", 0) + dungeon_count
|
||
except Exception:
|
||
# 不要阻塞主流程:上下文更新失败应当是非致命的
|
||
pass
|
||
#print(f"DungeonFilterNode: filtered to {len(dungeon_ranges)} dungeon ranges {filtered_df.shape}.")
|
||
outputs = {"metadata": filtered_df}
|
||
return outputs
|
||
|