TraceStudio-dev/server/app/nodes/transform_nodes.py

177 lines
8.0 KiB
Python
Raw Normal View History

from typing import Any, Dict
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType
)
from server.app.core.node_registry import register_node
# ============= 聚合节点 =============
@register_node
class DataFrameConcatNode(TraceNode):
"""聚合节点 - 将多个 polars DataFrame 按指定列聚合为一张表"""
CATEGORY = "Transform"
DISPLAY_NAME = "表聚合"
DESCRIPTION = "将多个表合并为一张表,可指定聚合方式和列"
ICON = "🧩"
@input_port("tables", "Array<DataTable>", description="待聚合的表数组", required=True)
@output_port("result", "DataTable", description="聚合后的表")
@param("how", "String", default="vertical", description="聚合方式,可选 vertical/horizontal/groupby")
@param("on", "String", default="", description="聚合列(横向聚合/join 或 groupby 时用逗号分隔)")
@param("agg_mode", "String", default="{}", description="""聚合模式json字符串{"Duration":"mean","Score":"sum"}未指定的列默认first。若需在输出中加入分组计数列请在 agg_mode 中添加 "my_count_col":"count"(键为希望输出的列名,值为 "count")。""")
def process(self, inputs, context=None):
import polars as pl
import json
tables = inputs.get("tables", [])
if not tables:
raise ValueError("输入 tables 不能为空")
non_pl_types = [type(df).__name__ for df in tables if not isinstance(df, pl.DataFrame)]
if non_pl_types:
raise ValueError(f"输入必须为 polars.DataFrame 的数组,发现类型: {non_pl_types}. 后端仅支持 polars。")
how = self.get_param("how", "vertical")
on = self.get_param("on", "")
agg_mode = self.get_param("agg_mode", "{}")
if how == "vertical":
# 纵向拼接
result = pl.concat(tables, how="vertical")
elif how == "horizontal":
# 横向聚合,按 on 指定的列做 join
on_cols = [col.strip() for col in on.split(",") if col.strip()]
if not on_cols:
raise ValueError("横向聚合必须指定聚合列 on")
result = tables[0]
for df in tables[1:]:
result = result.join(df, on=on_cols, how="inner")
elif how == "groupby":
# 分组聚合on 作为分组列
result = pl.concat(tables, how="vertical")
group_cols = [col.strip() for col in on.split(",") if col.strip()]
if not group_cols:
raise ValueError("分组聚合必须指定 on")
# 解析聚合表达式
try:
agg_dict = json.loads(agg_mode) if isinstance(agg_mode, str) else (agg_mode or {})
except Exception as e:
raise ValueError(f"聚合模式(agg_mode)必须为合法json字符串: {e} {agg_mode}")
agg_exprs = []
# 1) 先处理 agg_mode 中显式指定的聚合(支持为不存在的键指定 count生成自定义计数列
if isinstance(agg_dict, dict):
for out_col, method in agg_dict.items():
# 忽略分组列在 agg_dict 中的指定
if out_col in group_cols:
continue
if isinstance(method, str) and method == "count":
# 如果方法为 count不要求 out_col 存在于原始列,可以作为自定义计数列名
agg_exprs.append(pl.count().alias(out_col))
continue
# 如果 out_col 对应原始列,则按方法聚合
if out_col in result.columns:
m = (method or "first")
if m == "mean":
agg_exprs.append(pl.col(out_col).mean().alias(out_col))
elif m == "sum":
agg_exprs.append(pl.col(out_col).sum().alias(out_col))
elif m == "min":
agg_exprs.append(pl.col(out_col).min().alias(out_col))
elif m == "max":
agg_exprs.append(pl.col(out_col).max().alias(out_col))
elif m == "first":
agg_exprs.append(pl.col(out_col).first().alias(out_col))
elif m == "last":
agg_exprs.append(pl.col(out_col).last().alias(out_col))
elif m == "count":
agg_exprs.append(pl.col(out_col).count().alias(out_col))
else:
raise ValueError(f"不支持的聚合方式: {m} (列: {out_col})")
else:
# 非 count 的自定义列名不支持(没有对应源列)
raise ValueError(f"未知聚合列: {out_col},仅支持对存在列的聚合或使用 'count' 生成计数列")
# 2) 对未在 agg_mode 中显式指定的原始列使用默认聚合 first
for col in result.columns:
if col in group_cols:
continue
if isinstance(agg_dict, dict) and col in agg_dict:
continue
agg_exprs.append(pl.col(col).first().alias(col))
result = result.group_by(group_cols).agg(agg_exprs)
else:
raise ValueError("不支持的聚合方式: %s" % how)
return {"result": result}
# ============= 列操作节点 =============
@register_node
class ColumnOpsNode(TraceNode):
"""列操作节点 - 支持新增/删除/修改列"""
CATEGORY = "Transform"
DISPLAY_NAME = "列操作"
DESCRIPTION = "支持新增、删除、修改列,表达式灵活"
ICON = "🛠️"
@input_port("table", "DataTable", description="待变换的表", required=True)
@output_port("result", "DataTable", description="变换后的表")
@param("ops", "String", default="{}", description="""操作表json字符串支持新增/修改/删除列。如 {"add": {"Duration": "(EndTime-StartTime)*1000"}, "drop": ["EndTime", "StartTime", "Depth", "ThreadId"]} """)
def process(self, inputs, context=None):
import polars as pl
import json
table = inputs.get("table", None)
if table is None or not isinstance(table, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
ops = self.get_param("ops", "{}")
try:
ops_dict = json.loads(ops) if isinstance(ops, str) else (ops or {})
except Exception as e:
raise ValueError(f"操作表(ops)必须为合法json字符串: {e} {ops}")
df = table
# 新增/修改列
add_ops = ops_dict.get("add", {})
new_columns = []
for col, expr in add_ops.items():
# expr 支持简单表达式,如 '(EndTime-StartTime)*1000'
# 用 eval 安全地解析表达式
try:
# 构造 polars 表达式
expr_code = expr.replace(" ", "")
# 支持加减乘除和括号
# 只允许列名和运算符
import re
tokens = re.split(r'([\+\-\*/\(\)])', expr_code)
pl_expr = ""
for t in tokens:
if t in {"+", "-", "*", "/", "(", ")"}:
pl_expr += t
elif t:
pl_expr += f"pl.col('{t}')" if t.isidentifier() else t
# 通过 eval 构造 polars 表达式
new_columns.append(eval(pl_expr).alias(col))
except Exception as e:
raise ValueError(f"新增/修改列 {col} 表达式错误: {expr}, {e}")
if new_columns:
df = df.with_columns(new_columns)
# 删除列
drop_ops = ops_dict.get("drop", [])
keep_cols = [col for col in df.columns if col not in drop_ops]
df = df.select(keep_cols)
return {"result": df}