TraceStudio-dev/server/app/nodes/example_nodes.py

288 lines
10 KiB
Python
Raw Permalink Normal View History

"""
TraceStudio 内置节点示例 (v2.0)
演示如何使用装饰器简化节点开发
"""
from typing import Any, Dict
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType,
CachePolicy
)
from server.app.core.node_registry import register_node
# ============= 示例 1: 简单数学节点 =============
@register_node
class AddNode(TraceNode):
"""加法节点 - 演示最基础的节点定义"""
CATEGORY = "Math/Basic"
DISPLAY_NAME = "加法"
DESCRIPTION = "计算两个数的和,支持偏移量"
ICON = ""
# 使用装饰器自动收集属性
@input_port("a", "Number", description="加数 A", required=True)
@input_port("b", "Number", description="加数 B", required=True)
@output_port("result", "Number", description="计算结果")
@param("offset", "Number", default=0, description="偏移量", widget="slider", min=-100, max=100, step=1)
@context_var("operation", "String", description="执行的运算")
def process(self, inputs, context=None):
a = inputs["a"]
b = inputs["b"]
offset = self.get_param("offset", 0)
result = a + b + offset
return {"result": result}
@register_node
class MultiplyNode(TraceNode):
"""乘法节点"""
CATEGORY = "Math/Basic"
DISPLAY_NAME = "乘法"
DESCRIPTION = "计算两个数的乘积"
ICON = "✖️"
@input_port("a", "Number", description="乘数 A")
@input_port("b", "Number", description="乘数 B")
@output_port("result", "Number", description="乘积")
@param("scale", "Number", default=1.0, description="缩放系数")
@context_var("operation", "String", description="运算描述")
def process(self, inputs, context=None):
a = inputs["a"]
b = inputs["b"]
scale = self.get_param("scale", 1.0)
result = a * b * scale
return {"result": result}
# ============= 示例 2: 数据加载节点 =============
@register_node
class CSVLoaderNode(TraceNode):
"""CSV 数据加载器 - 演示输入节点"""
CATEGORY = "Data/Loader"
DISPLAY_NAME = "CSV 加载器"
DESCRIPTION = "从 CSV 文件加载数据表"
ICON = "📊"
NODE_TYPE = NodeType.INPUT # 输入节点(仅输出)
CACHE_POLICY = CachePolicy.MEMORY # 使用内存缓存
# 无输入端口
@output_port("table", "DataTable", description="加载的数据表")
@param("file_path", "String", default="", description="文件路径", widget="file")
@param("delimiter", "String", default=",", description="分隔符")
@param("skip_rows", "Number", default=0, description="跳过行数", min=0, step=1)
@context_var("row_count", "Integer", description="数据行数")
@context_var("column_count", "Integer", description="列数")
@context_var("file_name", "String", description="文件名")
def process(self, inputs, context=None):
import polars as pl
from pathlib import Path
file_path = self.get_param("file_path")
delimiter = self.get_param("delimiter", ",")
skip_rows = self.get_param("skip_rows", 0)
# 加载 CSV使用 polars
df = pl.read_csv(file_path, separator=delimiter)
return {"table": df}
# ============= 示例 3: 数据转换节点 =============
@register_node
class FilterRowsNode(TraceNode):
"""行过滤器 - 演示标准流水线节点"""
CATEGORY = "Data/Transform"
DISPLAY_NAME = "行过滤"
DESCRIPTION = "根据条件过滤数据行"
ICON = "🔍"
@input_port("table", "DataTable", description="输入数据表")
@output_port("filtered", "DataTable", description="过滤后的数据表")
@param("column", "String", default="", description="过滤列名")
@param("operator", "String", default=">", description="运算符", widget="dropdown",
options=[">", ">=", "<", "<=", "==", "!="])
@param("value", "Number", default=0, description="比较值")
@context_var("filtered_count", "Integer", description="过滤后行数")
@context_var("removed_count", "Integer", description="移除行数")
def process(self, inputs, context=None):
import polars as pl
table = inputs["table"]
column = self.get_param("column")
operator = self.get_param("operator", ">")
value = self.get_param("value", 0)
# 应用过滤条件(使用 polars expressions
if not column:
filtered = table
else:
if operator == ">":
filtered = table.filter(pl.col(column) > value)
elif operator == ">=":
filtered = table.filter(pl.col(column) >= value)
elif operator == "<":
filtered = table.filter(pl.col(column) < value)
elif operator == "<=":
filtered = table.filter(pl.col(column) <= value)
elif operator == "==":
filtered = table.filter(pl.col(column) == value)
elif operator == "!=":
filtered = table.filter(pl.col(column) != value)
else:
filtered = table
original_count = table.height if hasattr(table, 'height') else len(table)
filtered_count = filtered.height if hasattr(filtered, 'height') else len(filtered)
return {"filtered": filtered}
@register_node
class SelectColumnsNode(TraceNode):
"""列选择器"""
CATEGORY = "Data/Transform"
DISPLAY_NAME = "列选择"
DESCRIPTION = "选择指定的列"
ICON = "📋"
@input_port("table", "DataTable", description="输入数据表")
@output_port("selected", "DataTable", description="选择后的数据表")
@param("columns", "String", default="", description="列名(逗号分隔)", multiline=True)
@context_var("selected_columns", "List", description="选择的列名列表")
def process(self, inputs, context=None):
table = inputs["table"]
columns_str = self.get_param("columns", "")
# 解析列名
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
# 选择列
selected = table.select(columns)
return {"selected": selected}
# ============= 示例 4: 聚合节点(多输入) =============
@register_node
class ConcatNode(TraceNode):
"""数据合并节点 - 演示多输入聚合"""
CATEGORY = "Data/Aggregate"
DISPLAY_NAME = "数据合并"
DESCRIPTION = "合并多个数据表(垂直拼接)"
ICON = "🔗"
NODE_TYPE = NodeType.COMPOSITE
@input_port("tables", "DataTable", description="输入数据表", list=True) # 列表输入
@output_port("concatenated", "DataTable", description="合并后的数据表")
@param("ignore_index", "Boolean", default=True, description="忽略原索引")
@context_var("total_rows", "Integer", description="合并后总行数")
@context_var("input_count", "Integer", description="输入表数量")
def process(self, inputs, context=None):
import polars as pl
tables = inputs["tables"]
ignore_index = self.get_param("ignore_index", True)
# 合并数据表(使用 polars
concatenated = pl.concat(tables, how="vertical")
return {"concatenated": concatenated}
# ============= 示例 5: 输出节点 =============
@register_node
class TableOutputNode(TraceNode):
"""表格输出节点 - 演示输出节点"""
CATEGORY = "Output/Display"
DISPLAY_NAME = "表格显示"
DESCRIPTION = "在前端显示数据表"
ICON = "🖥️"
NODE_TYPE = NodeType.OUTPUT # 输出节点(仅输入)
@input_port("table", "DataTable", description="要显示的数据表")
# 无输出端口
@param("max_rows", "Number", default=100, description="最大显示行数", min=1, max=10000, step=10)
@param("show_index", "Boolean", default=True, description="显示索引列")
@context_var("displayed_rows", "Integer", description="实际显示行数")
def process(self, inputs, context=None):
table = inputs["table"]
max_rows = self.get_param("max_rows", 100)
show_index = self.get_param("show_index", True)
# 截取数据
displayed = table.head(max_rows)
# 转换为前端格式
result = {
"columns": list(displayed.columns),
"data": displayed.to_dicts(),
"total_rows": table.height if hasattr(table, 'height') else len(table),
"show_index": show_index
}
return {"display": result}
# ============= 示例 6: 带缓存的计算密集型节点 =============
@register_node
class StatisticsNode(TraceNode):
"""统计分析节点 - 演示缓存使用"""
CATEGORY = "Data/Analysis"
DISPLAY_NAME = "统计分析"
DESCRIPTION = "计算数据表的统计信息"
ICON = "📈"
CACHE_POLICY = CachePolicy.MEMORY # 启用内存缓存
@input_port("table", "DataTable", description="输入数据表")
@output_port("stats", "DataTable", description="统计结果")
@param("columns", "String", default="", description="分析列名(逗号分隔,留空则全部)")
@context_var("mean", "Number", description="平均值")
@context_var("std", "Number", description="标准差")
def process(self, inputs, context=None):
import polars as pl
table = inputs["table"]
columns_str = self.get_param("columns", "")
# 选择列
if columns_str:
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
data = table.select(columns)
else:
# 根据 schema 选择数值列
NUM_DTYPES = {
pl.Int8, pl.Int16, pl.Int32, pl.Int64,
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
pl.Float32, pl.Float64
}
num_cols = [c for c, t in table.schema.items() if t in NUM_DTYPES]
data = table.select(num_cols)
# 计算统计信息
stats = data.describe()
return {"stats": stats}