""" TraceStudio 节点基类 (v2.0) 四大属性规范:InputSpec, OutputSpec, ParamSpec, ContextSpec 支持高级特性:维度转换、特殊节点、函数嵌套 """ from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple from dataclasses import dataclass, field from .cache_manager import NodeCacheAdapter import inspect class NodeType(Enum): """节点类型""" NORMAL = "normal" # 普通节点 INPUT = "input" # 输入节点 OUTPUT = "output" # 输出节点 FUNCTION = "function" # 函数节点(由子工作流包装而成) COMPOSITE = "composite" # 复合/聚合节点(如 ConcatNode) class DimensionMode(Enum): """维度转换模式(v2.0 新增)""" NONE = "none" # 无转换 EXPAND = "down" # 升维:数组→单个元素(遍历) COLLAPSE = "up" # 降维:单个元素→数组(打包) BROADCAST = "broadcast" # 广播:多条线→数组(展开+打包) @classmethod def from_str(cls, label: str): # 统一转大写处理,防止前端大小写传错导致崩溃 try: print("DimensionMode::from_str", label, label.lower()) en = cls[label.lower()] print("DimensionMode::from_str", label, en) return en except (KeyError, AttributeError): return cls.NONE # 提供一个默认值 class CachePolicy(Enum): """缓存策略""" NONE = "none" # 不缓存 MEMORY = "memory" # 内存缓存 DISK = "disk" # 磁盘缓存 @dataclass class EdgeMetadata: """连线元数据(v2.0 新增,用于连线分类和维度转换)""" source_node: str # 源节点ID source_port: str # 源端口名 target_node: str # 目标节点ID target_port: str # 目标端口名 dimension_mode: DimensionMode = DimensionMode.NONE def to_dict(self) -> Dict: return { "source_node": self.source_node, "source_port": self.source_port, "target_node": self.target_node, "target_port": self.target_port, "dimension_mode": self.dimension_mode } @dataclass class NodeMetadata: """节点元数据(v2.0 新增,用于函数节点系统)""" node_id: str # 节点ID node_type: str # 节点类型(来自NodeRegistry) display_name: Optional[str] = None description: Optional[str] = None params: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict: return { "node_id": self.node_id, "node_type": self.node_type, "display_name": self.display_name, "description": self.description, "params": self.params, } class TraceNode(ABC): """ TraceStudio 节点基类 (v2.0) 四大属性: - InputSpec: 主输入(必须通过连线) - OutputSpec: 主输出(供下游连接) - ParamSpec: 控制参数(面板配置) - ContextSpec: 上下文/元数据(自动广播) 示例用法: @register_node class AddNode(TraceNode): CATEGORY = "Math/Basic" DISPLAY_NAME = "加法" DESCRIPTION = "计算两个数的和" # 自动收集的四大属性 @input_port("a", "Number", description="加数A") @input_port("b", "Number", description="加数B") @output_port("result", "Number", description="和") @param("offset", "Number", default=0, description="偏移量") @context_var("count", "Integer", description="计算次数") def process(self, inputs, context): a = inputs["a"] b = inputs["b"] offset = self.get_param("offset", 0) result = a + b + offset # 返回输出和上下文 return { "outputs": {"result": result}, "context": {"count": 1} } """ # ============= 元数据定义 (类属性) ============= CATEGORY = "Uncategorized" # 分类路径,如 "Data/Transform" DISPLAY_NAME = None # 显示名称,None 则使用类名 DESCRIPTION = "" # 节点描述 ICON = "📦" # 图标(Emoji 或图标名) VERSION = "1.0.0" # 版本号 AUTHOR = "" # 作者 # 节点行为配置 NODE_TYPE = NodeType.NORMAL # 节点类型 CACHE_POLICY = CachePolicy.NONE # 缓存策略 SUPPORTS_PREVIEW = True # 是否支持预览模式 # ============= v2.0 四大属性(自动收集) ============= # 格式: {"port_name": (data_type, config_dict)} InputSpec: Dict[str, tuple] = {} OutputSpec: Dict[str, tuple] = {} ParamSpec: Dict[str, tuple] = {} ContextSpec: Dict[str, tuple] = {} def __init__(self, node_id: str, params: Optional[Dict] = None): """ 初始化节点实例 Args: node_id: 节点实例 ID(前端生成) params: 参数值字典(仅 ParamSpec 中的参数) """ self.node_id = node_id self.params = params or {} self._cache = None @abstractmethod def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 处理节点核心逻辑(子类必须实现) Args: inputs: 主输入数据字典 {"port_name": data} 所有数据由连线提供,对应 InputSpec context: 上下文字典 {"$Global.var": value, "$NodeID.var": value} 包含全局变量和上游节点的 ContextSpec Returns: 返回字典,包含两个键: { "outputs": {"port_name": data}, # 主输出,对应 OutputSpec "context": {"var_name": value} # 上下文,对应 ContextSpec } 简化写法(仅返回 outputs): return {"result": data} # 自动转换为 {"outputs": {"result": data}, "context": {}} """ pass def wrap_process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: self.temp_inputs = inputs self.temp_context = context result = {} try: result = self.process(inputs, context) except Exception as e: raise e finally: self.temp_inputs = None self.temp_context = None return result def preview(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None, limit: int = 10) -> Dict[str, Any]: """ 预览模式执行(可选重写) 默认实现:调用 process 后截取数据 子类可重写此方法优化预览逻辑(如提前终止、采样等) Args: inputs: 主输入数据 context: 上下文 limit: 预览数据量限制 Returns: 预览结果 """ if not self.SUPPORTS_PREVIEW: raise NotImplementedError(f"节点 {self.__class__.__name__} 不支持预览") return self.wrap_process(inputs, context) # ============= 辅助方法 ============= def get_param(self, name: str, default: Any = None) -> Any: """ 获取参数值(支持静态值、Context 引用、暴露端口) """ inputs = self.temp_inputs if name in inputs: return inputs[name] #context = self.temp_context #if context and name in context: # return context[name] return self.params.get(name, default) def validate_inputs(self, inputs: Dict[str, Any]) -> bool: """ 验证输入数据完整性 Args: inputs: 输入数据字典 Returns: 验证是否通过 Raises: ValueError: 缺少必需输入或类型不匹配 """ for port_name, (data_type, config) in self.InputSpec.items(): required = config.get("required", True) # 检查必需输入 if required and port_name not in inputs: raise ValueError(f"缺少必需的输入端口: {port_name}") # 检查列表类型 is_list = config.get("list", False) if port_name in inputs and is_list and not isinstance(inputs[port_name], list): raise TypeError(f"端口 {port_name} 需要列表类型数据,但收到 {type(inputs[port_name])}") return True def execute_with_cache(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 带缓存的执行 Args: inputs: 输入数据 context: 上下文 Returns: 执行结果 """ if self.CACHE_POLICY == CachePolicy.NONE: return self.wrap_process(inputs, context) # 生成缓存键 cache_key = self._generate_cache_key(inputs, context) print("***********************",self.__class__.__name__, cache_key) # 检查缓存 cached = self._get_from_cache(cache_key) if cached is not None: print("***********************",self.__class__.__name__, "read cache:", cache_key) return cached # 执行处理 result = self.wrap_process(inputs, context) # 保存缓存 self._save_to_cache(cache_key, result) return result def _generate_cache_key(self, inputs: Dict, context: Optional[Dict] = None) -> str: """生成缓存键""" import hashlib import json cache_data = { "class": self.__class__.__name__, "params": self.params, "inputs": self._serialize_inputs(inputs), } json_str = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(json_str.encode()).hexdigest() def _serialize_inputs(self, inputs: Dict) -> Dict: """序列化输入用于缓存键生成""" import hashlib serialized = {} for key, value in inputs.items(): if isinstance(value, (str, int, float, bool, type(None))): serialized[key] = value elif isinstance(value, list): # 列表转为哈希 serialized[key] = hashlib.md5(str(value).encode()).hexdigest() else: # 复杂对象转为哈希 serialized[key] = hashlib.md5(str(value).encode()).hexdigest() return serialized def _get_from_cache(self, key: str) -> Optional[Dict]: """从缓存获取""" if self._cache is None: self._init_cache() return self._cache.get(key) if self._cache else None def _save_to_cache(self, key: str, value: Dict): """保存到缓存""" if self._cache is None: self._init_cache() if self._cache: self._cache.set(key, value) def _init_cache(self): storage = 'memory' if self.CACHE_POLICY == CachePolicy.MEMORY else 'disk' self._cache = NodeCacheAdapter(storage) # ============= 元数据生成(供前端使用) ============= @classmethod def get_metadata(cls) -> Dict: """ 生成节点元数据(供前端渲染) Returns: 节点元数据字典 """ return { "class_name": cls.__name__, "display_name": cls.DISPLAY_NAME or cls.__name__, "category": cls.CATEGORY, "description": cls.DESCRIPTION, "icon": cls.ICON, "version": cls.VERSION, "author": cls.AUTHOR, "node_type": cls.NODE_TYPE.value if hasattr(cls.NODE_TYPE, 'value') else str(cls.NODE_TYPE), "cache_policy": cls.CACHE_POLICY.value if hasattr(cls.CACHE_POLICY, 'value') else str(cls.CACHE_POLICY), "supports_preview": cls.SUPPORTS_PREVIEW, # 四大属性 "inputs": cls._format_spec(cls.InputSpec, "input"), "outputs": cls._format_spec(cls.OutputSpec, "output"), "params": cls._format_spec(cls.ParamSpec, "param"), "context": cls._format_spec(cls.ContextSpec, "context"), } @classmethod def _format_spec(cls, spec: Dict, spec_type: str) -> List[Dict]: """格式化属性定义为前端格式""" formatted = [] for name, (data_type, config) in spec.items(): item = { "name": name, "type": data_type, **config # 展开配置字典 } formatted.append(item) return formatted def __repr__(self): return f"<{self.__class__.__name__}(id={self.node_id})>" # ============= 装饰器:自动收集属性 ============= def input_port(name: str, data_type: str, **config): """ 输入端口装饰器(自动收集到 InputSpec) Args: name: 端口名称 data_type: 数据类型(如 "DataTable", "Number", "String") **config: 配置项 - description: 描述 - required: 是否必需(默认 True) - list: 是否为列表类型(默认 False) 用法: @input_port("data", "DataTable", description="输入数据") def process(self, inputs, context): ... """ def decorator(func): # 将属性存储在方法对象上,稍后由类装饰器收集 if not hasattr(func, '_pending_input_specs'): func._pending_input_specs = {} func._pending_input_specs[name] = (data_type, config) return func return decorator def output_port(name: str, data_type: str, **config): """ 输出端口装饰器(自动收集到 OutputSpec) 用法与 `input_port` 对称。 """ def decorator(func): if not hasattr(func, '_pending_output_specs'): func._pending_output_specs = {} func._pending_output_specs[name] = (data_type, config) return func return decorator def param(name: str, param_type: str, **config): """ 参数装饰器(自动收集到 ParamSpec) Args: name: 参数名称 param_type: 参数类型(如 "Number", "String", "Boolean", "Dropdown") **config: 配置项 - default: 默认值 - description: 描述 - widget: 控件类型(如 "slider", "text", "dropdown") - min: 最小值(数值类型) - max: 最大值(数值类型) - step: 步长(数值类型) - options: 选项列表(下拉框) - multiline: 多行文本(文本类型) 用法: @param("threshold", "Number", default=0.5, min=0, max=1, step=0.1) def process(self, inputs, context): threshold = self.get_param("threshold") ... """ def decorator(func): if not hasattr(func, '_pending_param_specs'): func._pending_param_specs = {} func._pending_param_specs[name] = (param_type, config) return func return decorator def context_var(name: str, var_type: str, **config): """ 上下文变量装饰器(自动收集到 ContextSpec) Args: name: 变量名称 var_type: 变量类型 **config: 配置项 - description: 描述 用法: @context_var("row_count", "Integer", description="数据行数") def process(self, inputs, context): ... return { "outputs": {...}, "context": {"row_count": 100} } """ def decorator(func): if not hasattr(func, '_pending_context_specs'): func._pending_context_specs = {} func._pending_context_specs[name] = (var_type, config) return func return decorator def auto_collect_specs(cls): """ 自动收集装饰器标记的属性到四大 Spec 此装饰器会在类定义完成后自动调用, 从方法装饰器中收集 _pending_*_specs 并合并到类的 Spec 中 """ # 遍历类的所有方法,收集装饰器标记的属性 for attr_name in dir(cls): try: attr = getattr(cls, attr_name) except AttributeError: continue # 收集 InputSpec if hasattr(attr, '_pending_input_specs'): if not hasattr(cls, 'InputSpec') or cls.InputSpec is TraceNode.InputSpec: cls.InputSpec = {} cls.InputSpec = {**cls.InputSpec, **attr._pending_input_specs} # 收集 OutputSpec if hasattr(attr, '_pending_output_specs'): if not hasattr(cls, 'OutputSpec') or cls.OutputSpec is TraceNode.OutputSpec: cls.OutputSpec = {} cls.OutputSpec = {**cls.OutputSpec, **attr._pending_output_specs} # 收集 ParamSpec if hasattr(attr, '_pending_param_specs'): if not hasattr(cls, 'ParamSpec') or cls.ParamSpec is TraceNode.ParamSpec: cls.ParamSpec = {} cls.ParamSpec = {**cls.ParamSpec, **attr._pending_param_specs} # 收集 ContextSpec if hasattr(attr, '_pending_context_specs'): if not hasattr(cls, 'ContextSpec') or cls.ContextSpec is TraceNode.ContextSpec: cls.ContextSpec = {} cls.ContextSpec = {**cls.ContextSpec, **attr._pending_context_specs} return cls # ============= 特殊节点类(v2.0 新增,用于函数节点系统)============= class InputNode(TraceNode, ABC): """ 输入节点 - 子工作流的入口 在工作流中的作用: - 将外部输入映射到工作流内部 - 不执行任何业务逻辑 - 输出接收到的输入数据 使用场景: 函数节点需要从外部接收参数,输入节点作为入口 """ NODE_TYPE = NodeType.INPUT CATEGORY = "Meta/Input" DISPLAY_NAME = "输入" DESCRIPTION = "工作流输入入口" # 输入节点不需要输入端口(只有输出) InputSpec = {} OutputSpec = {} # 由工作流动态定义 ParamSpec = {} # 输入节点通常没有参数 ContextSpec = {} def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: """ 直接返回接收到的输入数据 """ return { "outputs": inputs, "context": context or {} } def validate_inputs(self, inputs: Dict[str, Any]) -> bool: """输入节点不需要验证""" return True class OutputNode(TraceNode, ABC): """ 输出节点 - 子工作流的出口 在工作流中的作用: - 将工作流内部结果映射到外部 - 不执行任何业务逻辑 - 直接返回接收到的数据 使用场景: 函数节点需要向外部返回结果,输出节点作为出口 """ NODE_TYPE = NodeType.OUTPUT CATEGORY = "Meta/Output" DISPLAY_NAME = "输出" DESCRIPTION = "工作流输出出口" # 输出节点只有输入端口(没有输出) InputSpec = {} # 由工作流动态定义 OutputSpec = {} # 输出节点没有后续输出 ParamSpec = {} ContextSpec = {} def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: """ 直接返回接收到的数据 """ return { "outputs": inputs, "context": context or {} } def validate_inputs(self, inputs: Dict[str, Any]) -> bool: """输出节点接受任何输入""" return True class FunctionNode(TraceNode, ABC): """ 函数节点 - 由子工作流包装而成的可复用节点 在工作流中的作用: - 将一个完整的子工作流包装成单个节点 - 支持无限嵌套(函数节点内可包含函数节点) - 输入/输出映射到子工作流的 InputNode/OutputNode 使用场景: 创建可复用的工作流模板,如"数据清洗"、"特征提取"等 """ NODE_TYPE = NodeType.FUNCTION CATEGORY = "Meta/Function" DISPLAY_NAME = "函数" DESCRIPTION = "可复用的子工作流" # 函数节点的输入输出由子工作流定义 InputSpec = {} OutputSpec = {} ParamSpec = {} ContextSpec = {} def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: """ 函数节点的执行由 WorkflowExecutor 负责 这里仅作为占位符 """ raise NotImplementedError( "函数节点必须由 WorkflowExecutor 执行," "不能直接调用 process() 方法" ) class WorkflowPackager: """工作流打包器 - 将子工作流打包为函数节点""" @staticmethod def validate_function_workflow( nodes: List[Dict], edges: List[Dict] ) -> Tuple[bool, str]: """ 验证工作流是否可以打包为函数节点 要求: 1. 必须包含至少一个InputNode 2. 必须包含至少一个OutputNode 3. 所有节点都必须是可连接的 4. 不能有孤立节点 Returns: (valid, error_message) """ node_ids = {n["id"] for n in nodes} def _is_input_node(n: Dict) -> bool: # 严格检查 type 是否为 NodeType.INPUT.value t = n.get("type") if t == NodeType.INPUT.value: return True # 向上兼容:如果提供了实现类名或 class_name,使用关键字判断 c = str(n.get("class_name", n.get("class", ""))).lower() return "input" in c def _is_output_node(n: Dict) -> bool: t = n.get("type") if t == NodeType.OUTPUT.value: return True c = str(n.get("class_name", n.get("class", ""))).lower() return "output" in c has_input = any(_is_input_node(n) for n in nodes) has_output = any(_is_output_node(n) for n in nodes) if not has_input: return False, "函数节点工作流必须包含至少一个InputNode" if not has_output: return False, "函数节点工作流必须包含至少一个OutputNode" # 检查所有连线的节点存在性 for edge in edges: src = edge.get("source") tgt = edge.get("target") if src not in node_ids or tgt not in node_ids: return False, f"连线引用不存在的节点: {src} → {tgt}" # 检查是否有孤立节点(可选) connected_nodes = set() for edge in edges: connected_nodes.add(edge.get("source")) connected_nodes.add(edge.get("target")) # 输入输出节点可以孤立(作为入口/出口) isolated = node_ids - connected_nodes # 如果有孤立节点,允许它们仅当它们是明确定义为 Input/Output for node_id in isolated: node = next((n for n in nodes if n.get("id") == node_id), None) if node is None: continue t = node.get("type") c = str(node.get("class_name", node.get("class", ""))).lower() if t not in (NodeType.INPUT.value, NodeType.OUTPUT.value) and not ("input" in c or "output" in c): # 允许孤立的普通节点(可能是后续连接) pass return True, "" @staticmethod def package_as_function( node_id: str, nodes: List[Dict], edges: List[Dict], display_name: str = "", description: str = "" ) -> Dict[str, Any]: """ 将工作流打包为函数节点 Args: node_id: 新函数节点的ID nodes: 子工作流节点 edges: 子工作流连线 display_name: 显示名称 description: 描述 Returns: 函数节点定义 """ valid, error = WorkflowPackager.validate_function_workflow(nodes, edges) if not valid: raise ValueError(f"无法打包工作流: {error}") # 返回符合执行器严格约定的函数节点定义: # - type 必须为 NodeType.FUNCTION.value # - class 字段指定实现类名(这里使用 FunctionNodeImpl) return { "id": node_id, "type": NodeType.FUNCTION.value, "class_name": "FunctionNodeImpl", "display_name": display_name or "函数工作流", "description": description or "通过工作流定义的函数节点", "params": {}, }