12 KiB
12 KiB
🧩 自定义节点开发指南
📖 概述
TraceStudio 支持用户创建自定义节点,扩展系统功能。本文档介绍如何开发、测试和部署自定义节点。
🏗️ 架构设计
目录结构
server/
├── custom_nodes/ # 自定义节点目录
│ ├── __init__.py # 自动生成
│ ├── example_nodes.py # 示例节点
│ └── your_node.py # 你的节点
├── app/
│ ├── core/
│ │ ├── node_base.py # 节点基类
│ │ ├── node_validator.py # 代码验证器
│ │ └── node_loader.py # 动态加载器
│ └── api/
│ └── endpoints_custom_nodes.py # 管理 API
安全机制
✅ 代码验证
- AST 语法树分析
- 危险模块黑名单(os, subprocess, eval等)
- 必须继承
TraceNode基类 - 必须实现
execute()方法
✅ 操作确认
- 保存/覆盖文件需确认
- 删除文件需二次确认
- 自动备份旧文件
✅ 沙箱执行
- 执行超时限制(30秒)
- 内存限制(512MB)
- 禁止危险操作
🚀 快速开始
1. 创建节点文件
在 custom_nodes/ 目录创建 .py 文件:
# custom_nodes/my_filter.py
from app.core.node_base import TraceNode
class MyFilterNode(TraceNode):
"""我的过滤节点"""
@staticmethod
def get_metadata():
return {
"display_name": "我的过滤器",
"description": "根据条件过滤数据",
"category": "Custom/Data",
"author": "Your Name",
"version": "1.0.0",
"inputs": [
{
"name": "data",
"type": "DataFrame",
"description": "输入数据",
"required": True
}
],
"outputs": [
{
"name": "result",
"type": "DataFrame",
"description": "过滤结果"
}
],
"params": [
{
"name": "threshold",
"type": "float",
"default": 0.5,
"description": "过滤阈值"
}
]
}
def execute(self, inputs):
"""执行节点逻辑"""
data = inputs.get('data')
threshold = self.get_param('threshold', 0.5)
# 实现过滤逻辑
result = data[data['value'] > threshold]
return {"result": result}
2. 使用前端编辑器
- 打开 TraceStudio Web 界面
- 进入 自定义节点编辑器
- 点击 "新建节点" 或 "加载示例"
- 编写代码
- 点击 "验证代码" 检查语法
- 点击 "保存" 保存并自动加载
3. API 直接操作
# 列出所有节点
curl http://localhost:8000/api/custom-nodes/list
# 验证代码
curl -X POST http://localhost:8000/api/custom-nodes/validate \
-H "Content-Type: application/json" \
-d '{"code": "...", "filename": "test.py"}'
# 保存节点
curl -X POST http://localhost:8000/api/custom-nodes/save \
-H "Content-Type: application/json" \
-d '{"filename": "my_node.py", "code": "...", "force": false}'
# 加载节点
curl -X POST http://localhost:8000/api/custom-nodes/action \
-H "Content-Type: application/json" \
-d '{"filename": "my_node.py", "action": "load"}'
📝 节点开发规范
必须实现的方法
1. get_metadata() 静态方法
返回节点的元数据信息:
@staticmethod
def get_metadata():
return {
"display_name": str, # 显示名称(必需)
"description": str, # 功能描述(推荐)
"category": str, # 分类路径(如 "Data/Transform")
"author": str, # 作者(可选)
"version": str, # 版本号(可选)
"inputs": [...], # 输入端口定义
"outputs": [...], # 输出端口定义
"params": [...] # 参数定义
}
输入端口格式:
{
"name": "input_name", # 端口名(唯一)
"type": "DataFrame", # 数据类型
"description": "说明文字", # 描述(可选)
"required": True # 是否必需(默认True)
}
输出端口格式:
{
"name": "output_name", # 端口名(唯一)
"type": "Any", # 数据类型
"description": "说明文字" # 描述(可选)
}
参数格式:
{
"name": "param_name", # 参数名
"type": "float", # 类型(str/int/float/bool)
"default": 0.5, # 默认值
"description": "说明文字", # 描述(可选)
"options": [0.1, 0.5, 1.0] # 可选值列表(下拉框)
}
2. execute(inputs) 方法
执行节点的核心逻辑:
def execute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
执行节点
Args:
inputs: 输入数据字典 {"input_name": data}
Returns:
输出数据字典 {"output_name": data}
Raises:
Exception: 执行失败时抛出异常
"""
# 1. 获取输入
data = inputs.get('data')
# 2. 获取参数
threshold = self.get_param('threshold', 0.5)
# 3. 执行逻辑
result = process(data, threshold)
# 4. 返回输出
return {"result": result}
辅助方法
# 获取参数值
value = self.get_param('param_name', default_value)
# 验证输入(可选重写)
def validate_inputs(self, inputs):
if 'required_input' not in inputs:
raise ValueError("缺少必需输入")
return True
🔒 安全限制
禁止使用的模块
❌ 系统操作
import os # 禁止
import subprocess # 禁止
import sys # 禁止
❌ 动态执行
eval(code) # 禁止
exec(code) # 禁止
__import__(module) # 禁止
❌ 文件操作(需谨慎)
open(file) # 警告:确保路径安全
允许使用的模块
✅ 数据处理
import pandas as pd
import numpy as np
import json
✅ 数学计算
import math
from scipy import stats
✅ 可视化
import matplotlib.pyplot as plt
import seaborn as sns
📦 内置示例节点
1. DataFilterNode(数据过滤器)
from app.core.node_base import TraceNode
class DataFilterNode(TraceNode):
"""根据条件过滤DataFrame"""
def execute(self, inputs):
data = inputs.get('data')
column = self.get_param('column', 'value')
operator = self.get_param('operator', '>')
threshold = self.get_param('threshold', 0.5)
if operator == '>':
filtered = data[data[column] > threshold]
# ... 其他运算符
return {'filtered_data': filtered, 'count': len(filtered)}
使用场景:过滤性能数据、筛选事件记录
2. TextProcessorNode(文本处理器)
class TextProcessorNode(TraceNode):
"""文本转换处理"""
def execute(self, inputs):
text = inputs.get('text', '')
operation = self.get_param('operation', 'uppercase')
if operation == 'uppercase':
result = text.upper()
elif operation == 'lowercase':
result = text.lower()
# ...
return {'result': result}
使用场景:日志处理、字符串标准化
🧪 测试节点
单元测试
# tests/test_custom_nodes.py
import pytest
from custom_nodes.my_node import MyFilterNode
def test_my_filter_node():
node = MyFilterNode('test_node', {'threshold': 0.5})
# 准备测试数据
import pandas as pd
data = pd.DataFrame({'value': [0.3, 0.7, 0.9]})
# 执行节点
result = node.execute({'data': data})
# 断言结果
assert len(result['result']) == 2
assert result['result']['value'].min() > 0.5
验证测试
# 使用验证器测试
python -c "
from app.core.node_validator import validate_node_file
result = validate_node_file('custom_nodes/my_node.py')
print(result)
"
🛠️ API 文档
节点管理端点
| 端点 | 方法 | 说明 |
|---|---|---|
/api/custom-nodes/list |
GET | 列出所有节点 |
/api/custom-nodes/validate |
POST | 验证代码 |
/api/custom-nodes/read/{filename} |
GET | 读取节点代码 |
/api/custom-nodes/save |
POST | 保存节点 |
/api/custom-nodes/action |
POST | 操作节点(load/unload/delete) |
/api/custom-nodes/download/{filename} |
GET | 下载节点文件 |
/api/custom-nodes/upload |
POST | 上传节点文件 |
/api/custom-nodes/loaded |
GET | 获取已加载节点 |
/api/custom-nodes/reload-all |
POST | 重新加载所有节点 |
/api/custom-nodes/example |
GET | 获取示例代码 |
响应格式
成功响应:
{
"success": true,
"message": "操作成功",
"data": {...}
}
验证响应:
{
"success": true,
"validation": {
"valid": true,
"errors": [],
"warnings": ["建议实现 get_metadata()"],
"node_classes": ["MyNode"],
"metadata": {...}
}
}
🐛 常见问题
1. 验证失败:未找到 TraceNode 基类
问题:
❌ 未找到继承自 TraceNode 的节点类
解决:
# ❌ 错误
class MyNode:
pass
# ✅ 正确
from app.core.node_base import TraceNode
class MyNode(TraceNode):
pass
2. 导入错误:模块找不到
问题:
ModuleNotFoundError: No module named 'app'
解决:
# ❌ 错误(绝对导入)
from server.app.core.node_base import TraceNode
# ✅ 正确(相对于项目根目录)
from app.core.node_base import TraceNode
3. 安全检查失败:禁止导入模块
问题:
❌ 禁止导入危险模块: os (行 3)
解决:移除危险模块导入,使用安全替代方案。
4. 节点未加载到系统
解决步骤:
- 检查文件名格式(必须是
.py) - 验证代码通过(
/api/custom-nodes/validate) - 手动加载节点(
/api/custom-nodes/actionaction=load) - 重启服务器
📚 进阶技巧
1. 状态管理
class StatefulNode(TraceNode):
def __init__(self, node_id, params):
super().__init__(node_id, params)
self.cache = {} # 节点私有缓存
def execute(self, inputs):
# 使用缓存加速
if 'data' in self.cache:
return self.cache['data']
result = expensive_computation(inputs)
self.cache['data'] = result
return result
2. 多输出节点
def execute(self, inputs):
data = inputs.get('data')
# 多个输出
return {
'output1': process_a(data),
'output2': process_b(data),
'stats': {'count': len(data)}
}
3. 异常处理
def execute(self, inputs):
try:
data = inputs.get('data')
if data is None:
raise ValueError("输入数据不能为空")
result = risky_operation(data)
return {'result': result}
except KeyError as e:
raise ValueError(f"缺少必需的键: {e}")
except Exception as e:
raise RuntimeError(f"处理失败: {str(e)}")
🎓 最佳实践
- 清晰的命名:类名和文件名使用描述性名称
- 完整的文档:添加 docstring 说明功能
- 错误处理:提供明确的错误信息
- 输入验证:检查输入数据类型和范围
- 性能优化:避免不必要的计算
- 测试覆盖:编写单元测试
- 版本管理:在 metadata 中记录版本号
📞 技术支持
遇到问题?
Happy Coding! 🚀