2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
TraceStudio 节点基类 (v2.0)
|
|
|
|
|
|
四大属性规范:InputSpec, OutputSpec, ParamSpec, ContextSpec
|
|
|
|
|
|
支持高级特性:维度转换、特殊节点、函数嵌套
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from .cache_manager import NodeCacheAdapter
|
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeType(Enum):
|
|
|
|
|
|
"""节点类型"""
|
|
|
|
|
|
NORMAL = "normal" # 普通节点
|
|
|
|
|
|
INPUT = "input" # 输入节点
|
|
|
|
|
|
OUTPUT = "output" # 输出节点
|
|
|
|
|
|
FUNCTION = "function" # 函数节点(由子工作流包装而成)
|
|
|
|
|
|
COMPOSITE = "composite" # 复合/聚合节点(如 ConcatNode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DimensionMode(Enum):
|
|
|
|
|
|
"""维度转换模式(v2.0 新增)"""
|
|
|
|
|
|
NONE = "none" # 无转换
|
|
|
|
|
|
EXPAND = "down" # 升维:数组→单个元素(遍历)
|
|
|
|
|
|
COLLAPSE = "up" # 降维:单个元素→数组(打包)
|
|
|
|
|
|
BROADCAST = "broadcast" # 广播:多条线→数组(展开+打包)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_str(cls, label: str):
|
|
|
|
|
|
# 统一转大写处理,防止前端大小写传错导致崩溃
|
|
|
|
|
|
try:
|
|
|
|
|
|
print("DimensionMode::from_str", label, label.lower())
|
|
|
|
|
|
en = cls[label.lower()]
|
|
|
|
|
|
print("DimensionMode::from_str", label, en)
|
|
|
|
|
|
return en
|
|
|
|
|
|
except (KeyError, AttributeError):
|
|
|
|
|
|
return cls.NONE # 提供一个默认值
|
|
|
|
|
|
|
|
|
|
|
|
class CachePolicy(Enum):
|
|
|
|
|
|
"""缓存策略"""
|
|
|
|
|
|
NONE = "none" # 不缓存
|
|
|
|
|
|
MEMORY = "memory" # 内存缓存
|
|
|
|
|
|
DISK = "disk" # 磁盘缓存
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class EdgeMetadata:
|
|
|
|
|
|
"""连线元数据(v2.0 新增,用于连线分类和维度转换)"""
|
|
|
|
|
|
source_node: str # 源节点ID
|
|
|
|
|
|
source_port: str # 源端口名
|
|
|
|
|
|
target_node: str # 目标节点ID
|
|
|
|
|
|
target_port: str # 目标端口名
|
|
|
|
|
|
dimension_mode: DimensionMode = DimensionMode.NONE
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"source_node": self.source_node,
|
|
|
|
|
|
"source_port": self.source_port,
|
|
|
|
|
|
"target_node": self.target_node,
|
|
|
|
|
|
"target_port": self.target_port,
|
|
|
|
|
|
"dimension_mode": self.dimension_mode
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class NodeMetadata:
|
|
|
|
|
|
"""节点元数据(v2.0 新增,用于函数节点系统)"""
|
|
|
|
|
|
node_id: str # 节点ID
|
|
|
|
|
|
node_type: str # 节点类型(来自NodeRegistry)
|
|
|
|
|
|
display_name: Optional[str] = None
|
|
|
|
|
|
description: Optional[str] = None
|
|
|
|
|
|
params: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"node_id": self.node_id,
|
|
|
|
|
|
"node_type": self.node_type,
|
|
|
|
|
|
"display_name": self.display_name,
|
|
|
|
|
|
"description": self.description,
|
|
|
|
|
|
"params": self.params,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TraceNode(ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
TraceStudio 节点基类 (v2.0)
|
|
|
|
|
|
|
|
|
|
|
|
四大属性:
|
|
|
|
|
|
- InputSpec: 主输入(必须通过连线)
|
|
|
|
|
|
- OutputSpec: 主输出(供下游连接)
|
|
|
|
|
|
- ParamSpec: 控制参数(面板配置)
|
|
|
|
|
|
- ContextSpec: 上下文/元数据(自动广播)
|
|
|
|
|
|
|
|
|
|
|
|
示例用法:
|
|
|
|
|
|
@register_node
|
|
|
|
|
|
class AddNode(TraceNode):
|
|
|
|
|
|
CATEGORY = "Math/Basic"
|
|
|
|
|
|
DISPLAY_NAME = "加法"
|
|
|
|
|
|
DESCRIPTION = "计算两个数的和"
|
|
|
|
|
|
|
|
|
|
|
|
# 自动收集的四大属性
|
|
|
|
|
|
@input_port("a", "Number", description="加数A")
|
|
|
|
|
|
@input_port("b", "Number", description="加数B")
|
|
|
|
|
|
@output_port("result", "Number", description="和")
|
|
|
|
|
|
@param("offset", "Number", default=0, description="偏移量")
|
|
|
|
|
|
@context_var("count", "Integer", description="计算次数")
|
|
|
|
|
|
def process(self, inputs, context):
|
|
|
|
|
|
a = inputs["a"]
|
|
|
|
|
|
b = inputs["b"]
|
|
|
|
|
|
offset = self.get_param("offset", 0)
|
|
|
|
|
|
result = a + b + offset
|
|
|
|
|
|
|
|
|
|
|
|
# 返回输出和上下文
|
|
|
|
|
|
return {
|
|
|
|
|
|
"outputs": {"result": result},
|
|
|
|
|
|
"context": {"count": 1}
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# ============= 元数据定义 (类属性) =============
|
|
|
|
|
|
CATEGORY = "Uncategorized" # 分类路径,如 "Data/Transform"
|
|
|
|
|
|
DISPLAY_NAME = None # 显示名称,None 则使用类名
|
|
|
|
|
|
DESCRIPTION = "" # 节点描述
|
|
|
|
|
|
ICON = "📦" # 图标(Emoji 或图标名)
|
|
|
|
|
|
VERSION = "1.0.0" # 版本号
|
|
|
|
|
|
AUTHOR = "" # 作者
|
|
|
|
|
|
|
|
|
|
|
|
# 节点行为配置
|
|
|
|
|
|
NODE_TYPE = NodeType.NORMAL # 节点类型
|
|
|
|
|
|
CACHE_POLICY = CachePolicy.NONE # 缓存策略
|
|
|
|
|
|
SUPPORTS_PREVIEW = True # 是否支持预览模式
|
|
|
|
|
|
|
|
|
|
|
|
# ============= v2.0 四大属性(自动收集) =============
|
|
|
|
|
|
# 格式: {"port_name": (data_type, config_dict)}
|
|
|
|
|
|
InputSpec: Dict[str, tuple] = {}
|
|
|
|
|
|
OutputSpec: Dict[str, tuple] = {}
|
|
|
|
|
|
ParamSpec: Dict[str, tuple] = {}
|
|
|
|
|
|
ContextSpec: Dict[str, tuple] = {}
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, node_id: str, params: Optional[Dict] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化节点实例
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node_id: 节点实例 ID(前端生成)
|
|
|
|
|
|
params: 参数值字典(仅 ParamSpec 中的参数)
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.node_id = node_id
|
|
|
|
|
|
self.params = params or {}
|
|
|
|
|
|
self._cache = None
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
处理节点核心逻辑(子类必须实现)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
inputs: 主输入数据字典 {"port_name": data}
|
|
|
|
|
|
所有数据由连线提供,对应 InputSpec
|
|
|
|
|
|
context: 上下文字典 {"$Global.var": value, "$NodeID.var": value}
|
|
|
|
|
|
包含全局变量和上游节点的 ContextSpec
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
返回字典,包含两个键:
|
|
|
|
|
|
{
|
|
|
|
|
|
"outputs": {"port_name": data}, # 主输出,对应 OutputSpec
|
|
|
|
|
|
"context": {"var_name": value} # 上下文,对应 ContextSpec
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
简化写法(仅返回 outputs):
|
|
|
|
|
|
return {"result": data} # 自动转换为 {"outputs": {"result": data}, "context": {}}
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
2026-01-10 19:08:49 +08:00
|
|
|
|
|
2026-01-09 21:37:02 +08:00
|
|
|
|
def wrap_process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
|
|
|
|
self.temp_inputs = inputs
|
|
|
|
|
|
self.temp_context = context
|
2026-01-10 19:08:49 +08:00
|
|
|
|
result = {}
|
2026-01-09 21:37:02 +08:00
|
|
|
|
try:
|
2026-01-10 19:08:49 +08:00
|
|
|
|
result = self.process(inputs, context)
|
2026-01-09 21:37:02 +08:00
|
|
|
|
except Exception as e:
|
2026-01-10 19:08:49 +08:00
|
|
|
|
raise e
|
|
|
|
|
|
finally:
|
|
|
|
|
|
self.temp_inputs = None
|
|
|
|
|
|
self.temp_context = None
|
2026-01-09 21:37:02 +08:00
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def preview(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None, limit: int = 10) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
预览模式执行(可选重写)
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:调用 process 后截取数据
|
|
|
|
|
|
子类可重写此方法优化预览逻辑(如提前终止、采样等)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
inputs: 主输入数据
|
|
|
|
|
|
context: 上下文
|
|
|
|
|
|
limit: 预览数据量限制
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
预览结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.SUPPORTS_PREVIEW:
|
|
|
|
|
|
raise NotImplementedError(f"节点 {self.__class__.__name__} 不支持预览")
|
|
|
|
|
|
|
|
|
|
|
|
return self.wrap_process(inputs, context)
|
|
|
|
|
|
|
|
|
|
|
|
# ============= 辅助方法 =============
|
|
|
|
|
|
def get_param(self, name: str, default: Any = None) -> Any:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取参数值(支持静态值、Context 引用、暴露端口)
|
|
|
|
|
|
"""
|
2026-01-12 03:32:51 +08:00
|
|
|
|
param = self.params.get(name)
|
|
|
|
|
|
if not param:
|
|
|
|
|
|
return default
|
|
|
|
|
|
mode = param.get("mode")
|
|
|
|
|
|
if mode == "static":
|
|
|
|
|
|
return param.get("value", default)
|
|
|
|
|
|
if mode == "context":
|
|
|
|
|
|
#todo: 支持命名空间
|
|
|
|
|
|
return default
|
|
|
|
|
|
if mode == "exposed":
|
|
|
|
|
|
return self.temp_inputs.get(name, default)
|
|
|
|
|
|
return default
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
验证输入数据完整性
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
inputs: 输入数据字典
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
验证是否通过
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
ValueError: 缺少必需输入或类型不匹配
|
|
|
|
|
|
"""
|
|
|
|
|
|
for port_name, (data_type, config) in self.InputSpec.items():
|
|
|
|
|
|
required = config.get("required", True)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查必需输入
|
|
|
|
|
|
if required and port_name not in inputs:
|
|
|
|
|
|
raise ValueError(f"缺少必需的输入端口: {port_name}")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查列表类型
|
|
|
|
|
|
is_list = config.get("list", False)
|
|
|
|
|
|
if port_name in inputs and is_list and not isinstance(inputs[port_name], list):
|
|
|
|
|
|
raise TypeError(f"端口 {port_name} 需要列表类型数据,但收到 {type(inputs[port_name])}")
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def execute_with_cache(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
带缓存的执行
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
inputs: 输入数据
|
|
|
|
|
|
context: 上下文
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
执行结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.CACHE_POLICY == CachePolicy.NONE:
|
|
|
|
|
|
return self.wrap_process(inputs, context)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成缓存键
|
|
|
|
|
|
cache_key = self._generate_cache_key(inputs, context)
|
|
|
|
|
|
print("***********************",self.__class__.__name__, cache_key)
|
|
|
|
|
|
# 检查缓存
|
|
|
|
|
|
cached = self._get_from_cache(cache_key)
|
|
|
|
|
|
if cached is not None:
|
|
|
|
|
|
print("***********************",self.__class__.__name__, "read cache:", cache_key)
|
|
|
|
|
|
return cached
|
|
|
|
|
|
|
|
|
|
|
|
# 执行处理
|
|
|
|
|
|
result = self.wrap_process(inputs, context)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存缓存
|
|
|
|
|
|
self._save_to_cache(cache_key, result)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_cache_key(self, inputs: Dict, context: Optional[Dict] = None) -> str:
|
|
|
|
|
|
"""生成缓存键"""
|
|
|
|
|
|
import hashlib
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
cache_data = {
|
|
|
|
|
|
"class": self.__class__.__name__,
|
|
|
|
|
|
"params": self.params,
|
|
|
|
|
|
"inputs": self._serialize_inputs(inputs),
|
|
|
|
|
|
}
|
|
|
|
|
|
json_str = json.dumps(cache_data, sort_keys=True)
|
|
|
|
|
|
return hashlib.sha256(json_str.encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
def _serialize_inputs(self, inputs: Dict) -> Dict:
|
|
|
|
|
|
"""序列化输入用于缓存键生成"""
|
|
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
serialized = {}
|
|
|
|
|
|
for key, value in inputs.items():
|
|
|
|
|
|
if isinstance(value, (str, int, float, bool, type(None))):
|
|
|
|
|
|
serialized[key] = value
|
|
|
|
|
|
elif isinstance(value, list):
|
|
|
|
|
|
# 列表转为哈希
|
|
|
|
|
|
serialized[key] = hashlib.md5(str(value).encode()).hexdigest()
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 复杂对象转为哈希
|
|
|
|
|
|
serialized[key] = hashlib.md5(str(value).encode()).hexdigest()
|
|
|
|
|
|
return serialized
|
|
|
|
|
|
|
|
|
|
|
|
def _get_from_cache(self, key: str) -> Optional[Dict]:
|
|
|
|
|
|
"""从缓存获取"""
|
|
|
|
|
|
if self._cache is None:
|
|
|
|
|
|
self._init_cache()
|
|
|
|
|
|
return self._cache.get(key) if self._cache else None
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_cache(self, key: str, value: Dict):
|
|
|
|
|
|
"""保存到缓存"""
|
|
|
|
|
|
if self._cache is None:
|
|
|
|
|
|
self._init_cache()
|
|
|
|
|
|
if self._cache:
|
|
|
|
|
|
self._cache.set(key, value)
|
|
|
|
|
|
|
|
|
|
|
|
def _init_cache(self):
|
|
|
|
|
|
storage = 'memory' if self.CACHE_POLICY == CachePolicy.MEMORY else 'disk'
|
|
|
|
|
|
self._cache = NodeCacheAdapter(storage)
|
|
|
|
|
|
|
|
|
|
|
|
# ============= 元数据生成(供前端使用) =============
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def get_metadata(cls) -> Dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
生成节点元数据(供前端渲染)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
节点元数据字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"class_name": cls.__name__,
|
|
|
|
|
|
"display_name": cls.DISPLAY_NAME or cls.__name__,
|
|
|
|
|
|
"category": cls.CATEGORY,
|
|
|
|
|
|
"description": cls.DESCRIPTION,
|
|
|
|
|
|
"icon": cls.ICON,
|
|
|
|
|
|
"version": cls.VERSION,
|
|
|
|
|
|
"author": cls.AUTHOR,
|
|
|
|
|
|
"node_type": cls.NODE_TYPE.value if hasattr(cls.NODE_TYPE, 'value') else str(cls.NODE_TYPE),
|
|
|
|
|
|
"cache_policy": cls.CACHE_POLICY.value if hasattr(cls.CACHE_POLICY, 'value') else str(cls.CACHE_POLICY),
|
|
|
|
|
|
"supports_preview": cls.SUPPORTS_PREVIEW,
|
|
|
|
|
|
|
|
|
|
|
|
# 四大属性
|
|
|
|
|
|
"inputs": cls._format_spec(cls.InputSpec, "input"),
|
|
|
|
|
|
"outputs": cls._format_spec(cls.OutputSpec, "output"),
|
|
|
|
|
|
"params": cls._format_spec(cls.ParamSpec, "param"),
|
|
|
|
|
|
"context": cls._format_spec(cls.ContextSpec, "context"),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def _format_spec(cls, spec: Dict, spec_type: str) -> List[Dict]:
|
|
|
|
|
|
"""格式化属性定义为前端格式"""
|
|
|
|
|
|
formatted = []
|
|
|
|
|
|
for name, (data_type, config) in spec.items():
|
|
|
|
|
|
item = {
|
|
|
|
|
|
"name": name,
|
|
|
|
|
|
"type": data_type,
|
|
|
|
|
|
**config # 展开配置字典
|
|
|
|
|
|
}
|
|
|
|
|
|
formatted.append(item)
|
|
|
|
|
|
return formatted
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
|
return f"<{self.__class__.__name__}(id={self.node_id})>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============= 装饰器:自动收集属性 =============
|
|
|
|
|
|
|
|
|
|
|
|
def input_port(name: str, data_type: str, **config):
|
|
|
|
|
|
"""
|
|
|
|
|
|
输入端口装饰器(自动收集到 InputSpec)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 端口名称
|
|
|
|
|
|
data_type: 数据类型(如 "DataTable", "Number", "String")
|
|
|
|
|
|
**config: 配置项
|
|
|
|
|
|
- description: 描述
|
|
|
|
|
|
- required: 是否必需(默认 True)
|
|
|
|
|
|
- list: 是否为列表类型(默认 False)
|
|
|
|
|
|
|
|
|
|
|
|
用法:
|
|
|
|
|
|
@input_port("data", "DataTable", description="输入数据")
|
|
|
|
|
|
def process(self, inputs, context):
|
|
|
|
|
|
...
|
|
|
|
|
|
"""
|
|
|
|
|
|
def decorator(func):
|
|
|
|
|
|
# 将属性存储在方法对象上,稍后由类装饰器收集
|
|
|
|
|
|
if not hasattr(func, '_pending_input_specs'):
|
|
|
|
|
|
func._pending_input_specs = {}
|
|
|
|
|
|
func._pending_input_specs[name] = (data_type, config)
|
|
|
|
|
|
return func
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def output_port(name: str, data_type: str, **config):
|
|
|
|
|
|
"""
|
|
|
|
|
|
输出端口装饰器(自动收集到 OutputSpec)
|
|
|
|
|
|
|
|
|
|
|
|
用法与 `input_port` 对称。
|
|
|
|
|
|
"""
|
|
|
|
|
|
def decorator(func):
|
|
|
|
|
|
if not hasattr(func, '_pending_output_specs'):
|
|
|
|
|
|
func._pending_output_specs = {}
|
|
|
|
|
|
func._pending_output_specs[name] = (data_type, config)
|
|
|
|
|
|
return func
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def param(name: str, param_type: str, **config):
|
|
|
|
|
|
"""
|
|
|
|
|
|
参数装饰器(自动收集到 ParamSpec)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 参数名称
|
|
|
|
|
|
param_type: 参数类型(如 "Number", "String", "Boolean", "Dropdown")
|
|
|
|
|
|
**config: 配置项
|
|
|
|
|
|
- default: 默认值
|
|
|
|
|
|
- description: 描述
|
|
|
|
|
|
- widget: 控件类型(如 "slider", "text", "dropdown")
|
|
|
|
|
|
- min: 最小值(数值类型)
|
|
|
|
|
|
- max: 最大值(数值类型)
|
|
|
|
|
|
- step: 步长(数值类型)
|
|
|
|
|
|
- options: 选项列表(下拉框)
|
|
|
|
|
|
- multiline: 多行文本(文本类型)
|
|
|
|
|
|
|
|
|
|
|
|
用法:
|
|
|
|
|
|
@param("threshold", "Number", default=0.5, min=0, max=1, step=0.1)
|
|
|
|
|
|
def process(self, inputs, context):
|
|
|
|
|
|
threshold = self.get_param("threshold")
|
|
|
|
|
|
...
|
|
|
|
|
|
"""
|
|
|
|
|
|
def decorator(func):
|
|
|
|
|
|
if not hasattr(func, '_pending_param_specs'):
|
|
|
|
|
|
func._pending_param_specs = {}
|
|
|
|
|
|
func._pending_param_specs[name] = (param_type, config)
|
|
|
|
|
|
return func
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def context_var(name: str, var_type: str, **config):
|
|
|
|
|
|
"""
|
|
|
|
|
|
上下文变量装饰器(自动收集到 ContextSpec)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 变量名称
|
|
|
|
|
|
var_type: 变量类型
|
|
|
|
|
|
**config: 配置项
|
|
|
|
|
|
- description: 描述
|
|
|
|
|
|
|
|
|
|
|
|
用法:
|
|
|
|
|
|
@context_var("row_count", "Integer", description="数据行数")
|
|
|
|
|
|
def process(self, inputs, context):
|
|
|
|
|
|
...
|
|
|
|
|
|
return {
|
|
|
|
|
|
"outputs": {...},
|
|
|
|
|
|
"context": {"row_count": 100}
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
def decorator(func):
|
|
|
|
|
|
if not hasattr(func, '_pending_context_specs'):
|
|
|
|
|
|
func._pending_context_specs = {}
|
|
|
|
|
|
func._pending_context_specs[name] = (var_type, config)
|
|
|
|
|
|
return func
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_collect_specs(cls):
|
|
|
|
|
|
"""
|
|
|
|
|
|
自动收集装饰器标记的属性到四大 Spec
|
|
|
|
|
|
|
|
|
|
|
|
此装饰器会在类定义完成后自动调用,
|
|
|
|
|
|
从方法装饰器中收集 _pending_*_specs 并合并到类的 Spec 中
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 遍历类的所有方法,收集装饰器标记的属性
|
|
|
|
|
|
for attr_name in dir(cls):
|
|
|
|
|
|
try:
|
|
|
|
|
|
attr = getattr(cls, attr_name)
|
|
|
|
|
|
except AttributeError:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 收集 InputSpec
|
|
|
|
|
|
if hasattr(attr, '_pending_input_specs'):
|
|
|
|
|
|
if not hasattr(cls, 'InputSpec') or cls.InputSpec is TraceNode.InputSpec:
|
|
|
|
|
|
cls.InputSpec = {}
|
|
|
|
|
|
cls.InputSpec = {**cls.InputSpec, **attr._pending_input_specs}
|
|
|
|
|
|
|
|
|
|
|
|
# 收集 OutputSpec
|
|
|
|
|
|
if hasattr(attr, '_pending_output_specs'):
|
|
|
|
|
|
if not hasattr(cls, 'OutputSpec') or cls.OutputSpec is TraceNode.OutputSpec:
|
|
|
|
|
|
cls.OutputSpec = {}
|
|
|
|
|
|
cls.OutputSpec = {**cls.OutputSpec, **attr._pending_output_specs}
|
|
|
|
|
|
|
|
|
|
|
|
# 收集 ParamSpec
|
|
|
|
|
|
if hasattr(attr, '_pending_param_specs'):
|
|
|
|
|
|
if not hasattr(cls, 'ParamSpec') or cls.ParamSpec is TraceNode.ParamSpec:
|
|
|
|
|
|
cls.ParamSpec = {}
|
|
|
|
|
|
cls.ParamSpec = {**cls.ParamSpec, **attr._pending_param_specs}
|
|
|
|
|
|
|
|
|
|
|
|
# 收集 ContextSpec
|
|
|
|
|
|
if hasattr(attr, '_pending_context_specs'):
|
|
|
|
|
|
if not hasattr(cls, 'ContextSpec') or cls.ContextSpec is TraceNode.ContextSpec:
|
|
|
|
|
|
cls.ContextSpec = {}
|
|
|
|
|
|
cls.ContextSpec = {**cls.ContextSpec, **attr._pending_context_specs}
|
|
|
|
|
|
|
|
|
|
|
|
return cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============= 特殊节点类(v2.0 新增,用于函数节点系统)=============
|
|
|
|
|
|
|
|
|
|
|
|
class InputNode(TraceNode, ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
输入节点 - 子工作流的入口
|
|
|
|
|
|
|
|
|
|
|
|
在工作流中的作用:
|
|
|
|
|
|
- 将外部输入映射到工作流内部
|
|
|
|
|
|
- 不执行任何业务逻辑
|
|
|
|
|
|
- 输出接收到的输入数据
|
|
|
|
|
|
|
|
|
|
|
|
使用场景:
|
|
|
|
|
|
函数节点需要从外部接收参数,输入节点作为入口
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
NODE_TYPE = NodeType.INPUT
|
|
|
|
|
|
CATEGORY = "Meta/Input"
|
|
|
|
|
|
DISPLAY_NAME = "输入"
|
|
|
|
|
|
DESCRIPTION = "工作流输入入口"
|
|
|
|
|
|
|
|
|
|
|
|
# 输入节点不需要输入端口(只有输出)
|
|
|
|
|
|
InputSpec = {}
|
|
|
|
|
|
OutputSpec = {} # 由工作流动态定义
|
|
|
|
|
|
ParamSpec = {} # 输入节点通常没有参数
|
|
|
|
|
|
ContextSpec = {}
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
直接返回接收到的输入数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"outputs": inputs,
|
|
|
|
|
|
"context": context or {}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
|
|
|
|
|
|
"""输入节点不需要验证"""
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OutputNode(TraceNode, ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
输出节点 - 子工作流的出口
|
|
|
|
|
|
|
|
|
|
|
|
在工作流中的作用:
|
|
|
|
|
|
- 将工作流内部结果映射到外部
|
|
|
|
|
|
- 不执行任何业务逻辑
|
|
|
|
|
|
- 直接返回接收到的数据
|
|
|
|
|
|
|
|
|
|
|
|
使用场景:
|
|
|
|
|
|
函数节点需要向外部返回结果,输出节点作为出口
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
NODE_TYPE = NodeType.OUTPUT
|
|
|
|
|
|
CATEGORY = "Meta/Output"
|
|
|
|
|
|
DISPLAY_NAME = "输出"
|
|
|
|
|
|
DESCRIPTION = "工作流输出出口"
|
|
|
|
|
|
|
|
|
|
|
|
# 输出节点只有输入端口(没有输出)
|
|
|
|
|
|
InputSpec = {} # 由工作流动态定义
|
|
|
|
|
|
OutputSpec = {} # 输出节点没有后续输出
|
|
|
|
|
|
ParamSpec = {}
|
|
|
|
|
|
ContextSpec = {}
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
直接返回接收到的数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"outputs": inputs,
|
|
|
|
|
|
"context": context or {}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
|
|
|
|
|
|
"""输出节点接受任何输入"""
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FunctionNode(TraceNode, ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
函数节点 - 由子工作流包装而成的可复用节点
|
|
|
|
|
|
|
|
|
|
|
|
在工作流中的作用:
|
|
|
|
|
|
- 将一个完整的子工作流包装成单个节点
|
|
|
|
|
|
- 支持无限嵌套(函数节点内可包含函数节点)
|
|
|
|
|
|
- 输入/输出映射到子工作流的 InputNode/OutputNode
|
|
|
|
|
|
|
|
|
|
|
|
使用场景:
|
|
|
|
|
|
创建可复用的工作流模板,如"数据清洗"、"特征提取"等
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
NODE_TYPE = NodeType.FUNCTION
|
|
|
|
|
|
CATEGORY = "Meta/Function"
|
|
|
|
|
|
DISPLAY_NAME = "函数"
|
|
|
|
|
|
DESCRIPTION = "可复用的子工作流"
|
|
|
|
|
|
|
|
|
|
|
|
# 函数节点的输入输出由子工作流定义
|
|
|
|
|
|
InputSpec = {}
|
|
|
|
|
|
OutputSpec = {}
|
|
|
|
|
|
ParamSpec = {}
|
|
|
|
|
|
ContextSpec = {}
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
函数节点的执行由 WorkflowExecutor 负责
|
|
|
|
|
|
这里仅作为占位符
|
|
|
|
|
|
"""
|
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
|
"函数节点必须由 WorkflowExecutor 执行,"
|
|
|
|
|
|
"不能直接调用 process() 方法"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class WorkflowPackager:
|
|
|
|
|
|
"""工作流打包器 - 将子工作流打包为函数节点"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def validate_function_workflow(
|
|
|
|
|
|
nodes: List[Dict],
|
|
|
|
|
|
edges: List[Dict]
|
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
验证工作流是否可以打包为函数节点
|
|
|
|
|
|
|
|
|
|
|
|
要求:
|
|
|
|
|
|
1. 必须包含至少一个InputNode
|
|
|
|
|
|
2. 必须包含至少一个OutputNode
|
|
|
|
|
|
3. 所有节点都必须是可连接的
|
|
|
|
|
|
4. 不能有孤立节点
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(valid, error_message)
|
|
|
|
|
|
"""
|
|
|
|
|
|
node_ids = {n["id"] for n in nodes}
|
|
|
|
|
|
|
|
|
|
|
|
def _is_input_node(n: Dict) -> bool:
|
|
|
|
|
|
# 严格检查 type 是否为 NodeType.INPUT.value
|
|
|
|
|
|
t = n.get("type")
|
|
|
|
|
|
if t == NodeType.INPUT.value:
|
|
|
|
|
|
return True
|
|
|
|
|
|
# 向上兼容:如果提供了实现类名或 class_name,使用关键字判断
|
|
|
|
|
|
c = str(n.get("class_name", n.get("class", ""))).lower()
|
|
|
|
|
|
return "input" in c
|
|
|
|
|
|
|
|
|
|
|
|
def _is_output_node(n: Dict) -> bool:
|
|
|
|
|
|
t = n.get("type")
|
|
|
|
|
|
if t == NodeType.OUTPUT.value:
|
|
|
|
|
|
return True
|
|
|
|
|
|
c = str(n.get("class_name", n.get("class", ""))).lower()
|
|
|
|
|
|
return "output" in c
|
|
|
|
|
|
|
|
|
|
|
|
has_input = any(_is_input_node(n) for n in nodes)
|
|
|
|
|
|
has_output = any(_is_output_node(n) for n in nodes)
|
|
|
|
|
|
|
|
|
|
|
|
if not has_input:
|
|
|
|
|
|
return False, "函数节点工作流必须包含至少一个InputNode"
|
|
|
|
|
|
if not has_output:
|
|
|
|
|
|
return False, "函数节点工作流必须包含至少一个OutputNode"
|
|
|
|
|
|
|
|
|
|
|
|
# 检查所有连线的节点存在性
|
|
|
|
|
|
for edge in edges:
|
|
|
|
|
|
src = edge.get("source")
|
|
|
|
|
|
tgt = edge.get("target")
|
|
|
|
|
|
if src not in node_ids or tgt not in node_ids:
|
|
|
|
|
|
return False, f"连线引用不存在的节点: {src} → {tgt}"
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否有孤立节点(可选)
|
|
|
|
|
|
connected_nodes = set()
|
|
|
|
|
|
for edge in edges:
|
|
|
|
|
|
connected_nodes.add(edge.get("source"))
|
|
|
|
|
|
connected_nodes.add(edge.get("target"))
|
|
|
|
|
|
|
|
|
|
|
|
# 输入输出节点可以孤立(作为入口/出口)
|
|
|
|
|
|
isolated = node_ids - connected_nodes
|
|
|
|
|
|
# 如果有孤立节点,允许它们仅当它们是明确定义为 Input/Output
|
|
|
|
|
|
for node_id in isolated:
|
|
|
|
|
|
node = next((n for n in nodes if n.get("id") == node_id), None)
|
|
|
|
|
|
if node is None:
|
|
|
|
|
|
continue
|
|
|
|
|
|
t = node.get("type")
|
|
|
|
|
|
c = str(node.get("class_name", node.get("class", ""))).lower()
|
|
|
|
|
|
if t not in (NodeType.INPUT.value, NodeType.OUTPUT.value) and not ("input" in c or "output" in c):
|
|
|
|
|
|
# 允许孤立的普通节点(可能是后续连接)
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
return True, ""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def package_as_function(
|
|
|
|
|
|
node_id: str,
|
|
|
|
|
|
nodes: List[Dict],
|
|
|
|
|
|
edges: List[Dict],
|
|
|
|
|
|
display_name: str = "",
|
|
|
|
|
|
description: str = ""
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
将工作流打包为函数节点
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node_id: 新函数节点的ID
|
|
|
|
|
|
nodes: 子工作流节点
|
|
|
|
|
|
edges: 子工作流连线
|
|
|
|
|
|
display_name: 显示名称
|
|
|
|
|
|
description: 描述
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
函数节点定义
|
|
|
|
|
|
"""
|
|
|
|
|
|
valid, error = WorkflowPackager.validate_function_workflow(nodes, edges)
|
|
|
|
|
|
if not valid:
|
|
|
|
|
|
raise ValueError(f"无法打包工作流: {error}")
|
|
|
|
|
|
|
|
|
|
|
|
# 返回符合执行器严格约定的函数节点定义:
|
|
|
|
|
|
# - type 必须为 NodeType.FUNCTION.value
|
|
|
|
|
|
# - class 字段指定实现类名(这里使用 FunctionNodeImpl)
|
|
|
|
|
|
return {
|
|
|
|
|
|
"id": node_id,
|
|
|
|
|
|
"type": NodeType.FUNCTION.value,
|
|
|
|
|
|
"class_name": "FunctionNodeImpl",
|
|
|
|
|
|
"display_name": display_name or "函数工作流",
|
|
|
|
|
|
"description": description or "通过工作流定义的函数节点",
|
|
|
|
|
|
"params": {},
|
|
|
|
|
|
}
|