TraceStudio/server/app/core/node_base.py

743 lines
25 KiB
Python
Raw Normal View History

2026-01-12 21:51:45 +08:00
"""
TraceStudio 节点基类 (v2.0)
四大属性规范InputSpec, OutputSpec, ParamSpec, ContextSpec
支持高级特性维度转换特殊节点函数嵌套
"""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from .cache_manager import NodeCacheAdapter
import inspect
class NodeType(Enum):
"""节点类型"""
NORMAL = "normal" # 普通节点
INPUT = "input" # 输入节点
OUTPUT = "output" # 输出节点
FUNCTION = "function" # 函数节点(由子工作流包装而成)
COMPOSITE = "composite" # 复合/聚合节点(如 ConcatNode
class DimensionMode(Enum):
"""维度转换模式v2.0 新增)"""
NONE = "none" # 无转换
EXPAND = "down" # 升维:数组→单个元素(遍历)
COLLAPSE = "up" # 降维:单个元素→数组(打包)
BROADCAST = "broadcast" # 广播:多条线→数组(展开+打包)
@classmethod
def from_str(cls, label: str):
# 统一转大写处理,防止前端大小写传错导致崩溃
try:
print("DimensionMode::from_str", label, label.lower())
en = cls[label.lower()]
print("DimensionMode::from_str", label, en)
return en
except (KeyError, AttributeError):
return cls.NONE # 提供一个默认值
class CachePolicy(Enum):
"""缓存策略"""
NONE = "none" # 不缓存
MEMORY = "memory" # 内存缓存
DISK = "disk" # 磁盘缓存
@dataclass
class EdgeMetadata:
"""连线元数据v2.0 新增,用于连线分类和维度转换)"""
source_node: str # 源节点ID
source_port: str # 源端口名
target_node: str # 目标节点ID
target_port: str # 目标端口名
dimension_mode: DimensionMode = DimensionMode.NONE
def to_dict(self) -> Dict:
return {
"source_node": self.source_node,
"source_port": self.source_port,
"target_node": self.target_node,
"target_port": self.target_port,
"dimension_mode": self.dimension_mode
}
@dataclass
class NodeMetadata:
"""节点元数据v2.0 新增,用于函数节点系统)"""
node_id: str # 节点ID
node_type: str # 节点类型来自NodeRegistry
display_name: Optional[str] = None
description: Optional[str] = None
params: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict:
return {
"node_id": self.node_id,
"node_type": self.node_type,
"display_name": self.display_name,
"description": self.description,
"params": self.params,
}
class TraceNode(ABC):
"""
TraceStudio 节点基类 (v2.0)
四大属性
- InputSpec: 主输入必须通过连线
- OutputSpec: 主输出供下游连接
- ParamSpec: 控制参数面板配置
- ContextSpec: 上下文/元数据自动广播
示例用法
@register_node
class AddNode(TraceNode):
CATEGORY = "Math/Basic"
DISPLAY_NAME = "加法"
DESCRIPTION = "计算两个数的和"
# 自动收集的四大属性
@input_port("a", "Number", description="加数A")
@input_port("b", "Number", description="加数B")
@output_port("result", "Number", description="")
@param("offset", "Number", default=0, description="偏移量")
@context_var("count", "Integer", description="计算次数")
def process(self, inputs, context):
a = inputs["a"]
b = inputs["b"]
offset = self.get_param("offset", 0)
result = a + b + offset
# 返回输出和上下文
return {
"outputs": {"result": result},
"context": {"count": 1}
}
"""
# ============= 元数据定义 (类属性) =============
CATEGORY = "Uncategorized" # 分类路径,如 "Data/Transform"
DISPLAY_NAME = None # 显示名称None 则使用类名
DESCRIPTION = "" # 节点描述
ICON = "📦" # 图标Emoji 或图标名)
VERSION = "1.0.0" # 版本号
AUTHOR = "" # 作者
# 节点行为配置
NODE_TYPE = NodeType.NORMAL # 节点类型
CACHE_POLICY = CachePolicy.NONE # 缓存策略
SUPPORTS_PREVIEW = True # 是否支持预览模式
# ============= v2.0 四大属性(自动收集) =============
# 格式: {"port_name": (data_type, config_dict)}
InputSpec: Dict[str, tuple] = {}
OutputSpec: Dict[str, tuple] = {}
ParamSpec: Dict[str, tuple] = {}
ContextSpec: Dict[str, tuple] = {}
def __init__(self, node_id: str, params: Optional[Dict] = None):
"""
初始化节点实例
Args:
node_id: 节点实例 ID前端生成
params: 参数值字典 ParamSpec 中的参数
"""
self.node_id = node_id
self.params = params or {}
self._cache = None
@abstractmethod
def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理节点核心逻辑子类必须实现
Args:
inputs: 主输入数据字典 {"port_name": data}
所有数据由连线提供对应 InputSpec
context: 上下文字典 {"$Global.var": value, "$NodeID.var": value}
包含全局变量和上游节点的 ContextSpec
Returns:
返回字典包含两个键
{
"outputs": {"port_name": data}, # 主输出,对应 OutputSpec
"context": {"var_name": value} # 上下文,对应 ContextSpec
}
简化写法仅返回 outputs
return {"result": data} # 自动转换为 {"outputs": {"result": data}, "context": {}}
"""
pass
def wrap_process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self.temp_inputs = inputs
self.temp_context = context
result = {}
try:
result = self.process(inputs, context)
except Exception as e:
raise e
finally:
self.temp_inputs = None
self.temp_context = None
return result
def preview(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None, limit: int = 10) -> Dict[str, Any]:
"""
预览模式执行可选重写
默认实现调用 process 后截取数据
子类可重写此方法优化预览逻辑如提前终止采样等
Args:
inputs: 主输入数据
context: 上下文
limit: 预览数据量限制
Returns:
预览结果
"""
if not self.SUPPORTS_PREVIEW:
raise NotImplementedError(f"节点 {self.__class__.__name__} 不支持预览")
return self.wrap_process(inputs, context)
# ============= 辅助方法 =============
def get_param(self, name: str, default: Any = None) -> Any:
"""
获取参数值支持静态值Context 引用暴露端口
"""
param = self.params.get(name)
if not param:
return default
mode = param.get("mode")
if mode == "static":
return param.get("value", default)
if mode == "context":
#todo: 支持命名空间
return default
if mode == "exposed":
return self.temp_inputs.get(name, default)
return default
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
"""
验证输入数据完整性
Args:
inputs: 输入数据字典
Returns:
验证是否通过
Raises:
ValueError: 缺少必需输入或类型不匹配
"""
for port_name, (data_type, config) in self.InputSpec.items():
required = config.get("required", True)
# 检查必需输入
if required and port_name not in inputs:
raise ValueError(f"缺少必需的输入端口: {port_name}")
# 检查列表类型
is_list = config.get("list", False)
if port_name in inputs and is_list and not isinstance(inputs[port_name], list):
raise TypeError(f"端口 {port_name} 需要列表类型数据,但收到 {type(inputs[port_name])}")
return True
def execute_with_cache(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
带缓存的执行
Args:
inputs: 输入数据
context: 上下文
Returns:
执行结果
"""
if self.CACHE_POLICY == CachePolicy.NONE:
return self.wrap_process(inputs, context)
# 生成缓存键
cache_key = self._generate_cache_key(inputs, context)
print("***********************",self.__class__.__name__, cache_key)
# 检查缓存
cached = self._get_from_cache(cache_key)
if cached is not None:
print("***********************",self.__class__.__name__, "read cache:", cache_key)
return cached
# 执行处理
result = self.wrap_process(inputs, context)
# 保存缓存
self._save_to_cache(cache_key, result)
return result
def _generate_cache_key(self, inputs: Dict, context: Optional[Dict] = None) -> str:
"""生成缓存键"""
import hashlib
import json
cache_data = {
"class": self.__class__.__name__,
"params": self.params,
"inputs": self._serialize_inputs(inputs),
}
json_str = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(json_str.encode()).hexdigest()
def _serialize_inputs(self, inputs: Dict) -> Dict:
"""序列化输入用于缓存键生成"""
import hashlib
serialized = {}
for key, value in inputs.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
elif isinstance(value, list):
# 列表转为哈希
serialized[key] = hashlib.md5(str(value).encode()).hexdigest()
else:
# 复杂对象转为哈希
serialized[key] = hashlib.md5(str(value).encode()).hexdigest()
return serialized
def _get_from_cache(self, key: str) -> Optional[Dict]:
"""从缓存获取"""
if self._cache is None:
self._init_cache()
return self._cache.get(key) if self._cache else None
def _save_to_cache(self, key: str, value: Dict):
"""保存到缓存"""
if self._cache is None:
self._init_cache()
if self._cache:
self._cache.set(key, value)
def _init_cache(self):
storage = 'memory' if self.CACHE_POLICY == CachePolicy.MEMORY else 'disk'
self._cache = NodeCacheAdapter(storage)
# ============= 元数据生成(供前端使用) =============
@classmethod
def get_metadata(cls) -> Dict:
"""
生成节点元数据供前端渲染
Returns:
节点元数据字典
"""
return {
"class_name": cls.__name__,
"display_name": cls.DISPLAY_NAME or cls.__name__,
"category": cls.CATEGORY,
"description": cls.DESCRIPTION,
"icon": cls.ICON,
"version": cls.VERSION,
"author": cls.AUTHOR,
"node_type": cls.NODE_TYPE.value if hasattr(cls.NODE_TYPE, 'value') else str(cls.NODE_TYPE),
"cache_policy": cls.CACHE_POLICY.value if hasattr(cls.CACHE_POLICY, 'value') else str(cls.CACHE_POLICY),
"supports_preview": cls.SUPPORTS_PREVIEW,
# 四大属性
"inputs": cls._format_spec(cls.InputSpec, "input"),
"outputs": cls._format_spec(cls.OutputSpec, "output"),
"params": cls._format_spec(cls.ParamSpec, "param"),
"context": cls._format_spec(cls.ContextSpec, "context"),
}
@classmethod
def _format_spec(cls, spec: Dict, spec_type: str) -> List[Dict]:
"""格式化属性定义为前端格式"""
formatted = []
for name, (data_type, config) in spec.items():
item = {
"name": name,
"type": data_type,
**config # 展开配置字典
}
formatted.append(item)
return formatted
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.node_id})>"
# ============= 装饰器:自动收集属性 =============
def input_port(name: str, data_type: str, **config):
"""
输入端口装饰器自动收集到 InputSpec
Args:
name: 端口名称
data_type: 数据类型 "DataTable", "Number", "String"
**config: 配置项
- description: 描述
- required: 是否必需默认 True
- list: 是否为列表类型默认 False
用法:
@input_port("data", "DataTable", description="输入数据")
def process(self, inputs, context):
...
"""
def decorator(func):
# 将属性存储在方法对象上,稍后由类装饰器收集
if not hasattr(func, '_pending_input_specs'):
func._pending_input_specs = {}
func._pending_input_specs[name] = (data_type, config)
return func
return decorator
def output_port(name: str, data_type: str, **config):
"""
输出端口装饰器自动收集到 OutputSpec
用法与 `input_port` 对称
"""
def decorator(func):
if not hasattr(func, '_pending_output_specs'):
func._pending_output_specs = {}
func._pending_output_specs[name] = (data_type, config)
return func
return decorator
def param(name: str, param_type: str, **config):
"""
参数装饰器自动收集到 ParamSpec
Args:
name: 参数名称
param_type: 参数类型 "Number", "String", "Boolean", "Dropdown"
**config: 配置项
- default: 默认值
- description: 描述
- widget: 控件类型 "slider", "text", "dropdown"
- min: 最小值数值类型
- max: 最大值数值类型
- step: 步长数值类型
- options: 选项列表下拉框
- multiline: 多行文本文本类型
用法:
@param("threshold", "Number", default=0.5, min=0, max=1, step=0.1)
def process(self, inputs, context):
threshold = self.get_param("threshold")
...
"""
def decorator(func):
if not hasattr(func, '_pending_param_specs'):
func._pending_param_specs = {}
func._pending_param_specs[name] = (param_type, config)
return func
return decorator
def context_var(name: str, var_type: str, **config):
"""
上下文变量装饰器自动收集到 ContextSpec
Args:
name: 变量名称
var_type: 变量类型
**config: 配置项
- description: 描述
用法:
@context_var("row_count", "Integer", description="数据行数")
def process(self, inputs, context):
...
return {
"outputs": {...},
"context": {"row_count": 100}
}
"""
def decorator(func):
if not hasattr(func, '_pending_context_specs'):
func._pending_context_specs = {}
func._pending_context_specs[name] = (var_type, config)
return func
return decorator
def auto_collect_specs(cls):
"""
自动收集装饰器标记的属性到四大 Spec
此装饰器会在类定义完成后自动调用
从方法装饰器中收集 _pending_*_specs 并合并到类的 Spec
"""
# 遍历类的所有方法,收集装饰器标记的属性
for attr_name in dir(cls):
try:
attr = getattr(cls, attr_name)
except AttributeError:
continue
# 收集 InputSpec
if hasattr(attr, '_pending_input_specs'):
if not hasattr(cls, 'InputSpec') or cls.InputSpec is TraceNode.InputSpec:
cls.InputSpec = {}
cls.InputSpec = {**cls.InputSpec, **attr._pending_input_specs}
# 收集 OutputSpec
if hasattr(attr, '_pending_output_specs'):
if not hasattr(cls, 'OutputSpec') or cls.OutputSpec is TraceNode.OutputSpec:
cls.OutputSpec = {}
cls.OutputSpec = {**cls.OutputSpec, **attr._pending_output_specs}
# 收集 ParamSpec
if hasattr(attr, '_pending_param_specs'):
if not hasattr(cls, 'ParamSpec') or cls.ParamSpec is TraceNode.ParamSpec:
cls.ParamSpec = {}
cls.ParamSpec = {**cls.ParamSpec, **attr._pending_param_specs}
# 收集 ContextSpec
if hasattr(attr, '_pending_context_specs'):
if not hasattr(cls, 'ContextSpec') or cls.ContextSpec is TraceNode.ContextSpec:
cls.ContextSpec = {}
cls.ContextSpec = {**cls.ContextSpec, **attr._pending_context_specs}
return cls
# ============= 特殊节点类v2.0 新增,用于函数节点系统)=============
class InputNode(TraceNode, ABC):
"""
输入节点 - 子工作流的入口
在工作流中的作用
- 将外部输入映射到工作流内部
- 不执行任何业务逻辑
- 输出接收到的输入数据
使用场景
函数节点需要从外部接收参数输入节点作为入口
"""
NODE_TYPE = NodeType.INPUT
CATEGORY = "Meta/Input"
DISPLAY_NAME = "输入"
DESCRIPTION = "工作流输入入口"
# 输入节点不需要输入端口(只有输出)
InputSpec = {}
OutputSpec = {} # 由工作流动态定义
ParamSpec = {} # 输入节点通常没有参数
ContextSpec = {}
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
"""
直接返回接收到的输入数据
"""
return {
"outputs": inputs,
"context": context or {}
}
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
"""输入节点不需要验证"""
return True
class OutputNode(TraceNode, ABC):
"""
输出节点 - 子工作流的出口
在工作流中的作用
- 将工作流内部结果映射到外部
- 不执行任何业务逻辑
- 直接返回接收到的数据
使用场景
函数节点需要向外部返回结果输出节点作为出口
"""
NODE_TYPE = NodeType.OUTPUT
CATEGORY = "Meta/Output"
DISPLAY_NAME = "输出"
DESCRIPTION = "工作流输出出口"
# 输出节点只有输入端口(没有输出)
InputSpec = {} # 由工作流动态定义
OutputSpec = {} # 输出节点没有后续输出
ParamSpec = {}
ContextSpec = {}
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
"""
直接返回接收到的数据
"""
return {
"outputs": inputs,
"context": context or {}
}
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
"""输出节点接受任何输入"""
return True
class FunctionNode(TraceNode, ABC):
"""
函数节点 - 由子工作流包装而成的可复用节点
在工作流中的作用
- 将一个完整的子工作流包装成单个节点
- 支持无限嵌套函数节点内可包含函数节点
- 输入/输出映射到子工作流的 InputNode/OutputNode
使用场景
创建可复用的工作流模板"数据清洗""特征提取"
"""
NODE_TYPE = NodeType.FUNCTION
CATEGORY = "Meta/Function"
DISPLAY_NAME = "函数"
DESCRIPTION = "可复用的子工作流"
# 函数节点的输入输出由子工作流定义
InputSpec = {}
OutputSpec = {}
ParamSpec = {}
ContextSpec = {}
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
"""
函数节点的执行由 WorkflowExecutor 负责
这里仅作为占位符
"""
raise NotImplementedError(
"函数节点必须由 WorkflowExecutor 执行,"
"不能直接调用 process() 方法"
)
class WorkflowPackager:
"""工作流打包器 - 将子工作流打包为函数节点"""
@staticmethod
def validate_function_workflow(
nodes: List[Dict],
edges: List[Dict]
) -> Tuple[bool, str]:
"""
验证工作流是否可以打包为函数节点
要求
1. 必须包含至少一个InputNode
2. 必须包含至少一个OutputNode
3. 所有节点都必须是可连接的
4. 不能有孤立节点
Returns:
(valid, error_message)
"""
node_ids = {n["id"] for n in nodes}
def _is_input_node(n: Dict) -> bool:
# 严格检查 type 是否为 NodeType.INPUT.value
t = n.get("type")
if t == NodeType.INPUT.value:
return True
# 向上兼容:如果提供了实现类名或 class_name使用关键字判断
c = str(n.get("class_name", n.get("class", ""))).lower()
return "input" in c
def _is_output_node(n: Dict) -> bool:
t = n.get("type")
if t == NodeType.OUTPUT.value:
return True
c = str(n.get("class_name", n.get("class", ""))).lower()
return "output" in c
has_input = any(_is_input_node(n) for n in nodes)
has_output = any(_is_output_node(n) for n in nodes)
if not has_input:
return False, "函数节点工作流必须包含至少一个InputNode"
if not has_output:
return False, "函数节点工作流必须包含至少一个OutputNode"
# 检查所有连线的节点存在性
for edge in edges:
src = edge.get("source")
tgt = edge.get("target")
if src not in node_ids or tgt not in node_ids:
return False, f"连线引用不存在的节点: {src}{tgt}"
# 检查是否有孤立节点(可选)
connected_nodes = set()
for edge in edges:
connected_nodes.add(edge.get("source"))
connected_nodes.add(edge.get("target"))
# 输入输出节点可以孤立(作为入口/出口)
isolated = node_ids - connected_nodes
# 如果有孤立节点,允许它们仅当它们是明确定义为 Input/Output
for node_id in isolated:
node = next((n for n in nodes if n.get("id") == node_id), None)
if node is None:
continue
t = node.get("type")
c = str(node.get("class_name", node.get("class", ""))).lower()
if t not in (NodeType.INPUT.value, NodeType.OUTPUT.value) and not ("input" in c or "output" in c):
# 允许孤立的普通节点(可能是后续连接)
pass
return True, ""
@staticmethod
def package_as_function(
node_id: str,
nodes: List[Dict],
edges: List[Dict],
display_name: str = "",
description: str = ""
) -> Dict[str, Any]:
"""
将工作流打包为函数节点
Args:
node_id: 新函数节点的ID
nodes: 子工作流节点
edges: 子工作流连线
display_name: 显示名称
description: 描述
Returns:
函数节点定义
"""
valid, error = WorkflowPackager.validate_function_workflow(nodes, edges)
if not valid:
raise ValueError(f"无法打包工作流: {error}")
# 返回符合执行器严格约定的函数节点定义:
# - type 必须为 NodeType.FUNCTION.value
# - class 字段指定实现类名(这里使用 FunctionNodeImpl
return {
"id": node_id,
"type": NodeType.FUNCTION.value,
"class_name": "FunctionNodeImpl",
"display_name": display_name or "函数工作流",
"description": description or "通过工作流定义的函数节点",
"params": {},
}