288 lines
10 KiB
Python
288 lines
10 KiB
Python
|
|
"""
|
|||
|
|
TraceStudio 内置节点示例 (v2.0)
|
|||
|
|
演示如何使用装饰器简化节点开发
|
|||
|
|
"""
|
|||
|
|
from typing import Any, Dict
|
|||
|
|
from server.app.core.node_base import (
|
|||
|
|
TraceNode,
|
|||
|
|
input_port,
|
|||
|
|
output_port,
|
|||
|
|
param,
|
|||
|
|
context_var,
|
|||
|
|
NodeType,
|
|||
|
|
CachePolicy
|
|||
|
|
)
|
|||
|
|
from server.app.core.node_registry import register_node
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 1: 简单数学节点 =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class AddNode(TraceNode):
|
|||
|
|
"""加法节点 - 演示最基础的节点定义"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Math/Basic"
|
|||
|
|
DISPLAY_NAME = "加法"
|
|||
|
|
DESCRIPTION = "计算两个数的和,支持偏移量"
|
|||
|
|
ICON = "➕"
|
|||
|
|
|
|||
|
|
# 使用装饰器自动收集属性
|
|||
|
|
@input_port("a", "Number", description="加数 A", required=True)
|
|||
|
|
@input_port("b", "Number", description="加数 B", required=True)
|
|||
|
|
@output_port("result", "Number", description="计算结果")
|
|||
|
|
@param("offset", "Number", default=0, description="偏移量", widget="slider", min=-100, max=100, step=1)
|
|||
|
|
@context_var("operation", "String", description="执行的运算")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
a = inputs["a"]
|
|||
|
|
b = inputs["b"]
|
|||
|
|
offset = self.get_param("offset", 0)
|
|||
|
|
|
|||
|
|
result = a + b + offset
|
|||
|
|
|
|||
|
|
return {"result": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class MultiplyNode(TraceNode):
|
|||
|
|
"""乘法节点"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Math/Basic"
|
|||
|
|
DISPLAY_NAME = "乘法"
|
|||
|
|
DESCRIPTION = "计算两个数的乘积"
|
|||
|
|
ICON = "✖️"
|
|||
|
|
|
|||
|
|
@input_port("a", "Number", description="乘数 A")
|
|||
|
|
@input_port("b", "Number", description="乘数 B")
|
|||
|
|
@output_port("result", "Number", description="乘积")
|
|||
|
|
@param("scale", "Number", default=1.0, description="缩放系数")
|
|||
|
|
@context_var("operation", "String", description="运算描述")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
a = inputs["a"]
|
|||
|
|
b = inputs["b"]
|
|||
|
|
scale = self.get_param("scale", 1.0)
|
|||
|
|
|
|||
|
|
result = a * b * scale
|
|||
|
|
|
|||
|
|
return {"result": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 2: 数据加载节点 =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class CSVLoaderNode(TraceNode):
|
|||
|
|
"""CSV 数据加载器 - 演示输入节点"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Data/Loader"
|
|||
|
|
DISPLAY_NAME = "CSV 加载器"
|
|||
|
|
DESCRIPTION = "从 CSV 文件加载数据表"
|
|||
|
|
ICON = "📊"
|
|||
|
|
NODE_TYPE = NodeType.INPUT # 输入节点(仅输出)
|
|||
|
|
CACHE_POLICY = CachePolicy.MEMORY # 使用内存缓存
|
|||
|
|
|
|||
|
|
# 无输入端口
|
|||
|
|
@output_port("table", "DataTable", description="加载的数据表")
|
|||
|
|
@param("file_path", "String", default="", description="文件路径", widget="file")
|
|||
|
|
@param("delimiter", "String", default=",", description="分隔符")
|
|||
|
|
@param("skip_rows", "Number", default=0, description="跳过行数", min=0, step=1)
|
|||
|
|
@context_var("row_count", "Integer", description="数据行数")
|
|||
|
|
@context_var("column_count", "Integer", description="列数")
|
|||
|
|
@context_var("file_name", "String", description="文件名")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
import polars as pl
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
file_path = self.get_param("file_path")
|
|||
|
|
delimiter = self.get_param("delimiter", ",")
|
|||
|
|
skip_rows = self.get_param("skip_rows", 0)
|
|||
|
|
|
|||
|
|
# 加载 CSV(使用 polars)
|
|||
|
|
df = pl.read_csv(file_path, separator=delimiter)
|
|||
|
|
|
|||
|
|
return {"table": df}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 3: 数据转换节点 =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class FilterRowsNode(TraceNode):
|
|||
|
|
"""行过滤器 - 演示标准流水线节点"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Data/Transform"
|
|||
|
|
DISPLAY_NAME = "行过滤"
|
|||
|
|
DESCRIPTION = "根据条件过滤数据行"
|
|||
|
|
ICON = "🔍"
|
|||
|
|
|
|||
|
|
@input_port("table", "DataTable", description="输入数据表")
|
|||
|
|
@output_port("filtered", "DataTable", description="过滤后的数据表")
|
|||
|
|
@param("column", "String", default="", description="过滤列名")
|
|||
|
|
@param("operator", "String", default=">", description="运算符", widget="dropdown",
|
|||
|
|
options=[">", ">=", "<", "<=", "==", "!="])
|
|||
|
|
@param("value", "Number", default=0, description="比较值")
|
|||
|
|
@context_var("filtered_count", "Integer", description="过滤后行数")
|
|||
|
|
@context_var("removed_count", "Integer", description="移除行数")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
table = inputs["table"]
|
|||
|
|
column = self.get_param("column")
|
|||
|
|
operator = self.get_param("operator", ">")
|
|||
|
|
value = self.get_param("value", 0)
|
|||
|
|
|
|||
|
|
# 应用过滤条件(使用 polars expressions)
|
|||
|
|
if not column:
|
|||
|
|
filtered = table
|
|||
|
|
else:
|
|||
|
|
if operator == ">":
|
|||
|
|
filtered = table.filter(pl.col(column) > value)
|
|||
|
|
elif operator == ">=":
|
|||
|
|
filtered = table.filter(pl.col(column) >= value)
|
|||
|
|
elif operator == "<":
|
|||
|
|
filtered = table.filter(pl.col(column) < value)
|
|||
|
|
elif operator == "<=":
|
|||
|
|
filtered = table.filter(pl.col(column) <= value)
|
|||
|
|
elif operator == "==":
|
|||
|
|
filtered = table.filter(pl.col(column) == value)
|
|||
|
|
elif operator == "!=":
|
|||
|
|
filtered = table.filter(pl.col(column) != value)
|
|||
|
|
else:
|
|||
|
|
filtered = table
|
|||
|
|
|
|||
|
|
original_count = table.height if hasattr(table, 'height') else len(table)
|
|||
|
|
filtered_count = filtered.height if hasattr(filtered, 'height') else len(filtered)
|
|||
|
|
|
|||
|
|
return {"filtered": filtered}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class SelectColumnsNode(TraceNode):
|
|||
|
|
"""列选择器"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Data/Transform"
|
|||
|
|
DISPLAY_NAME = "列选择"
|
|||
|
|
DESCRIPTION = "选择指定的列"
|
|||
|
|
ICON = "📋"
|
|||
|
|
|
|||
|
|
@input_port("table", "DataTable", description="输入数据表")
|
|||
|
|
@output_port("selected", "DataTable", description="选择后的数据表")
|
|||
|
|
@param("columns", "String", default="", description="列名(逗号分隔)", multiline=True)
|
|||
|
|
@context_var("selected_columns", "List", description="选择的列名列表")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
table = inputs["table"]
|
|||
|
|
columns_str = self.get_param("columns", "")
|
|||
|
|
|
|||
|
|
# 解析列名
|
|||
|
|
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
|
|||
|
|
|
|||
|
|
# 选择列
|
|||
|
|
selected = table.select(columns)
|
|||
|
|
|
|||
|
|
return {"selected": selected}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 4: 聚合节点(多输入) =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ConcatNode(TraceNode):
|
|||
|
|
"""数据合并节点 - 演示多输入聚合"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Data/Aggregate"
|
|||
|
|
DISPLAY_NAME = "数据合并"
|
|||
|
|
DESCRIPTION = "合并多个数据表(垂直拼接)"
|
|||
|
|
ICON = "🔗"
|
|||
|
|
NODE_TYPE = NodeType.COMPOSITE
|
|||
|
|
|
|||
|
|
@input_port("tables", "DataTable", description="输入数据表", list=True) # 列表输入
|
|||
|
|
@output_port("concatenated", "DataTable", description="合并后的数据表")
|
|||
|
|
@param("ignore_index", "Boolean", default=True, description="忽略原索引")
|
|||
|
|
@context_var("total_rows", "Integer", description="合并后总行数")
|
|||
|
|
@context_var("input_count", "Integer", description="输入表数量")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
tables = inputs["tables"]
|
|||
|
|
ignore_index = self.get_param("ignore_index", True)
|
|||
|
|
|
|||
|
|
# 合并数据表(使用 polars)
|
|||
|
|
concatenated = pl.concat(tables, how="vertical")
|
|||
|
|
|
|||
|
|
return {"concatenated": concatenated}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 5: 输出节点 =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class TableOutputNode(TraceNode):
|
|||
|
|
"""表格输出节点 - 演示输出节点"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Output/Display"
|
|||
|
|
DISPLAY_NAME = "表格显示"
|
|||
|
|
DESCRIPTION = "在前端显示数据表"
|
|||
|
|
ICON = "🖥️"
|
|||
|
|
NODE_TYPE = NodeType.OUTPUT # 输出节点(仅输入)
|
|||
|
|
|
|||
|
|
@input_port("table", "DataTable", description="要显示的数据表")
|
|||
|
|
# 无输出端口
|
|||
|
|
@param("max_rows", "Number", default=100, description="最大显示行数", min=1, max=10000, step=10)
|
|||
|
|
@param("show_index", "Boolean", default=True, description="显示索引列")
|
|||
|
|
@context_var("displayed_rows", "Integer", description="实际显示行数")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
table = inputs["table"]
|
|||
|
|
max_rows = self.get_param("max_rows", 100)
|
|||
|
|
show_index = self.get_param("show_index", True)
|
|||
|
|
|
|||
|
|
# 截取数据
|
|||
|
|
displayed = table.head(max_rows)
|
|||
|
|
|
|||
|
|
# 转换为前端格式
|
|||
|
|
result = {
|
|||
|
|
"columns": list(displayed.columns),
|
|||
|
|
"data": displayed.to_dicts(),
|
|||
|
|
"total_rows": table.height if hasattr(table, 'height') else len(table),
|
|||
|
|
"show_index": show_index
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {"display": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 示例 6: 带缓存的计算密集型节点 =============
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class StatisticsNode(TraceNode):
|
|||
|
|
"""统计分析节点 - 演示缓存使用"""
|
|||
|
|
|
|||
|
|
CATEGORY = "Data/Analysis"
|
|||
|
|
DISPLAY_NAME = "统计分析"
|
|||
|
|
DESCRIPTION = "计算数据表的统计信息"
|
|||
|
|
ICON = "📈"
|
|||
|
|
CACHE_POLICY = CachePolicy.MEMORY # 启用内存缓存
|
|||
|
|
|
|||
|
|
@input_port("table", "DataTable", description="输入数据表")
|
|||
|
|
@output_port("stats", "DataTable", description="统计结果")
|
|||
|
|
@param("columns", "String", default="", description="分析列名(逗号分隔,留空则全部)")
|
|||
|
|
@context_var("mean", "Number", description="平均值")
|
|||
|
|
@context_var("std", "Number", description="标准差")
|
|||
|
|
def process(self, inputs, context=None):
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
table = inputs["table"]
|
|||
|
|
columns_str = self.get_param("columns", "")
|
|||
|
|
|
|||
|
|
# 选择列
|
|||
|
|
if columns_str:
|
|||
|
|
columns = [col.strip() for col in columns_str.split(",") if col.strip()]
|
|||
|
|
data = table.select(columns)
|
|||
|
|
else:
|
|||
|
|
# 根据 schema 选择数值列
|
|||
|
|
NUM_DTYPES = {
|
|||
|
|
pl.Int8, pl.Int16, pl.Int32, pl.Int64,
|
|||
|
|
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
|
|||
|
|
pl.Float32, pl.Float64
|
|||
|
|
}
|
|||
|
|
num_cols = [c for c, t in table.schema.items() if t in NUM_DTYPES]
|
|||
|
|
data = table.select(num_cols)
|
|||
|
|
|
|||
|
|
# 计算统计信息
|
|||
|
|
stats = data.describe()
|
|||
|
|
|
|||
|
|
return {"stats": stats}
|