TraceStudio-dev/server/app/nodes/transform_nodes.py
2026-01-09 21:37:02 +08:00

177 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType
)
from server.app.core.node_registry import register_node
# ============= 聚合节点 =============
@register_node
class DataFrameConcatNode(TraceNode):
"""聚合节点 - 将多个 polars DataFrame 按指定列聚合为一张表"""
CATEGORY = "Transform"
DISPLAY_NAME = "表聚合"
DESCRIPTION = "将多个表合并为一张表,可指定聚合方式和列"
ICON = "🧩"
@input_port("tables", "Array<DataTable>", description="待聚合的表数组", required=True)
@output_port("result", "DataTable", description="聚合后的表")
@param("how", "String", default="vertical", description="聚合方式,可选 vertical/horizontal/groupby")
@param("on", "String", default="", description="聚合列(横向聚合/join 或 groupby 时用逗号分隔)")
@param("agg_mode", "String", default="{}", description="""聚合模式json字符串{"Duration":"mean","Score":"sum"}未指定的列默认first。若需在输出中加入分组计数列请在 agg_mode 中添加 "my_count_col":"count"(键为希望输出的列名,值为 "count")。""")
def process(self, inputs, context=None):
import polars as pl
import json
tables = inputs.get("tables", [])
if not tables:
raise ValueError("输入 tables 不能为空")
non_pl_types = [type(df).__name__ for df in tables if not isinstance(df, pl.DataFrame)]
if non_pl_types:
raise ValueError(f"输入必须为 polars.DataFrame 的数组,发现类型: {non_pl_types}. 后端仅支持 polars。")
how = self.get_param("how", "vertical")
on = self.get_param("on", "")
agg_mode = self.get_param("agg_mode", "{}")
if how == "vertical":
# 纵向拼接
result = pl.concat(tables, how="vertical")
elif how == "horizontal":
# 横向聚合,按 on 指定的列做 join
on_cols = [col.strip() for col in on.split(",") if col.strip()]
if not on_cols:
raise ValueError("横向聚合必须指定聚合列 on")
result = tables[0]
for df in tables[1:]:
result = result.join(df, on=on_cols, how="inner")
elif how == "groupby":
# 分组聚合on 作为分组列
result = pl.concat(tables, how="vertical")
group_cols = [col.strip() for col in on.split(",") if col.strip()]
if not group_cols:
raise ValueError("分组聚合必须指定 on")
# 解析聚合表达式
try:
agg_dict = json.loads(agg_mode) if isinstance(agg_mode, str) else (agg_mode or {})
except Exception as e:
raise ValueError(f"聚合模式(agg_mode)必须为合法json字符串: {e} {agg_mode}")
agg_exprs = []
# 1) 先处理 agg_mode 中显式指定的聚合(支持为不存在的键指定 count生成自定义计数列
if isinstance(agg_dict, dict):
for out_col, method in agg_dict.items():
# 忽略分组列在 agg_dict 中的指定
if out_col in group_cols:
continue
if isinstance(method, str) and method == "count":
# 如果方法为 count不要求 out_col 存在于原始列,可以作为自定义计数列名
agg_exprs.append(pl.count().alias(out_col))
continue
# 如果 out_col 对应原始列,则按方法聚合
if out_col in result.columns:
m = (method or "first")
if m == "mean":
agg_exprs.append(pl.col(out_col).mean().alias(out_col))
elif m == "sum":
agg_exprs.append(pl.col(out_col).sum().alias(out_col))
elif m == "min":
agg_exprs.append(pl.col(out_col).min().alias(out_col))
elif m == "max":
agg_exprs.append(pl.col(out_col).max().alias(out_col))
elif m == "first":
agg_exprs.append(pl.col(out_col).first().alias(out_col))
elif m == "last":
agg_exprs.append(pl.col(out_col).last().alias(out_col))
elif m == "count":
agg_exprs.append(pl.col(out_col).count().alias(out_col))
else:
raise ValueError(f"不支持的聚合方式: {m} (列: {out_col})")
else:
# 非 count 的自定义列名不支持(没有对应源列)
raise ValueError(f"未知聚合列: {out_col},仅支持对存在列的聚合或使用 'count' 生成计数列")
# 2) 对未在 agg_mode 中显式指定的原始列使用默认聚合 first
for col in result.columns:
if col in group_cols:
continue
if isinstance(agg_dict, dict) and col in agg_dict:
continue
agg_exprs.append(pl.col(col).first().alias(col))
result = result.group_by(group_cols).agg(agg_exprs)
else:
raise ValueError("不支持的聚合方式: %s" % how)
return {"result": result}
# ============= 列操作节点 =============
@register_node
class ColumnOpsNode(TraceNode):
"""列操作节点 - 支持新增/删除/修改列"""
CATEGORY = "Transform"
DISPLAY_NAME = "列操作"
DESCRIPTION = "支持新增、删除、修改列,表达式灵活"
ICON = "🛠️"
@input_port("table", "DataTable", description="待变换的表", required=True)
@output_port("result", "DataTable", description="变换后的表")
@param("ops", "String", default="{}", description="""操作表json字符串支持新增/修改/删除列。如 {"add": {"Duration": "(EndTime-StartTime)*1000"}, "drop": ["EndTime", "StartTime", "Depth", "ThreadId"]} """)
def process(self, inputs, context=None):
import polars as pl
import json
table = inputs.get("table", None)
if table is None or not isinstance(table, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
ops = self.get_param("ops", "{}")
try:
ops_dict = json.loads(ops) if isinstance(ops, str) else (ops or {})
except Exception as e:
raise ValueError(f"操作表(ops)必须为合法json字符串: {e} {ops}")
df = table
# 新增/修改列
add_ops = ops_dict.get("add", {})
new_columns = []
for col, expr in add_ops.items():
# expr 支持简单表达式,如 '(EndTime-StartTime)*1000'
# 用 eval 安全地解析表达式
try:
# 构造 polars 表达式
expr_code = expr.replace(" ", "")
# 支持加减乘除和括号
# 只允许列名和运算符
import re
tokens = re.split(r'([\+\-\*/\(\)])', expr_code)
pl_expr = ""
for t in tokens:
if t in {"+", "-", "*", "/", "(", ")"}:
pl_expr += t
elif t:
pl_expr += f"pl.col('{t}')" if t.isidentifier() else t
# 通过 eval 构造 polars 表达式
new_columns.append(eval(pl_expr).alias(col))
except Exception as e:
raise ValueError(f"新增/修改列 {col} 表达式错误: {expr}, {e}")
if new_columns:
df = df.with_columns(new_columns)
# 删除列
drop_ops = ops_dict.get("drop", [])
keep_cols = [col for col in df.columns if col not in drop_ops]
df = df.select(keep_cols)
return {"result": df}