381 lines
12 KiB
Python
381 lines
12 KiB
Python
"""
|
||
特殊节点示例和高级使用场景
|
||
"""
|
||
from ..core.node_base import (
|
||
TraceNode, CachePolicy, input_port, output_port, param, context_var,
|
||
InputNode, OutputNode, FunctionNode
|
||
)
|
||
from ..core.node_registry import NodeRegistry, register_node
|
||
from typing import Any, Dict, Optional, List
|
||
|
||
|
||
# ===========================
|
||
# 注册特殊节点
|
||
# ===========================
|
||
|
||
@register_node
|
||
class InputNodeImpl(InputNode):
|
||
"""
|
||
输入节点实现
|
||
|
||
特殊逻辑:
|
||
- 没有输入端口(inputs 为空)
|
||
- 从 context(全局上下文)中提取所有字段作为输出端口
|
||
- 每个全局上下文字段都映射为一个输出端口
|
||
"""
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
# 如果没有连线到此节点,inputs 会是空的
|
||
# 我们从 context(全局上下文)中提取数据
|
||
# context 中的数据会由工作流执行器自动传递
|
||
|
||
# 返回整个全局上下文作为输出(严格返回 dict)
|
||
# 如果 context 是 ExecutionContext 实例,读取其 global_context。
|
||
if context is None:
|
||
return {}
|
||
|
||
# Duck-typing: prefer `global_context` attribute when present
|
||
gc = None
|
||
try:
|
||
if hasattr(context, "global_context") and isinstance(getattr(context, "global_context"), dict):
|
||
gc = dict(getattr(context, "global_context"))
|
||
except Exception:
|
||
gc = None
|
||
|
||
if gc is None and isinstance(context, dict):
|
||
gc = dict(context)
|
||
|
||
return gc or {}
|
||
|
||
@register_node
|
||
class OutputNodeImpl(OutputNode):
|
||
"""输出节点实现"""
|
||
@input_port("input", "Any", description="要输出的数据")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
return inputs
|
||
|
||
|
||
@register_node
|
||
class FunctionNodeImpl(FunctionNode):
|
||
"""函数节点实现"""
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
# 由 AdvancedWorkflowExecutor 处理
|
||
return inputs
|
||
|
||
|
||
# ===========================
|
||
# 示例:支持数组操作的节点
|
||
# ===========================
|
||
|
||
@register_node
|
||
class ArrayMapNode(TraceNode):
|
||
"""
|
||
数组映射节点 - 支持升维操作
|
||
|
||
场景:
|
||
- 输入:整数数组 [1, 2, 3]
|
||
- 操作:map(x => x * 2)
|
||
- 输出:[2, 4, 6]
|
||
|
||
通过升维机制实现:
|
||
AddNode作为内部逻辑,升维(EXPAND)处理数组遍历
|
||
"""
|
||
NODE_TYPE = "ArrayMapNode"
|
||
CATEGORY = "Array/Transform"
|
||
DISPLAY_NAME = "数组映射"
|
||
DESCRIPTION = "对数组中的每个元素应用操作"
|
||
|
||
@input_port("values", "Array[Number]", description="输入数组")
|
||
@output_port("mapped", "Array[Number]", description="映射后的数组")
|
||
@param("multiplier", "Number", default=1, description="乘数")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
values = inputs.get("values", [])
|
||
multiplier = self.get_param("multiplier", 1)
|
||
|
||
if not isinstance(values, list):
|
||
values = [values]
|
||
|
||
mapped = [v * multiplier for v in values]
|
||
|
||
return {"mapped": mapped}
|
||
|
||
|
||
@register_node
|
||
class ArrayConcatNode(TraceNode):
|
||
"""
|
||
数组连接节点 - 支持多线汇聚
|
||
|
||
场景:
|
||
- 输入1:[1, 2] (来自节点A)
|
||
- 输入2:[3, 4] (来自节点B)
|
||
- 操作:concat
|
||
- 输出:[1, 2, 3, 4]
|
||
|
||
通过BROADCAST机制实现:
|
||
多条线汇聚到数组输入,自动打包为嵌套数组再展平
|
||
"""
|
||
NODE_TYPE = "ArrayConcatNode"
|
||
CATEGORY = "Array/Combine"
|
||
DISPLAY_NAME = "数组连接"
|
||
DESCRIPTION = "将多个数组合并为一个"
|
||
|
||
@input_port("arrays", "Array[Array]", description="要连接的数组们")
|
||
@output_port("result", "Array", description="连接后的数组")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
arrays_input = inputs.get("arrays", [])
|
||
|
||
# 展平:如果输入是 [[1,2], [3,4]],输出 [1,2,3,4]
|
||
result = []
|
||
for arr in arrays_input:
|
||
if isinstance(arr, list):
|
||
result.extend(arr)
|
||
else:
|
||
result.append(arr)
|
||
|
||
return {"result": result}
|
||
|
||
|
||
@register_node
|
||
class ArrayFilterNode(TraceNode):
|
||
"""
|
||
数组过滤节点 - 数组到数组
|
||
|
||
场景:
|
||
- 输入:[1, 2, 3, 4, 5]
|
||
- 操作:filter(x > 2)
|
||
- 输出:[3, 4, 5]
|
||
"""
|
||
NODE_TYPE = "ArrayFilterNode"
|
||
CATEGORY = "Array/Filter"
|
||
DISPLAY_NAME = "数组过滤"
|
||
DESCRIPTION = "根据条件过滤数组元素"
|
||
|
||
@input_port("values", "Array[Number]", description="输入数组")
|
||
@output_port("filtered", "Array[Number]", description="过滤后的数组")
|
||
@param("threshold", "Number", default=0, description="阈值(保留 > threshold 的元素)")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
values = inputs.get("values", [])
|
||
threshold = self.get_param("threshold", 0)
|
||
|
||
if not isinstance(values, list):
|
||
values = [values]
|
||
|
||
filtered = [v for v in values if v > threshold]
|
||
|
||
return {"filtered": filtered}
|
||
|
||
|
||
@register_node
|
||
class ArrayReduceNode(TraceNode):
|
||
"""
|
||
数组规约节点 - 数组到标量
|
||
|
||
场景:
|
||
- 输入:[1, 2, 3, 4, 5]
|
||
- 操作:sum(reduce)
|
||
- 输出:15
|
||
"""
|
||
NODE_TYPE = "ArrayReduceNode"
|
||
CATEGORY = "Array/Reduce"
|
||
DISPLAY_NAME = "数组规约"
|
||
DESCRIPTION = "将数组规约为单个值"
|
||
|
||
@input_port("values", "Array[Number]", description="输入数组")
|
||
@output_port("result", "Number", description="规约后的值")
|
||
@param("operation", "String", default="sum", description="操作类型:sum、product、max、min")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
values = inputs.get("values", [])
|
||
operation = self.get_param("operation", "sum")
|
||
|
||
if not isinstance(values, list):
|
||
values = [values]
|
||
|
||
if operation == "sum":
|
||
result = sum(values)
|
||
elif operation == "product":
|
||
result = 1
|
||
for v in values:
|
||
result *= v
|
||
elif operation == "max":
|
||
result = max(values) if values else 0
|
||
elif operation == "min":
|
||
result = min(values) if values else 0
|
||
else:
|
||
result = sum(values)
|
||
|
||
return {"result": result}
|
||
|
||
|
||
@register_node
|
||
class BroadcastNode(TraceNode):
|
||
"""
|
||
广播节点 - 标量到数组
|
||
|
||
场景:
|
||
- 输入:5
|
||
- 操作:broadcast(3次)
|
||
- 输出:[5, 5, 5]
|
||
"""
|
||
NODE_TYPE = "BroadcastNode"
|
||
CATEGORY = "Array/Broadcast"
|
||
DISPLAY_NAME = "广播"
|
||
DESCRIPTION = "将单个值广播为数组"
|
||
|
||
@input_port("value", "Any", description="要广播的值")
|
||
@output_port("broadcast", "Array", description="广播后的数组")
|
||
@param("count", "Integer", default=3, description="广播次数")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
value = inputs.get("value")
|
||
count = self.get_param("count", 3)
|
||
|
||
result = [value] * count
|
||
|
||
return {"broadcast": result}
|
||
|
||
|
||
@register_node
|
||
class ArrayZipNode(TraceNode):
|
||
"""
|
||
数组拉链节点 - 多数组合并
|
||
|
||
场景:
|
||
- 输入1:[1, 2, 3]
|
||
- 输入2:['a', 'b', 'c']
|
||
- 操作:zip
|
||
- 输出:[[1, 'a'], [2, 'b'], [3, 'c']]
|
||
"""
|
||
NODE_TYPE = "ArrayZipNode"
|
||
CATEGORY = "Array/Combine"
|
||
DISPLAY_NAME = "数组拉链"
|
||
DESCRIPTION = "将多个数组按位置组合"
|
||
|
||
@input_port("array1", "Array", description="第一个数组")
|
||
@input_port("array2", "Array", description="第二个数组")
|
||
@output_port("zipped", "Array[Array]", description="拉链后的数组")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
arr1 = inputs.get("array1", [])
|
||
arr2 = inputs.get("array2", [])
|
||
|
||
if not isinstance(arr1, list):
|
||
arr1 = [arr1]
|
||
if not isinstance(arr2, list):
|
||
arr2 = [arr2]
|
||
|
||
# 拉链操作
|
||
zipped = list(zip(arr1, arr2))
|
||
result = [list(pair) for pair in zipped]
|
||
|
||
return {"zipped": result}
|
||
|
||
|
||
@register_node
|
||
class ConditionalBranchNode(TraceNode):
|
||
"""
|
||
条件分支节点 - 根据条件选择输出
|
||
|
||
场景:
|
||
- 输入:值
|
||
- 条件:值 > 10
|
||
- 输出1:值(如果满足条件)
|
||
- 输出2:值(如果不满足条件)
|
||
"""
|
||
NODE_TYPE = "ConditionalBranchNode"
|
||
CATEGORY = "Control/Branch"
|
||
DISPLAY_NAME = "条件分支"
|
||
DESCRIPTION = "根据条件选择不同的输出路径"
|
||
|
||
@input_port("value", "Any", description="输入值")
|
||
@output_port("true_output", "Any", description="条件为真时的输出")
|
||
@output_port("false_output", "Any", description="条件为假时的输出")
|
||
@param("threshold", "Number", default=0, description="阈值")
|
||
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
value = inputs.get("value", 0)
|
||
threshold = self.get_param("threshold", 0)
|
||
|
||
outputs = {}
|
||
if value > threshold:
|
||
outputs["true_output"] = value
|
||
outputs["false_output"] = None
|
||
else:
|
||
outputs["true_output"] = None
|
||
outputs["false_output"] = value
|
||
|
||
return outputs
|
||
|
||
|
||
# ===========================
|
||
# 示例:工作流打包为函数节点
|
||
# ===========================
|
||
|
||
"""
|
||
示例工作流定义(可打包为函数节点):
|
||
|
||
{
|
||
"id": "multiply_and_sum",
|
||
"type": "FunctionNode",
|
||
"display_name": "乘积求和",
|
||
"description": "先将数组中的每个数乘以2,再求和",
|
||
"sub_workflow": {
|
||
"nodes": [
|
||
{
|
||
"id": "input",
|
||
"type": "InputNodeImpl",
|
||
"params": {}
|
||
},
|
||
{
|
||
"id": "multiply",
|
||
"type": "ArrayMapNode",
|
||
"params": {"multiplier": 2}
|
||
},
|
||
{
|
||
"id": "sum",
|
||
"type": "ArrayReduceNode",
|
||
"params": {"operation": "sum"}
|
||
},
|
||
{
|
||
"id": "output",
|
||
"type": "OutputNodeImpl",
|
||
"params": {}
|
||
}
|
||
],
|
||
"edges": [
|
||
{
|
||
"source": "input",
|
||
"sourcePort": "output",
|
||
"target": "multiply",
|
||
"targetPort": "values",
|
||
},
|
||
{
|
||
"source": "multiply",
|
||
"sourcePort": "mapped",
|
||
"target": "sum",
|
||
"targetPort": "values",
|
||
},
|
||
{
|
||
"source": "sum",
|
||
"sourcePort": "result",
|
||
"target": "output",
|
||
"targetPort": "input",
|
||
}
|
||
]
|
||
}
|
||
}
|
||
|
||
使用方式:
|
||
1. 从上面的定义创建 AdvancedWorkflowGraph
|
||
2. 调用 AdvancedWorkflowExecutor.execute()
|
||
3. 函数节点会自动展开子工作流并执行
|
||
"""
|
||
|
||
__all__ = [
|
||
"InputNodeImpl",
|
||
"OutputNodeImpl",
|
||
"FunctionNodeImpl",
|
||
"ArrayMapNode",
|
||
"ArrayConcatNode",
|
||
"ArrayFilterNode",
|
||
"ArrayReduceNode",
|
||
"BroadcastNode",
|
||
"ArrayZipNode",
|
||
"ConditionalBranchNode"
|
||
]
|