from typing import Any, Dict from server.app.core.node_base import ( TraceNode, input_port, output_port, param, context_var, NodeType ) from server.app.core.node_registry import register_node # ============= 聚合节点 ============= @register_node class DataFrameConcatNode(TraceNode): """聚合节点 - 将多个 polars DataFrame 按指定列聚合为一张表""" CATEGORY = "Transform" DISPLAY_NAME = "表聚合" DESCRIPTION = "将多个表合并为一张表,可指定聚合方式和列" ICON = "🧩" @input_port("tables", "Array", description="待聚合的表数组", required=True) @output_port("result", "DataTable", description="聚合后的表") @param("how", "String", default="vertical", description="聚合方式,可选 vertical/horizontal/groupby") @param("on", "String", default="", description="聚合列(横向聚合/join 或 groupby 时用逗号分隔)") @param("agg_mode", "String", default="{}", description="""聚合模式,json字符串,如{"Duration":"mean","Score":"sum"},未指定的列默认first。若需在输出中加入分组计数列,请在 agg_mode 中添加 "my_count_col":"count"(键为希望输出的列名,值为 "count")。""") def process(self, inputs, context=None): import polars as pl import json tables = inputs.get("tables", []) if not tables: raise ValueError("输入 tables 不能为空") non_pl_types = [type(df).__name__ for df in tables if not isinstance(df, pl.DataFrame)] if non_pl_types: raise ValueError(f"输入必须为 polars.DataFrame 的数组,发现类型: {non_pl_types}. 后端仅支持 polars。") how = self.get_param("how", "vertical") on = self.get_param("on", "") agg_mode = self.get_param("agg_mode", "{}") if how == "vertical": # 纵向拼接 result = pl.concat(tables, how="vertical") elif how == "horizontal": # 横向聚合,按 on 指定的列做 join on_cols = [col.strip() for col in on.split(",") if col.strip()] if not on_cols: raise ValueError("横向聚合必须指定聚合列 on") result = tables[0] for df in tables[1:]: result = result.join(df, on=on_cols, how="inner") elif how == "groupby": # 分组聚合,on 作为分组列 result = pl.concat(tables, how="diagonal") group_cols = [col.strip() for col in on.split(",") if col.strip()] if not group_cols: raise ValueError("分组聚合必须指定 on") # 解析聚合表达式 try: agg_dict = json.loads(agg_mode) if isinstance(agg_mode, str) else (agg_mode or {}) except Exception as e: raise ValueError(f"聚合模式(agg_mode)必须为合法json字符串: {e} {agg_mode}") agg_exprs = [] # 1) 先处理 agg_mode 中显式指定的聚合(支持为不存在的键指定 count,生成自定义计数列) if isinstance(agg_dict, dict): for out_col, method in agg_dict.items(): # 忽略分组列在 agg_dict 中的指定 if out_col in group_cols: continue if isinstance(method, str) and method == "count": # 如果方法为 count,不要求 out_col 存在于原始列,可以作为自定义计数列名 agg_exprs.append(pl.count().alias(out_col)) continue # 如果 out_col 对应原始列,则按方法聚合 if out_col in result.columns: m = (method or "first") if m == "mean": agg_exprs.append(pl.col(out_col).mean().alias(out_col)) elif m == "sum": agg_exprs.append(pl.col(out_col).sum().alias(out_col)) elif m == "min": agg_exprs.append(pl.col(out_col).min().alias(out_col)) elif m == "max": agg_exprs.append(pl.col(out_col).max().alias(out_col)) elif m == "first": agg_exprs.append(pl.col(out_col).first().alias(out_col)) elif m == "last": agg_exprs.append(pl.col(out_col).last().alias(out_col)) elif m == "count": agg_exprs.append(pl.col(out_col).count().alias(out_col)) else: raise ValueError(f"不支持的聚合方式: {m} (列: {out_col})") else: # 非 count 的自定义列名不支持(没有对应源列) raise ValueError(f"未知聚合列: {out_col},仅支持对存在列的聚合或使用 'count' 生成计数列") # 2) 对未在 agg_mode 中显式指定的原始列使用默认聚合 first for col in result.columns: if col in group_cols: continue if isinstance(agg_dict, dict) and col in agg_dict: continue agg_exprs.append(pl.col(col).first().alias(col)) result = result.group_by(group_cols).agg(agg_exprs) else: raise ValueError("不支持的聚合方式: %s" % how) return {"result": result} # ============= 列操作节点 ============= @register_node class ColumnOpsNode(TraceNode): """列操作节点 - 支持新增/删除/修改列""" CATEGORY = "Transform" DISPLAY_NAME = "列操作" DESCRIPTION = "支持新增、删除、修改列,表达式灵活" ICON = "🛠️" @input_port("table", "DataTable", description="待变换的表", required=True) @output_port("result", "DataTable", description="变换后的表") @param("ops", "String", default="{}", description="""操作表,json字符串,支持新增/修改/删除列。如 {"add": {"Duration": "(EndTime-StartTime)*1000"}, "drop": ["EndTime", "StartTime", "Depth", "ThreadId"]} """) def process(self, inputs, context=None): import polars as pl import json table = inputs.get("table", None) if table is None or not isinstance(table, pl.DataFrame): raise ValueError("输入必须为 polars.DataFrame") ops = self.get_param("ops", "{}") try: ops_dict = json.loads(ops) if isinstance(ops, str) else (ops or {}) except Exception as e: raise ValueError(f"操作表(ops)必须为合法json字符串: {e} {ops}") df = table # 新增/修改列 add_ops = ops_dict.get("add", {}) new_columns = [] for col, expr in add_ops.items(): # expr 支持简单表达式,如 '(EndTime-StartTime)*1000' # 用 eval 安全地解析表达式 try: # 构造 polars 表达式 expr_code = expr.replace(" ", "") # 支持加减乘除和括号 # 只允许列名和运算符 import re tokens = re.split(r'([\+\-\*/\(\)])', expr_code) pl_expr = "" for t in tokens: if t in {"+", "-", "*", "/", "(", ")"}: pl_expr += t elif t: pl_expr += f"pl.col('{t}')" if t.isidentifier() else t # 通过 eval 构造 polars 表达式 new_columns.append(eval(pl_expr).alias(col)) except Exception as e: raise ValueError(f"新增/修改列 {col} 表达式错误: {expr}, {e}") if new_columns: df = df.with_columns(new_columns) # 删除列 drop_ops = ops_dict.get("drop", []) keep_cols = [col for col in df.columns if col not in drop_ops] df = df.select(keep_cols) return {"result": df}