TraceStudio-dev/docs/studio1.3/CUSTOM_NODES.md
2026-01-09 21:37:02 +08:00

12 KiB
Raw Blame History

🧩 自定义节点开发指南

📖 概述

TraceStudio 支持用户创建自定义节点,扩展系统功能。本文档介绍如何开发、测试和部署自定义节点。

🏗️ 架构设计

目录结构

server/
├── custom_nodes/           # 自定义节点目录
│   ├── __init__.py         # 自动生成
│   ├── example_nodes.py    # 示例节点
│   └── your_node.py        # 你的节点
├── app/
│   ├── core/
│   │   ├── node_base.py        # 节点基类
│   │   ├── node_validator.py   # 代码验证器
│   │   └── node_loader.py      # 动态加载器
│   └── api/
│       └── endpoints_custom_nodes.py  # 管理 API

安全机制

代码验证

  • AST 语法树分析
  • 危险模块黑名单os, subprocess, eval等
  • 必须继承 TraceNode 基类
  • 必须实现 execute() 方法

操作确认

  • 保存/覆盖文件需确认
  • 删除文件需二次确认
  • 自动备份旧文件

沙箱执行

  • 执行超时限制30秒
  • 内存限制512MB
  • 禁止危险操作

🚀 快速开始

1. 创建节点文件

custom_nodes/ 目录创建 .py 文件:

# custom_nodes/my_filter.py
from app.core.node_base import TraceNode


class MyFilterNode(TraceNode):
    """我的过滤节点"""
    
    @staticmethod
    def get_metadata():
        return {
            "display_name": "我的过滤器",
            "description": "根据条件过滤数据",
            "category": "Custom/Data",
            "author": "Your Name",
            "version": "1.0.0",
            "inputs": [
                {
                    "name": "data",
                    "type": "DataFrame",
                    "description": "输入数据",
                    "required": True
                }
            ],
            "outputs": [
                {
                    "name": "result",
                    "type": "DataFrame",
                    "description": "过滤结果"
                }
            ],
            "params": [
                {
                    "name": "threshold",
                    "type": "float",
                    "default": 0.5,
                    "description": "过滤阈值"
                }
            ]
        }
    
    def execute(self, inputs):
        """执行节点逻辑"""
        data = inputs.get('data')
        threshold = self.get_param('threshold', 0.5)
        
        # 实现过滤逻辑
        result = data[data['value'] > threshold]
        
        return {"result": result}

2. 使用前端编辑器

  1. 打开 TraceStudio Web 界面
  2. 进入 自定义节点编辑器
  3. 点击 "新建节点""加载示例"
  4. 编写代码
  5. 点击 "验证代码" 检查语法
  6. 点击 "保存" 保存并自动加载

3. API 直接操作

# 列出所有节点
curl http://localhost:8000/api/custom-nodes/list

# 验证代码
curl -X POST http://localhost:8000/api/custom-nodes/validate \
  -H "Content-Type: application/json" \
  -d '{"code": "...", "filename": "test.py"}'

# 保存节点
curl -X POST http://localhost:8000/api/custom-nodes/save \
  -H "Content-Type: application/json" \
  -d '{"filename": "my_node.py", "code": "...", "force": false}'

# 加载节点
curl -X POST http://localhost:8000/api/custom-nodes/action \
  -H "Content-Type: application/json" \
  -d '{"filename": "my_node.py", "action": "load"}'

📝 节点开发规范

必须实现的方法

1. get_metadata() 静态方法

返回节点的元数据信息:

@staticmethod
def get_metadata():
    return {
        "display_name": str,      # 显示名称(必需)
        "description": str,       # 功能描述(推荐)
        "category": str,          # 分类路径(如 "Data/Transform"
        "author": str,            # 作者(可选)
        "version": str,           # 版本号(可选)
        "inputs": [...],          # 输入端口定义
        "outputs": [...],         # 输出端口定义
        "params": [...]           # 参数定义
    }

输入端口格式

{
    "name": "input_name",         # 端口名(唯一)
    "type": "DataFrame",          # 数据类型
    "description": "说明文字",    # 描述(可选)
    "required": True              # 是否必需默认True
}

输出端口格式

{
    "name": "output_name",        # 端口名(唯一)
    "type": "Any",                # 数据类型
    "description": "说明文字"     # 描述(可选)
}

参数格式

{
    "name": "param_name",         # 参数名
    "type": "float",              # 类型str/int/float/bool
    "default": 0.5,               # 默认值
    "description": "说明文字",    # 描述(可选)
    "options": [0.1, 0.5, 1.0]    # 可选值列表(下拉框)
}

2. execute(inputs) 方法

执行节点的核心逻辑:

def execute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
    """
    执行节点
    
    Args:
        inputs: 输入数据字典 {"input_name": data}
    
    Returns:
        输出数据字典 {"output_name": data}
    
    Raises:
        Exception: 执行失败时抛出异常
    """
    # 1. 获取输入
    data = inputs.get('data')
    
    # 2. 获取参数
    threshold = self.get_param('threshold', 0.5)
    
    # 3. 执行逻辑
    result = process(data, threshold)
    
    # 4. 返回输出
    return {"result": result}

辅助方法

# 获取参数值
value = self.get_param('param_name', default_value)

# 验证输入(可选重写)
def validate_inputs(self, inputs):
    if 'required_input' not in inputs:
        raise ValueError("缺少必需输入")
    return True

🔒 安全限制

禁止使用的模块

系统操作

import os            # 禁止
import subprocess    # 禁止
import sys          # 禁止

动态执行

eval(code)          # 禁止
exec(code)          # 禁止
__import__(module)  # 禁止

文件操作(需谨慎)

open(file)          # 警告:确保路径安全

允许使用的模块

数据处理

import pandas as pd
import numpy as np
import json

数学计算

import math
from scipy import stats

可视化

import matplotlib.pyplot as plt
import seaborn as sns

📦 内置示例节点

1. DataFilterNode数据过滤器

from app.core.node_base import TraceNode

class DataFilterNode(TraceNode):
    """根据条件过滤DataFrame"""
    
    def execute(self, inputs):
        data = inputs.get('data')
        column = self.get_param('column', 'value')
        operator = self.get_param('operator', '>')
        threshold = self.get_param('threshold', 0.5)
        
        if operator == '>':
            filtered = data[data[column] > threshold]
        # ... 其他运算符
        
        return {'filtered_data': filtered, 'count': len(filtered)}

使用场景:过滤性能数据、筛选事件记录

2. TextProcessorNode文本处理器

class TextProcessorNode(TraceNode):
    """文本转换处理"""
    
    def execute(self, inputs):
        text = inputs.get('text', '')
        operation = self.get_param('operation', 'uppercase')
        
        if operation == 'uppercase':
            result = text.upper()
        elif operation == 'lowercase':
            result = text.lower()
        # ...
        
        return {'result': result}

使用场景:日志处理、字符串标准化

🧪 测试节点

单元测试

# tests/test_custom_nodes.py
import pytest
from custom_nodes.my_node import MyFilterNode

def test_my_filter_node():
    node = MyFilterNode('test_node', {'threshold': 0.5})
    
    # 准备测试数据
    import pandas as pd
    data = pd.DataFrame({'value': [0.3, 0.7, 0.9]})
    
    # 执行节点
    result = node.execute({'data': data})
    
    # 断言结果
    assert len(result['result']) == 2
    assert result['result']['value'].min() > 0.5

验证测试

# 使用验证器测试
python -c "
from app.core.node_validator import validate_node_file
result = validate_node_file('custom_nodes/my_node.py')
print(result)
"

🛠️ API 文档

节点管理端点

端点 方法 说明
/api/custom-nodes/list GET 列出所有节点
/api/custom-nodes/validate POST 验证代码
/api/custom-nodes/read/{filename} GET 读取节点代码
/api/custom-nodes/save POST 保存节点
/api/custom-nodes/action POST 操作节点load/unload/delete
/api/custom-nodes/download/{filename} GET 下载节点文件
/api/custom-nodes/upload POST 上传节点文件
/api/custom-nodes/loaded GET 获取已加载节点
/api/custom-nodes/reload-all POST 重新加载所有节点
/api/custom-nodes/example GET 获取示例代码

响应格式

成功响应

{
  "success": true,
  "message": "操作成功",
  "data": {...}
}

验证响应

{
  "success": true,
  "validation": {
    "valid": true,
    "errors": [],
    "warnings": ["建议实现 get_metadata()"],
    "node_classes": ["MyNode"],
    "metadata": {...}
  }
}

🐛 常见问题

1. 验证失败:未找到 TraceNode 基类

问题

❌ 未找到继承自 TraceNode 的节点类

解决

# ❌ 错误
class MyNode:
    pass

# ✅ 正确
from app.core.node_base import TraceNode

class MyNode(TraceNode):
    pass

2. 导入错误:模块找不到

问题

ModuleNotFoundError: No module named 'app'

解决

# ❌ 错误(绝对导入)
from server.app.core.node_base import TraceNode

# ✅ 正确(相对于项目根目录)
from app.core.node_base import TraceNode

3. 安全检查失败:禁止导入模块

问题

❌ 禁止导入危险模块: os (行 3)

解决:移除危险模块导入,使用安全替代方案。

4. 节点未加载到系统

解决步骤

  1. 检查文件名格式(必须是 .py
  2. 验证代码通过(/api/custom-nodes/validate
  3. 手动加载节点(/api/custom-nodes/action action=load
  4. 重启服务器

📚 进阶技巧

1. 状态管理

class StatefulNode(TraceNode):
    def __init__(self, node_id, params):
        super().__init__(node_id, params)
        self.cache = {}  # 节点私有缓存
    
    def execute(self, inputs):
        # 使用缓存加速
        if 'data' in self.cache:
            return self.cache['data']
        
        result = expensive_computation(inputs)
        self.cache['data'] = result
        return result

2. 多输出节点

def execute(self, inputs):
    data = inputs.get('data')
    
    # 多个输出
    return {
        'output1': process_a(data),
        'output2': process_b(data),
        'stats': {'count': len(data)}
    }

3. 异常处理

def execute(self, inputs):
    try:
        data = inputs.get('data')
        if data is None:
            raise ValueError("输入数据不能为空")
        
        result = risky_operation(data)
        return {'result': result}
    
    except KeyError as e:
        raise ValueError(f"缺少必需的键: {e}")
    except Exception as e:
        raise RuntimeError(f"处理失败: {str(e)}")

🎓 最佳实践

  1. 清晰的命名:类名和文件名使用描述性名称
  2. 完整的文档:添加 docstring 说明功能
  3. 错误处理:提供明确的错误信息
  4. 输入验证:检查输入数据类型和范围
  5. 性能优化:避免不必要的计算
  6. 测试覆盖:编写单元测试
  7. 版本管理:在 metadata 中记录版本号

📞 技术支持

遇到问题?


Happy Coding! 🚀