""" TraceStudio 内置节点示例 (v2.0) 演示如何使用装饰器简化节点开发 """ from typing import Any, Dict from server.app.core.node_base import ( TraceNode, input_port, output_port, param, context_var, NodeType, CachePolicy ) from server.app.core.node_registry import register_node # ============= 示例 1: 简单数学节点 ============= @register_node class AddNode(TraceNode): """加法节点 - 演示最基础的节点定义""" CATEGORY = "Math/Basic" DISPLAY_NAME = "加法" DESCRIPTION = "计算两个数的和,支持偏移量" ICON = "➕" # 使用装饰器自动收集属性 @input_port("a", "Number", description="加数 A", required=True) @input_port("b", "Number", description="加数 B", required=True) @output_port("result", "Number", description="计算结果") @param("offset", "Number", default=0, description="偏移量", widget="slider", min=-100, max=100, step=1) @context_var("operation", "String", description="执行的运算") def process(self, inputs, context=None): a = inputs["a"] b = inputs["b"] offset = self.get_param("offset", 0) result = a + b + offset return {"result": result} @register_node class MultiplyNode(TraceNode): """乘法节点""" CATEGORY = "Math/Basic" DISPLAY_NAME = "乘法" DESCRIPTION = "计算两个数的乘积" ICON = "✖️" @input_port("a", "Number", description="乘数 A") @input_port("b", "Number", description="乘数 B") @output_port("result", "Number", description="乘积") @param("scale", "Number", default=1.0, description="缩放系数") @context_var("operation", "String", description="运算描述") def process(self, inputs, context=None): a = inputs["a"] b = inputs["b"] scale = self.get_param("scale", 1.0) result = a * b * scale return {"result": result} # ============= 示例 2: 数据加载节点 ============= @register_node class CSVLoaderNode(TraceNode): """CSV 数据加载器 - 演示输入节点""" CATEGORY = "Data/Loader" DISPLAY_NAME = "CSV 加载器" DESCRIPTION = "从 CSV 文件加载数据表" ICON = "📊" NODE_TYPE = NodeType.INPUT # 输入节点(仅输出) CACHE_POLICY = CachePolicy.MEMORY # 使用内存缓存 # 无输入端口 @output_port("table", "DataTable", description="加载的数据表") @param("file_path", "String", default="", description="文件路径", widget="file") @param("delimiter", "String", default=",", description="分隔符") @param("skip_rows", "Number", default=0, description="跳过行数", min=0, step=1) @context_var("row_count", "Integer", description="数据行数") @context_var("column_count", "Integer", description="列数") @context_var("file_name", "String", description="文件名") def process(self, inputs, context=None): import polars as pl from pathlib import Path file_path = self.get_param("file_path") delimiter = self.get_param("delimiter", ",") skip_rows = self.get_param("skip_rows", 0) # 加载 CSV(使用 polars) df = pl.read_csv(file_path, separator=delimiter) return {"table": df} # ============= 示例 3: 数据转换节点 ============= @register_node class FilterRowsNode(TraceNode): """行过滤器 - 演示标准流水线节点""" CATEGORY = "Data/Transform" DISPLAY_NAME = "行过滤" DESCRIPTION = "根据条件过滤数据行" ICON = "🔍" @input_port("table", "DataTable", description="输入数据表") @output_port("filtered", "DataTable", description="过滤后的数据表") @param("column", "String", default="", description="过滤列名") @param("operator", "String", default=">", description="运算符", widget="dropdown", options=[">", ">=", "<", "<=", "==", "!="]) @param("value", "Number", default=0, description="比较值") @context_var("filtered_count", "Integer", description="过滤后行数") @context_var("removed_count", "Integer", description="移除行数") def process(self, inputs, context=None): import polars as pl table = inputs["table"] column = self.get_param("column") operator = self.get_param("operator", ">") value = self.get_param("value", 0) # 应用过滤条件(使用 polars expressions) if not column: filtered = table else: if operator == ">": filtered = table.filter(pl.col(column) > value) elif operator == ">=": filtered = table.filter(pl.col(column) >= value) elif operator == "<": filtered = table.filter(pl.col(column) < value) elif operator == "<=": filtered = table.filter(pl.col(column) <= value) elif operator == "==": filtered = table.filter(pl.col(column) == value) elif operator == "!=": filtered = table.filter(pl.col(column) != value) else: filtered = table original_count = table.height if hasattr(table, 'height') else len(table) filtered_count = filtered.height if hasattr(filtered, 'height') else len(filtered) return {"filtered": filtered} @register_node class SelectColumnsNode(TraceNode): """列选择器""" CATEGORY = "Data/Transform" DISPLAY_NAME = "列选择" DESCRIPTION = "选择指定的列" ICON = "📋" @input_port("table", "DataTable", description="输入数据表") @output_port("selected", "DataTable", description="选择后的数据表") @param("columns", "String", default="", description="列名(逗号分隔)", multiline=True) @context_var("selected_columns", "List", description="选择的列名列表") def process(self, inputs, context=None): table = inputs["table"] columns_str = self.get_param("columns", "") # 解析列名 columns = [col.strip() for col in columns_str.split(",") if col.strip()] # 选择列 selected = table.select(columns) return {"selected": selected} # ============= 示例 4: 聚合节点(多输入) ============= @register_node class ConcatNode(TraceNode): """数据合并节点 - 演示多输入聚合""" CATEGORY = "Data/Aggregate" DISPLAY_NAME = "数据合并" DESCRIPTION = "合并多个数据表(垂直拼接)" ICON = "🔗" NODE_TYPE = NodeType.COMPOSITE @input_port("tables", "DataTable", description="输入数据表", list=True) # 列表输入 @output_port("concatenated", "DataTable", description="合并后的数据表") @param("ignore_index", "Boolean", default=True, description="忽略原索引") @context_var("total_rows", "Integer", description="合并后总行数") @context_var("input_count", "Integer", description="输入表数量") def process(self, inputs, context=None): import polars as pl tables = inputs["tables"] ignore_index = self.get_param("ignore_index", True) # 合并数据表(使用 polars) concatenated = pl.concat(tables, how="vertical") return {"concatenated": concatenated} # ============= 示例 5: 输出节点 ============= @register_node class TableOutputNode(TraceNode): """表格输出节点 - 演示输出节点""" CATEGORY = "Output/Display" DISPLAY_NAME = "表格显示" DESCRIPTION = "在前端显示数据表" ICON = "🖥️" NODE_TYPE = NodeType.OUTPUT # 输出节点(仅输入) @input_port("table", "DataTable", description="要显示的数据表") # 无输出端口 @param("max_rows", "Number", default=100, description="最大显示行数", min=1, max=10000, step=10) @param("show_index", "Boolean", default=True, description="显示索引列") @context_var("displayed_rows", "Integer", description="实际显示行数") def process(self, inputs, context=None): table = inputs["table"] max_rows = self.get_param("max_rows", 100) show_index = self.get_param("show_index", True) # 截取数据 displayed = table.head(max_rows) # 转换为前端格式 result = { "columns": list(displayed.columns), "data": displayed.to_dicts(), "total_rows": table.height if hasattr(table, 'height') else len(table), "show_index": show_index } return {"display": result} # ============= 示例 6: 带缓存的计算密集型节点 ============= @register_node class StatisticsNode(TraceNode): """统计分析节点 - 演示缓存使用""" CATEGORY = "Data/Analysis" DISPLAY_NAME = "统计分析" DESCRIPTION = "计算数据表的统计信息" ICON = "📈" CACHE_POLICY = CachePolicy.MEMORY # 启用内存缓存 @input_port("table", "DataTable", description="输入数据表") @output_port("stats", "DataTable", description="统计结果") @param("columns", "String", default="", description="分析列名(逗号分隔,留空则全部)") @context_var("mean", "Number", description="平均值") @context_var("std", "Number", description="标准差") def process(self, inputs, context=None): import polars as pl table = inputs["table"] columns_str = self.get_param("columns", "") # 选择列 if columns_str: columns = [col.strip() for col in columns_str.split(",") if col.strip()] data = table.select(columns) else: # 根据 schema 选择数值列 NUM_DTYPES = { pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Float32, pl.Float64 } num_cols = [c for c, t in table.schema.items() if t in NUM_DTYPES] data = table.select(num_cols) # 计算统计信息 stats = data.describe() return {"stats": stats}