288 lines
10 KiB
Python
288 lines
10 KiB
Python
"""
|
||
TraceStudio 内置节点示例 (v2.0)
|
||
演示如何使用装饰器简化节点开发
|
||
"""
|
||
from typing import Any, Dict
|
||
from server.app.core.node_base import (
|
||
TraceNode,
|
||
input_port,
|
||
output_port,
|
||
param,
|
||
context_var,
|
||
NodeType,
|
||
CachePolicy
|
||
)
|
||
from server.app.core.node_registry import register_node
|
||
|
||
|
||
# ============= 示例 1: 简单数学节点 =============
|
||
|
||
@register_node
|
||
class AddNode(TraceNode):
|
||
"""加法节点 - 演示最基础的节点定义"""
|
||
|
||
CATEGORY = "Math/Basic"
|
||
DISPLAY_NAME = "加法"
|
||
DESCRIPTION = "计算两个数的和,支持偏移量"
|
||
ICON = "➕"
|
||
|
||
# 使用装饰器自动收集属性
|
||
@input_port("a", "Number", description="加数 A", required=True)
|
||
@input_port("b", "Number", description="加数 B", required=True)
|
||
@output_port("result", "Number", description="计算结果")
|
||
@param("offset", "Number", default=0, description="偏移量", widget="slider", min=-100, max=100, step=1)
|
||
@context_var("operation", "String", description="执行的运算")
|
||
def process(self, inputs, context=None):
|
||
a = inputs["a"]
|
||
b = inputs["b"]
|
||
offset = self.get_param("offset", 0)
|
||
|
||
result = a + b + offset
|
||
|
||
return {"result": result}
|
||
|
||
|
||
@register_node
|
||
class MultiplyNode(TraceNode):
|
||
"""乘法节点"""
|
||
|
||
CATEGORY = "Math/Basic"
|
||
DISPLAY_NAME = "乘法"
|
||
DESCRIPTION = "计算两个数的乘积"
|
||
ICON = "✖️"
|
||
|
||
@input_port("a", "Number", description="乘数 A")
|
||
@input_port("b", "Number", description="乘数 B")
|
||
@output_port("result", "Number", description="乘积")
|
||
@param("scale", "Number", default=1.0, description="缩放系数")
|
||
@context_var("operation", "String", description="运算描述")
|
||
def process(self, inputs, context=None):
|
||
a = inputs["a"]
|
||
b = inputs["b"]
|
||
scale = self.get_param("scale", 1.0)
|
||
|
||
result = a * b * scale
|
||
|
||
return {"result": result}
|
||
|
||
|
||
# ============= 示例 2: 数据加载节点 =============
|
||
|
||
@register_node
|
||
class CSVLoaderNode(TraceNode):
|
||
"""CSV 数据加载器 - 演示输入节点"""
|
||
|
||
CATEGORY = "Data/Loader"
|
||
DISPLAY_NAME = "CSV 加载器"
|
||
DESCRIPTION = "从 CSV 文件加载数据表"
|
||
ICON = "📊"
|
||
NODE_TYPE = NodeType.INPUT # 输入节点(仅输出)
|
||
CACHE_POLICY = CachePolicy.MEMORY # 使用内存缓存
|
||
|
||
# 无输入端口
|
||
@output_port("table", "DataTable", description="加载的数据表")
|
||
@param("file_path", "String", default="", description="文件路径", widget="file")
|
||
@param("delimiter", "String", default=",", description="分隔符")
|
||
@param("skip_rows", "Number", default=0, description="跳过行数", min=0, step=1)
|
||
@context_var("row_count", "Integer", description="数据行数")
|
||
@context_var("column_count", "Integer", description="列数")
|
||
@context_var("file_name", "String", description="文件名")
|
||
def process(self, inputs, context=None):
|
||
import polars as pl
|
||
from pathlib import Path
|
||
|
||
file_path = self.get_param("file_path")
|
||
delimiter = self.get_param("delimiter", ",")
|
||
skip_rows = self.get_param("skip_rows", 0)
|
||
|
||
# 加载 CSV(使用 polars)
|
||
df = pl.read_csv(file_path, separator=delimiter)
|
||
|
||
return {"table": df}
|
||
|
||
|
||
# ============= 示例 3: 数据转换节点 =============
|
||
|
||
@register_node
|
||
class FilterRowsNode(TraceNode):
|
||
"""行过滤器 - 演示标准流水线节点"""
|
||
|
||
CATEGORY = "Data/Transform"
|
||
DISPLAY_NAME = "行过滤"
|
||
DESCRIPTION = "根据条件过滤数据行"
|
||
ICON = "🔍"
|
||
|
||
@input_port("table", "DataTable", description="输入数据表")
|
||
@output_port("filtered", "DataTable", description="过滤后的数据表")
|
||
@param("column", "String", default="", description="过滤列名")
|
||
@param("operator", "String", default=">", description="运算符", widget="dropdown",
|
||
options=[">", ">=", "<", "<=", "==", "!="])
|
||
@param("value", "Number", default=0, description="比较值")
|
||
@context_var("filtered_count", "Integer", description="过滤后行数")
|
||
@context_var("removed_count", "Integer", description="移除行数")
|
||
def process(self, inputs, context=None):
|
||
import polars as pl
|
||
|
||
table = inputs["table"]
|
||
column = self.get_param("column")
|
||
operator = self.get_param("operator", ">")
|
||
value = self.get_param("value", 0)
|
||
|
||
# 应用过滤条件(使用 polars expressions)
|
||
if not column:
|
||
filtered = table
|
||
else:
|
||
if operator == ">":
|
||
filtered = table.filter(pl.col(column) > value)
|
||
elif operator == ">=":
|
||
filtered = table.filter(pl.col(column) >= value)
|
||
elif operator == "<":
|
||
filtered = table.filter(pl.col(column) < value)
|
||
elif operator == "<=":
|
||
filtered = table.filter(pl.col(column) <= value)
|
||
elif operator == "==":
|
||
filtered = table.filter(pl.col(column) == value)
|
||
elif operator == "!=":
|
||
filtered = table.filter(pl.col(column) != value)
|
||
else:
|
||
filtered = table
|
||
|
||
original_count = table.height if hasattr(table, 'height') else len(table)
|
||
filtered_count = filtered.height if hasattr(filtered, 'height') else len(filtered)
|
||
|
||
return {"filtered": filtered}
|
||
|
||
|
||
@register_node
|
||
class SelectColumnsNode(TraceNode):
|
||
"""列选择器"""
|
||
|
||
CATEGORY = "Data/Transform"
|
||
DISPLAY_NAME = "列选择"
|
||
DESCRIPTION = "选择指定的列"
|
||
ICON = "📋"
|
||
|
||
@input_port("table", "DataTable", description="输入数据表")
|
||
@output_port("selected", "DataTable", description="选择后的数据表")
|
||
@param("columns", "String", default="", description="列名(逗号分隔)", multiline=True)
|
||
@context_var("selected_columns", "List", description="选择的列名列表")
|
||
def process(self, inputs, context=None):
|
||
table = inputs["table"]
|
||
columns_str = self.get_param("columns", "")
|
||
|
||
# 解析列名
|
||
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
|
||
|
||
# 选择列
|
||
selected = table.select(columns)
|
||
|
||
return {"selected": selected}
|
||
|
||
|
||
# ============= 示例 4: 聚合节点(多输入) =============
|
||
|
||
@register_node
|
||
class ConcatNode(TraceNode):
|
||
"""数据合并节点 - 演示多输入聚合"""
|
||
|
||
CATEGORY = "Data/Aggregate"
|
||
DISPLAY_NAME = "数据合并"
|
||
DESCRIPTION = "合并多个数据表(垂直拼接)"
|
||
ICON = "🔗"
|
||
NODE_TYPE = NodeType.COMPOSITE
|
||
|
||
@input_port("tables", "DataTable", description="输入数据表", list=True) # 列表输入
|
||
@output_port("concatenated", "DataTable", description="合并后的数据表")
|
||
@param("ignore_index", "Boolean", default=True, description="忽略原索引")
|
||
@context_var("total_rows", "Integer", description="合并后总行数")
|
||
@context_var("input_count", "Integer", description="输入表数量")
|
||
def process(self, inputs, context=None):
|
||
import polars as pl
|
||
|
||
tables = inputs["tables"]
|
||
ignore_index = self.get_param("ignore_index", True)
|
||
|
||
# 合并数据表(使用 polars)
|
||
concatenated = pl.concat(tables, how="vertical")
|
||
|
||
return {"concatenated": concatenated}
|
||
|
||
|
||
# ============= 示例 5: 输出节点 =============
|
||
|
||
@register_node
|
||
class TableOutputNode(TraceNode):
|
||
"""表格输出节点 - 演示输出节点"""
|
||
|
||
CATEGORY = "Output/Display"
|
||
DISPLAY_NAME = "表格显示"
|
||
DESCRIPTION = "在前端显示数据表"
|
||
ICON = "🖥️"
|
||
NODE_TYPE = NodeType.OUTPUT # 输出节点(仅输入)
|
||
|
||
@input_port("table", "DataTable", description="要显示的数据表")
|
||
# 无输出端口
|
||
@param("max_rows", "Number", default=100, description="最大显示行数", min=1, max=10000, step=10)
|
||
@param("show_index", "Boolean", default=True, description="显示索引列")
|
||
@context_var("displayed_rows", "Integer", description="实际显示行数")
|
||
def process(self, inputs, context=None):
|
||
table = inputs["table"]
|
||
max_rows = self.get_param("max_rows", 100)
|
||
show_index = self.get_param("show_index", True)
|
||
|
||
# 截取数据
|
||
displayed = table.head(max_rows)
|
||
|
||
# 转换为前端格式
|
||
result = {
|
||
"columns": list(displayed.columns),
|
||
"data": displayed.to_dicts(),
|
||
"total_rows": table.height if hasattr(table, 'height') else len(table),
|
||
"show_index": show_index
|
||
}
|
||
|
||
return {"display": result}
|
||
|
||
|
||
# ============= 示例 6: 带缓存的计算密集型节点 =============
|
||
|
||
@register_node
|
||
class StatisticsNode(TraceNode):
|
||
"""统计分析节点 - 演示缓存使用"""
|
||
|
||
CATEGORY = "Data/Analysis"
|
||
DISPLAY_NAME = "统计分析"
|
||
DESCRIPTION = "计算数据表的统计信息"
|
||
ICON = "📈"
|
||
CACHE_POLICY = CachePolicy.MEMORY # 启用内存缓存
|
||
|
||
@input_port("table", "DataTable", description="输入数据表")
|
||
@output_port("stats", "DataTable", description="统计结果")
|
||
@param("columns", "String", default="", description="分析列名(逗号分隔,留空则全部)")
|
||
@context_var("mean", "Number", description="平均值")
|
||
@context_var("std", "Number", description="标准差")
|
||
def process(self, inputs, context=None):
|
||
import polars as pl
|
||
|
||
table = inputs["table"]
|
||
columns_str = self.get_param("columns", "")
|
||
|
||
# 选择列
|
||
if columns_str:
|
||
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
|
||
data = table.select(columns)
|
||
else:
|
||
# 根据 schema 选择数值列
|
||
NUM_DTYPES = {
|
||
pl.Int8, pl.Int16, pl.Int32, pl.Int64,
|
||
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
|
||
pl.Float32, pl.Float64
|
||
}
|
||
num_cols = [c for c, t in table.schema.items() if t in NUM_DTYPES]
|
||
data = table.select(num_cols)
|
||
|
||
# 计算统计信息
|
||
stats = data.describe()
|
||
|
||
return {"stats": stats}
|