""" 特殊节点示例和高级使用场景 """ from ..core.node_base import ( TraceNode, CachePolicy, input_port, output_port, param, context_var, InputNode, OutputNode, FunctionNode ) from ..core.node_registry import NodeRegistry, register_node from typing import Any, Dict, Optional, List # =========================== # 注册特殊节点 # =========================== @register_node class InputNodeImpl(InputNode): """ 输入节点实现 特殊逻辑: - 没有输入端口(inputs 为空) - 从 context(全局上下文)中提取所有字段作为输出端口 - 每个全局上下文字段都映射为一个输出端口 """ async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: # 如果没有连线到此节点,inputs 会是空的 # 我们从 context(全局上下文)中提取数据 # context 中的数据会由工作流执行器自动传递 # 返回整个全局上下文作为输出(严格返回 dict) # 如果 context 是 ExecutionContext 实例,读取其 global_context。 if context is None: return {} # Duck-typing: prefer `global_context` attribute when present gc = None try: if hasattr(context, "global_context") and isinstance(getattr(context, "global_context"), dict): gc = dict(getattr(context, "global_context")) except Exception: gc = None if gc is None and isinstance(context, dict): gc = dict(context) return gc or {} @register_node class OutputNodeImpl(OutputNode): """输出节点实现""" @input_port("input", "Any", description="要输出的数据") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: return inputs @register_node class FunctionNodeImpl(FunctionNode): """函数节点实现""" async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: # 由 AdvancedWorkflowExecutor 处理 return inputs # =========================== # 示例:支持数组操作的节点 # =========================== @register_node class ArrayMapNode(TraceNode): """ 数组映射节点 - 支持升维操作 场景: - 输入:整数数组 [1, 2, 3] - 操作:map(x => x * 2) - 输出:[2, 4, 6] 通过升维机制实现: AddNode作为内部逻辑,升维(EXPAND)处理数组遍历 """ NODE_TYPE = "ArrayMapNode" CATEGORY = "Array/Transform" DISPLAY_NAME = "数组映射" DESCRIPTION = "对数组中的每个元素应用操作" @input_port("values", "Array[Number]", description="输入数组") @output_port("mapped", "Array[Number]", description="映射后的数组") @param("multiplier", "Number", default=1, description="乘数") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: values = inputs.get("values", []) multiplier = self.get_param("multiplier", 1) if not isinstance(values, list): values = [values] mapped = [v * multiplier for v in values] return {"mapped": mapped} @register_node class ArrayConcatNode(TraceNode): """ 数组连接节点 - 支持多线汇聚 场景: - 输入1:[1, 2] (来自节点A) - 输入2:[3, 4] (来自节点B) - 操作:concat - 输出:[1, 2, 3, 4] 通过BROADCAST机制实现: 多条线汇聚到数组输入,自动打包为嵌套数组再展平 """ NODE_TYPE = "ArrayConcatNode" CATEGORY = "Array/Combine" DISPLAY_NAME = "数组连接" DESCRIPTION = "将多个数组合并为一个" @input_port("arrays", "Array[Array]", description="要连接的数组们") @output_port("result", "Array", description="连接后的数组") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: arrays_input = inputs.get("arrays", []) # 展平:如果输入是 [[1,2], [3,4]],输出 [1,2,3,4] result = [] for arr in arrays_input: if isinstance(arr, list): result.extend(arr) else: result.append(arr) return {"result": result} @register_node class ArrayFilterNode(TraceNode): """ 数组过滤节点 - 数组到数组 场景: - 输入:[1, 2, 3, 4, 5] - 操作:filter(x > 2) - 输出:[3, 4, 5] """ NODE_TYPE = "ArrayFilterNode" CATEGORY = "Array/Filter" DISPLAY_NAME = "数组过滤" DESCRIPTION = "根据条件过滤数组元素" @input_port("values", "Array[Number]", description="输入数组") @output_port("filtered", "Array[Number]", description="过滤后的数组") @param("threshold", "Number", default=0, description="阈值(保留 > threshold 的元素)") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: values = inputs.get("values", []) threshold = self.get_param("threshold", 0) if not isinstance(values, list): values = [values] filtered = [v for v in values if v > threshold] return {"filtered": filtered} @register_node class ArrayReduceNode(TraceNode): """ 数组规约节点 - 数组到标量 场景: - 输入:[1, 2, 3, 4, 5] - 操作:sum(reduce) - 输出:15 """ NODE_TYPE = "ArrayReduceNode" CATEGORY = "Array/Reduce" DISPLAY_NAME = "数组规约" DESCRIPTION = "将数组规约为单个值" @input_port("values", "Array[Number]", description="输入数组") @output_port("result", "Number", description="规约后的值") @param("operation", "String", default="sum", description="操作类型:sum、product、max、min") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: values = inputs.get("values", []) operation = self.get_param("operation", "sum") if not isinstance(values, list): values = [values] if operation == "sum": result = sum(values) elif operation == "product": result = 1 for v in values: result *= v elif operation == "max": result = max(values) if values else 0 elif operation == "min": result = min(values) if values else 0 else: result = sum(values) return {"result": result} @register_node class BroadcastNode(TraceNode): """ 广播节点 - 标量到数组 场景: - 输入:5 - 操作:broadcast(3次) - 输出:[5, 5, 5] """ NODE_TYPE = "BroadcastNode" CATEGORY = "Array/Broadcast" DISPLAY_NAME = "广播" DESCRIPTION = "将单个值广播为数组" @input_port("value", "Any", description="要广播的值") @output_port("broadcast", "Array", description="广播后的数组") @param("count", "Integer", default=3, description="广播次数") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: value = inputs.get("value") count = self.get_param("count", 3) result = [value] * count return {"broadcast": result} @register_node class ArrayZipNode(TraceNode): """ 数组拉链节点 - 多数组合并 场景: - 输入1:[1, 2, 3] - 输入2:['a', 'b', 'c'] - 操作:zip - 输出:[[1, 'a'], [2, 'b'], [3, 'c']] """ NODE_TYPE = "ArrayZipNode" CATEGORY = "Array/Combine" DISPLAY_NAME = "数组拉链" DESCRIPTION = "将多个数组按位置组合" @input_port("array1", "Array", description="第一个数组") @input_port("array2", "Array", description="第二个数组") @output_port("zipped", "Array[Array]", description="拉链后的数组") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: arr1 = inputs.get("array1", []) arr2 = inputs.get("array2", []) if not isinstance(arr1, list): arr1 = [arr1] if not isinstance(arr2, list): arr2 = [arr2] # 拉链操作 zipped = list(zip(arr1, arr2)) result = [list(pair) for pair in zipped] return {"zipped": result} @register_node class ConditionalBranchNode(TraceNode): """ 条件分支节点 - 根据条件选择输出 场景: - 输入:值 - 条件:值 > 10 - 输出1:值(如果满足条件) - 输出2:值(如果不满足条件) """ NODE_TYPE = "ConditionalBranchNode" CATEGORY = "Control/Branch" DISPLAY_NAME = "条件分支" DESCRIPTION = "根据条件选择不同的输出路径" @input_port("value", "Any", description="输入值") @output_port("true_output", "Any", description="条件为真时的输出") @output_port("false_output", "Any", description="条件为假时的输出") @param("threshold", "Number", default=0, description="阈值") async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]: value = inputs.get("value", 0) threshold = self.get_param("threshold", 0) outputs = {} if value > threshold: outputs["true_output"] = value outputs["false_output"] = None else: outputs["true_output"] = None outputs["false_output"] = value return outputs # =========================== # 示例:工作流打包为函数节点 # =========================== """ 示例工作流定义(可打包为函数节点): { "id": "multiply_and_sum", "type": "FunctionNode", "display_name": "乘积求和", "description": "先将数组中的每个数乘以2,再求和", "sub_workflow": { "nodes": [ { "id": "input", "type": "InputNodeImpl", "params": {} }, { "id": "multiply", "type": "ArrayMapNode", "params": {"multiplier": 2} }, { "id": "sum", "type": "ArrayReduceNode", "params": {"operation": "sum"} }, { "id": "output", "type": "OutputNodeImpl", "params": {} } ], "edges": [ { "source": "input", "sourcePort": "output", "target": "multiply", "targetPort": "values", }, { "source": "multiply", "sourcePort": "mapped", "target": "sum", "targetPort": "values", }, { "source": "sum", "sourcePort": "result", "target": "output", "targetPort": "input", } ] } } 使用方式: 1. 从上面的定义创建 AdvancedWorkflowGraph 2. 调用 AdvancedWorkflowExecutor.execute() 3. 函数节点会自动展开子工作流并执行 """ __all__ = [ "InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl", "ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode", "ArrayReduceNode", "BroadcastNode", "ArrayZipNode", "ConditionalBranchNode" ]