TraceStudio/server/app/nodes/example_nodes.py
2026-01-12 21:51:45 +08:00

288 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TraceStudio 内置节点示例 (v2.0)
演示如何使用装饰器简化节点开发
"""
from typing import Any, Dict
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType,
CachePolicy
)
from server.app.core.node_registry import register_node
# ============= 示例 1: 简单数学节点 =============
@register_node
class AddNode(TraceNode):
"""加法节点 - 演示最基础的节点定义"""
CATEGORY = "Math/Basic"
DISPLAY_NAME = "加法"
DESCRIPTION = "计算两个数的和,支持偏移量"
ICON = ""
# 使用装饰器自动收集属性
@input_port("a", "Number", description="加数 A", required=True)
@input_port("b", "Number", description="加数 B", required=True)
@output_port("result", "Number", description="计算结果")
@param("offset", "Number", default=0, description="偏移量", widget="slider", min=-100, max=100, step=1)
@context_var("operation", "String", description="执行的运算")
def process(self, inputs, context=None):
a = inputs["a"]
b = inputs["b"]
offset = self.get_param("offset", 0)
result = a + b + offset
return {"result": result}
@register_node
class MultiplyNode(TraceNode):
"""乘法节点"""
CATEGORY = "Math/Basic"
DISPLAY_NAME = "乘法"
DESCRIPTION = "计算两个数的乘积"
ICON = "✖️"
@input_port("a", "Number", description="乘数 A")
@input_port("b", "Number", description="乘数 B")
@output_port("result", "Number", description="乘积")
@param("scale", "Number", default=1.0, description="缩放系数")
@context_var("operation", "String", description="运算描述")
def process(self, inputs, context=None):
a = inputs["a"]
b = inputs["b"]
scale = self.get_param("scale", 1.0)
result = a * b * scale
return {"result": result}
# ============= 示例 2: 数据加载节点 =============
@register_node
class CSVLoaderNode(TraceNode):
"""CSV 数据加载器 - 演示输入节点"""
CATEGORY = "Data/Loader"
DISPLAY_NAME = "CSV 加载器"
DESCRIPTION = "从 CSV 文件加载数据表"
ICON = "📊"
NODE_TYPE = NodeType.INPUT # 输入节点(仅输出)
CACHE_POLICY = CachePolicy.MEMORY # 使用内存缓存
# 无输入端口
@output_port("table", "DataTable", description="加载的数据表")
@param("file_path", "String", default="", description="文件路径", widget="file")
@param("delimiter", "String", default=",", description="分隔符")
@param("skip_rows", "Number", default=0, description="跳过行数", min=0, step=1)
@context_var("row_count", "Integer", description="数据行数")
@context_var("column_count", "Integer", description="列数")
@context_var("file_name", "String", description="文件名")
def process(self, inputs, context=None):
import polars as pl
from pathlib import Path
file_path = self.get_param("file_path")
delimiter = self.get_param("delimiter", ",")
skip_rows = self.get_param("skip_rows", 0)
# 加载 CSV使用 polars
df = pl.read_csv(file_path, separator=delimiter)
return {"table": df}
# ============= 示例 3: 数据转换节点 =============
@register_node
class FilterRowsNode(TraceNode):
"""行过滤器 - 演示标准流水线节点"""
CATEGORY = "Data/Transform"
DISPLAY_NAME = "行过滤"
DESCRIPTION = "根据条件过滤数据行"
ICON = "🔍"
@input_port("table", "DataTable", description="输入数据表")
@output_port("filtered", "DataTable", description="过滤后的数据表")
@param("column", "String", default="", description="过滤列名")
@param("operator", "String", default=">", description="运算符", widget="dropdown",
options=[">", ">=", "<", "<=", "==", "!="])
@param("value", "Number", default=0, description="比较值")
@context_var("filtered_count", "Integer", description="过滤后行数")
@context_var("removed_count", "Integer", description="移除行数")
def process(self, inputs, context=None):
import polars as pl
table = inputs["table"]
column = self.get_param("column")
operator = self.get_param("operator", ">")
value = self.get_param("value", 0)
# 应用过滤条件(使用 polars expressions
if not column:
filtered = table
else:
if operator == ">":
filtered = table.filter(pl.col(column) > value)
elif operator == ">=":
filtered = table.filter(pl.col(column) >= value)
elif operator == "<":
filtered = table.filter(pl.col(column) < value)
elif operator == "<=":
filtered = table.filter(pl.col(column) <= value)
elif operator == "==":
filtered = table.filter(pl.col(column) == value)
elif operator == "!=":
filtered = table.filter(pl.col(column) != value)
else:
filtered = table
original_count = table.height if hasattr(table, 'height') else len(table)
filtered_count = filtered.height if hasattr(filtered, 'height') else len(filtered)
return {"filtered": filtered}
@register_node
class SelectColumnsNode(TraceNode):
"""列选择器"""
CATEGORY = "Data/Transform"
DISPLAY_NAME = "列选择"
DESCRIPTION = "选择指定的列"
ICON = "📋"
@input_port("table", "DataTable", description="输入数据表")
@output_port("selected", "DataTable", description="选择后的数据表")
@param("columns", "String", default="", description="列名(逗号分隔)", multiline=True)
@context_var("selected_columns", "List", description="选择的列名列表")
def process(self, inputs, context=None):
table = inputs["table"]
columns_str = self.get_param("columns", "")
# 解析列名
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
# 选择列
selected = table.select(columns)
return {"selected": selected}
# ============= 示例 4: 聚合节点(多输入) =============
@register_node
class ConcatNode(TraceNode):
"""数据合并节点 - 演示多输入聚合"""
CATEGORY = "Data/Aggregate"
DISPLAY_NAME = "数据合并"
DESCRIPTION = "合并多个数据表(垂直拼接)"
ICON = "🔗"
NODE_TYPE = NodeType.COMPOSITE
@input_port("tables", "DataTable", description="输入数据表", list=True) # 列表输入
@output_port("concatenated", "DataTable", description="合并后的数据表")
@param("ignore_index", "Boolean", default=True, description="忽略原索引")
@context_var("total_rows", "Integer", description="合并后总行数")
@context_var("input_count", "Integer", description="输入表数量")
def process(self, inputs, context=None):
import polars as pl
tables = inputs["tables"]
ignore_index = self.get_param("ignore_index", True)
# 合并数据表(使用 polars
concatenated = pl.concat(tables, how="vertical")
return {"concatenated": concatenated}
# ============= 示例 5: 输出节点 =============
@register_node
class TableOutputNode(TraceNode):
"""表格输出节点 - 演示输出节点"""
CATEGORY = "Output/Display"
DISPLAY_NAME = "表格显示"
DESCRIPTION = "在前端显示数据表"
ICON = "🖥️"
NODE_TYPE = NodeType.OUTPUT # 输出节点(仅输入)
@input_port("table", "DataTable", description="要显示的数据表")
# 无输出端口
@param("max_rows", "Number", default=100, description="最大显示行数", min=1, max=10000, step=10)
@param("show_index", "Boolean", default=True, description="显示索引列")
@context_var("displayed_rows", "Integer", description="实际显示行数")
def process(self, inputs, context=None):
table = inputs["table"]
max_rows = self.get_param("max_rows", 100)
show_index = self.get_param("show_index", True)
# 截取数据
displayed = table.head(max_rows)
# 转换为前端格式
result = {
"columns": list(displayed.columns),
"data": displayed.to_dicts(),
"total_rows": table.height if hasattr(table, 'height') else len(table),
"show_index": show_index
}
return {"display": result}
# ============= 示例 6: 带缓存的计算密集型节点 =============
@register_node
class StatisticsNode(TraceNode):
"""统计分析节点 - 演示缓存使用"""
CATEGORY = "Data/Analysis"
DISPLAY_NAME = "统计分析"
DESCRIPTION = "计算数据表的统计信息"
ICON = "📈"
CACHE_POLICY = CachePolicy.MEMORY # 启用内存缓存
@input_port("table", "DataTable", description="输入数据表")
@output_port("stats", "DataTable", description="统计结果")
@param("columns", "String", default="", description="分析列名(逗号分隔,留空则全部)")
@context_var("mean", "Number", description="平均值")
@context_var("std", "Number", description="标准差")
def process(self, inputs, context=None):
import polars as pl
table = inputs["table"]
columns_str = self.get_param("columns", "")
# 选择列
if columns_str:
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
data = table.select(columns)
else:
# 根据 schema 选择数值列
NUM_DTYPES = {
pl.Int8, pl.Int16, pl.Int32, pl.Int64,
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
pl.Float32, pl.Float64
}
num_cols = [c for c, t in table.schema.items() if t in NUM_DTYPES]
data = table.select(num_cols)
# 计算统计信息
stats = data.describe()
return {"stats": stats}