TraceStudio-dev/docs/server1.2/NODE_DEVELOPMENT_GUIDE_v2.md
2026-01-09 21:37:02 +08:00

12 KiB
Raw Blame History

TraceStudio 节点开发指南 (v2.0)

🎯 核心设计理念

让节点开发者专注于业务逻辑,框架自动处理元数据收集和注册。


📦 快速开始

1. 最简单的节点

from server.app.core.node_base import TraceNode, input_port, output_port, param
from server.app.core.node_registry import register_node

@register_node
class AddNode(TraceNode):
    CATEGORY = "Math"
    DISPLAY_NAME = "加法"
    
    @input_port("a", "Number", description="加数 A")
    @input_port("b", "Number", description="加数 B")
    @output_port("result", "Number", description="和")
    @param("offset", "Number", default=0, description="偏移量")
    def process(self, inputs, context=None):
        a = inputs["a"]
        b = inputs["b"]
        offset = self.get_param("offset", 0)
        
        return {
            "outputs": {"result": a + b + offset},
            "context": {}
        }

就这么简单! 装饰器会自动收集所有属性到 InputSpecOutputSpecParamSpec


🧩 四大属性规范

InputSpec - 主输入端口

@input_port(
    "data",                    # 端口名称
    "DataTable",               # 数据类型
    description="输入数据表",   # 描述
    required=True,             # 是否必需(默认 True
    list=False                 # 是否为列表类型(默认 False
)

特点:

  • 必须通过连线获得数据
  • 不会显示在参数面板中
  • 位于节点左侧

OutputSpec - 主输出端口

@output_port(
    "filtered",                # 端口名称
    "DataTable",               # 数据类型
    description="过滤后的数据", # 描述
    list=False                 # 是否输出列表
)

特点:

  • 供下游节点连接
  • 位于节点右侧
  • 支持多输出(复合节点)

ParamSpec - 控制参数

@param(
    "threshold",               # 参数名
    "Number",                  # 参数类型
    default=0.5,               # 默认值
    description="阈值",        # 描述
    widget="slider",           # 控件类型
    min=0, max=1, step=0.1     # 数值限制
)

支持的参数类型和控件:

类型 widget 配置项 示例
Number slider min, max, step @param("value", "Number", default=0.5, widget="slider", min=0, max=1, step=0.1)
String text multiline @param("text", "String", default="", multiline=False)
String file - @param("path", "String", widget="file")
String dropdown options @param("mode", "String", widget="dropdown", options=["A", "B", "C"])
Boolean checkbox - @param("enabled", "Boolean", default=True)

ContextSpec - 上下文变量

@context_var(
    "row_count",               # 变量名
    "Integer",                 # 变量类型
    description="数据行数"     # 描述
)

特点:

  • 节点计算的副产物/元数据
  • 自动沿连线向下游广播
  • 下游节点可在参数中引用:$NodeID.row_count
  • 可提升为全局变量:$Global.row_count

🎨 装饰器详解

@register_node - 注册节点

@register_node
class MyNode(TraceNode):
    ...

作用:

  • 自动注册节点到全局注册表
  • 自动收集所有 @input_port@output_port@param@context_var 装饰器
  • 合并到 InputSpecOutputSpecParamSpecContextSpec

装饰器链式调用

@input_port("a", "Number", description="输入A")
@input_port("b", "Number", description="输入B")
@output_port("result", "Number", description="结果")
@param("offset", "Number", default=0)
@context_var("count", "Integer", description="计算次数")
def process(self, inputs, context=None):
    ...

注意: 装饰器可以任意顺序,框架会自动分类收集。


📋 节点类型

1. 标准流水线节点1进1出

@register_node
class FilterNode(TraceNode):
    NODE_TYPE = NodeType.NORMAL  # 默认值
    
    @input_port("data", "DataTable")
    @output_port("filtered", "DataTable")
    def process(self, inputs, context=None):
        ...

2. 输入节点(仅输出)

@register_node
class LoaderNode(TraceNode):
    NODE_TYPE = NodeType.INPUT
    
    # 无 @input_port
    @output_port("data", "DataTable")
    def process(self, inputs, context=None):
        ...

3. 输出节点(仅输入)

@register_node
class DisplayNode(TraceNode):
    NODE_TYPE = NodeType.OUTPUT
    
    @input_port("data", "DataTable")
    # 无 @output_port
    def process(self, inputs, context=None):
        ...

4. 聚合节点(多输入)

@register_node
class ConcatNode(TraceNode):
    NODE_TYPE = NodeType.COMPOSITE
    
    @input_port("tables", "DataTable", list=True)  # 列表输入
    @output_port("concatenated", "DataTable")
    def process(self, inputs, context=None):
        tables = inputs["tables"]  # 这是一个列表
        ...

🔄 process 方法规范

方法签名

def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:

参数说明

  • inputs: 主输入数据字典,键为端口名,值为数据
  • context: 上下文字典
    • 全局变量:context["$Global.variable_name"]
    • 上游节点变量:context["$NodeID.variable_name"]

返回值格式

标准格式(推荐):

return {
    "outputs": {"port_name": data},  # 对应 OutputSpec
    "context": {"var_name": value}   # 对应 ContextSpec
}

简化格式(仅输出):

return {"result": data}  # 自动转换为 {"outputs": {"result": data}, "context": {}}

💾 缓存策略

@register_node
class ExpensiveNode(TraceNode):
    CACHE_POLICY = CachePolicy.MEMORY  # 或 DISK 或 NONE
    
    def process(self, inputs, context=None):
        # 计算密集型操作
        ...

缓存键组成:

  • 节点类名
  • 参数值
  • 输入数据(哈希)
  • 上下文(哈希)

注意: 相同输入和参数会直接返回缓存结果,不重复计算。


🔍 预览模式

默认实现

def preview(self, inputs, context=None, limit=10):
    # 默认:调用 process 然后截取数据
    result = self.process(inputs, context)
    return self._truncate_preview_data(result, limit)

自定义优化

@register_node
class BigDataNode(TraceNode):
    def preview(self, inputs, context=None, limit=10):
        # 优化:仅处理前 N 行
        data = inputs["data"].head(limit)
        return self.process({"data": data}, context)

📊 完整示例

示例CSV 过滤器

@register_node
class AdvancedFilterNode(TraceNode):
    """高级数据过滤器"""
    
    CATEGORY = "Data/Transform"
    DISPLAY_NAME = "高级过滤"
    DESCRIPTION = "支持多条件过滤的数据节点"
    ICON = "🔍"
    CACHE_POLICY = CachePolicy.MEMORY
    
    @input_port("table", "DataTable", description="输入数据表", required=True)
    @output_port("filtered", "DataTable", description="过滤后的数据")
    @output_port("removed", "DataTable", description="被移除的数据")
    
    @param("column", "String", default="", description="过滤列名")
    @param("operator", "String", default=">", widget="dropdown", 
           options=[">", ">=", "<", "<=", "==", "!="], description="运算符")
    @param("value", "Number", default=0, description="比较值")
    @param("return_removed", "Boolean", default=False, description="输出被移除的数据")
    
    @context_var("filtered_count", "Integer", description="保留行数")
    @context_var("removed_count", "Integer", description="移除行数")
    @context_var("filter_expression", "String", description="过滤表达式")
    
    def process(self, inputs, context=None):
        import pandas as pd
        
        table = inputs["table"]
        column = self.get_param("column")
        operator = self.get_param("operator", ">")
        value = self.get_param("value", 0)
        return_removed = self.get_param("return_removed", False)
        
        # 应用过滤
        if operator == ">":
            mask = table[column] > value
        elif operator == ">=":
            mask = table[column] >= value
        elif operator == "<":
            mask = table[column] < value
        elif operator == "<=":
            mask = table[column] <= value
        elif operator == "==":
            mask = table[column] == value
        else:  # !=
            mask = table[column] != value
        
        filtered = table[mask]
        removed = table[~mask]
        
        outputs = {"filtered": filtered}
        if return_removed:
            outputs["removed"] = removed
        
        return {
            "outputs": outputs,
            "context": {
                "filtered_count": len(filtered),
                "removed_count": len(removed),
                "filter_expression": f"{column} {operator} {value}"
            }
        }

最佳实践

1. 命名规范

  • 类名:使用 XxxNode 后缀(如 FilterRowsNode
  • 端口名:小写字母 + 下划线(如 input_data
  • 参数名:小写字母 + 下划线(如 max_value

2. 分类路径

CATEGORY = "Data/Transform/Filter"  # 多级分类

3. 描述文档

  • DESCRIPTION简短说明1-2 句话)
  • 端口 description:说明数据用途
  • 参数 description:说明参数含义

4. 错误处理

def process(self, inputs, context=None):
    try:
        # 处理逻辑
        ...
    except Exception as e:
        raise ValueError(f"节点执行失败: {e}")

5. 类型提示

def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    ...

🚀 部署自定义节点

文件位置

TraceStudio/
├── cloud/
│   └── custom_nodes/        # 用户自定义节点目录
│       └── my_nodes.py      # 你的节点文件
└── server/
    └── app/
        └── nodes/
            └── example_nodes.py  # 内置节点示例

自动加载

服务器启动时会自动扫描 cloud/custom_nodes/ 目录,加载所有带 @register_node 的节点。


🧪 测试节点

cd server
python tests/test_node_system.py

测试内容:

  • 节点注册
  • 元数据生成
  • 节点执行
  • 输入验证
  • 多输入聚合
  • 缓存机制

🎓 对比:旧版 vs 新版

旧版方式(繁琐)

class OldNode(TraceNode):
    InputSpec = {
        "data": ("DataTable", {"description": "输入数据", "required": True})
    }
    OutputSpec = {
        "result": ("DataTable", {"description": "输出结果"})
    }
    ParamSpec = {
        "threshold": ("Number", {"default": 0.5, "min": 0, "max": 1})
    }
    
    def process(self, inputs, context=None):
        ...

新版方式(简洁)

@register_node
class NewNode(TraceNode):
    @input_port("data", "DataTable", description="输入数据")
    @output_port("result", "DataTable", description="输出结果")
    @param("threshold", "Number", default=0.5, min=0, max=1)
    def process(self, inputs, context=None):
        ...

优势:

  • 代码更简洁(减少 40% 代码量)
  • 语义更清晰(装饰器即文档)
  • 易于维护(属性定义靠近使用位置)
  • 自动注册(无需手动调用)

📚 进阶主题

动态参数选项

def get_column_options(self, upstream_data):
    """动态获取列名选项"""
    if "table" in upstream_data:
        return list(upstream_data["table"].columns)
    return []

条件输出

def process(self, inputs, context=None):
    outputs = {"main": result}
    
    if self.get_param("debug_mode"):
        outputs["debug_info"] = debug_data
    
    return {"outputs": outputs, "context": {}}

上下文引用

def process(self, inputs, context=None):
    # 引用全局变量
    user_name = context.get("$Global.user_name", "guest")
    
    # 引用上游节点变量
    upstream_count = context.get("$Loader_1.row_count", 0)
    
    ...

祝你开发愉快!🎉