381 lines
12 KiB
Python
381 lines
12 KiB
Python
|
|
"""
|
|||
|
|
特殊节点示例和高级使用场景
|
|||
|
|
"""
|
|||
|
|
from ..core.node_base import (
|
|||
|
|
TraceNode, CachePolicy, input_port, output_port, param, context_var,
|
|||
|
|
InputNode, OutputNode, FunctionNode
|
|||
|
|
)
|
|||
|
|
from ..core.node_registry import NodeRegistry, register_node
|
|||
|
|
from typing import Any, Dict, Optional, List
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===========================
|
|||
|
|
# 注册特殊节点
|
|||
|
|
# ===========================
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class InputNodeImpl(InputNode):
|
|||
|
|
"""
|
|||
|
|
输入节点实现
|
|||
|
|
|
|||
|
|
特殊逻辑:
|
|||
|
|
- 没有输入端口(inputs 为空)
|
|||
|
|
- 从 context(全局上下文)中提取所有字段作为输出端口
|
|||
|
|
- 每个全局上下文字段都映射为一个输出端口
|
|||
|
|
"""
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
# 如果没有连线到此节点,inputs 会是空的
|
|||
|
|
# 我们从 context(全局上下文)中提取数据
|
|||
|
|
# context 中的数据会由工作流执行器自动传递
|
|||
|
|
|
|||
|
|
# 返回整个全局上下文作为输出(严格返回 dict)
|
|||
|
|
# 如果 context 是 ExecutionContext 实例,读取其 global_context。
|
|||
|
|
if context is None:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
# Duck-typing: prefer `global_context` attribute when present
|
|||
|
|
gc = None
|
|||
|
|
try:
|
|||
|
|
if hasattr(context, "global_context") and isinstance(getattr(context, "global_context"), dict):
|
|||
|
|
gc = dict(getattr(context, "global_context"))
|
|||
|
|
except Exception:
|
|||
|
|
gc = None
|
|||
|
|
|
|||
|
|
if gc is None and isinstance(context, dict):
|
|||
|
|
gc = dict(context)
|
|||
|
|
|
|||
|
|
return gc or {}
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class OutputNodeImpl(OutputNode):
|
|||
|
|
"""输出节点实现"""
|
|||
|
|
@input_port("input", "Any", description="要输出的数据")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
return inputs
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class FunctionNodeImpl(FunctionNode):
|
|||
|
|
"""函数节点实现"""
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
# 由 AdvancedWorkflowExecutor 处理
|
|||
|
|
return inputs
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===========================
|
|||
|
|
# 示例:支持数组操作的节点
|
|||
|
|
# ===========================
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ArrayMapNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
数组映射节点 - 支持升维操作
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入:整数数组 [1, 2, 3]
|
|||
|
|
- 操作:map(x => x * 2)
|
|||
|
|
- 输出:[2, 4, 6]
|
|||
|
|
|
|||
|
|
通过升维机制实现:
|
|||
|
|
AddNode作为内部逻辑,升维(EXPAND)处理数组遍历
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ArrayMapNode"
|
|||
|
|
CATEGORY = "Array/Transform"
|
|||
|
|
DISPLAY_NAME = "数组映射"
|
|||
|
|
DESCRIPTION = "对数组中的每个元素应用操作"
|
|||
|
|
|
|||
|
|
@input_port("values", "Array[Number]", description="输入数组")
|
|||
|
|
@output_port("mapped", "Array[Number]", description="映射后的数组")
|
|||
|
|
@param("multiplier", "Number", default=1, description="乘数")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
values = inputs.get("values", [])
|
|||
|
|
multiplier = self.get_param("multiplier", 1)
|
|||
|
|
|
|||
|
|
if not isinstance(values, list):
|
|||
|
|
values = [values]
|
|||
|
|
|
|||
|
|
mapped = [v * multiplier for v in values]
|
|||
|
|
|
|||
|
|
return {"mapped": mapped}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ArrayConcatNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
数组连接节点 - 支持多线汇聚
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入1:[1, 2] (来自节点A)
|
|||
|
|
- 输入2:[3, 4] (来自节点B)
|
|||
|
|
- 操作:concat
|
|||
|
|
- 输出:[1, 2, 3, 4]
|
|||
|
|
|
|||
|
|
通过BROADCAST机制实现:
|
|||
|
|
多条线汇聚到数组输入,自动打包为嵌套数组再展平
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ArrayConcatNode"
|
|||
|
|
CATEGORY = "Array/Combine"
|
|||
|
|
DISPLAY_NAME = "数组连接"
|
|||
|
|
DESCRIPTION = "将多个数组合并为一个"
|
|||
|
|
|
|||
|
|
@input_port("arrays", "Array[Array]", description="要连接的数组们")
|
|||
|
|
@output_port("result", "Array", description="连接后的数组")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
arrays_input = inputs.get("arrays", [])
|
|||
|
|
|
|||
|
|
# 展平:如果输入是 [[1,2], [3,4]],输出 [1,2,3,4]
|
|||
|
|
result = []
|
|||
|
|
for arr in arrays_input:
|
|||
|
|
if isinstance(arr, list):
|
|||
|
|
result.extend(arr)
|
|||
|
|
else:
|
|||
|
|
result.append(arr)
|
|||
|
|
|
|||
|
|
return {"result": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ArrayFilterNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
数组过滤节点 - 数组到数组
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入:[1, 2, 3, 4, 5]
|
|||
|
|
- 操作:filter(x > 2)
|
|||
|
|
- 输出:[3, 4, 5]
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ArrayFilterNode"
|
|||
|
|
CATEGORY = "Array/Filter"
|
|||
|
|
DISPLAY_NAME = "数组过滤"
|
|||
|
|
DESCRIPTION = "根据条件过滤数组元素"
|
|||
|
|
|
|||
|
|
@input_port("values", "Array[Number]", description="输入数组")
|
|||
|
|
@output_port("filtered", "Array[Number]", description="过滤后的数组")
|
|||
|
|
@param("threshold", "Number", default=0, description="阈值(保留 > threshold 的元素)")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
values = inputs.get("values", [])
|
|||
|
|
threshold = self.get_param("threshold", 0)
|
|||
|
|
|
|||
|
|
if not isinstance(values, list):
|
|||
|
|
values = [values]
|
|||
|
|
|
|||
|
|
filtered = [v for v in values if v > threshold]
|
|||
|
|
|
|||
|
|
return {"filtered": filtered}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ArrayReduceNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
数组规约节点 - 数组到标量
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入:[1, 2, 3, 4, 5]
|
|||
|
|
- 操作:sum(reduce)
|
|||
|
|
- 输出:15
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ArrayReduceNode"
|
|||
|
|
CATEGORY = "Array/Reduce"
|
|||
|
|
DISPLAY_NAME = "数组规约"
|
|||
|
|
DESCRIPTION = "将数组规约为单个值"
|
|||
|
|
|
|||
|
|
@input_port("values", "Array[Number]", description="输入数组")
|
|||
|
|
@output_port("result", "Number", description="规约后的值")
|
|||
|
|
@param("operation", "String", default="sum", description="操作类型:sum、product、max、min")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
values = inputs.get("values", [])
|
|||
|
|
operation = self.get_param("operation", "sum")
|
|||
|
|
|
|||
|
|
if not isinstance(values, list):
|
|||
|
|
values = [values]
|
|||
|
|
|
|||
|
|
if operation == "sum":
|
|||
|
|
result = sum(values)
|
|||
|
|
elif operation == "product":
|
|||
|
|
result = 1
|
|||
|
|
for v in values:
|
|||
|
|
result *= v
|
|||
|
|
elif operation == "max":
|
|||
|
|
result = max(values) if values else 0
|
|||
|
|
elif operation == "min":
|
|||
|
|
result = min(values) if values else 0
|
|||
|
|
else:
|
|||
|
|
result = sum(values)
|
|||
|
|
|
|||
|
|
return {"result": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class BroadcastNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
广播节点 - 标量到数组
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入:5
|
|||
|
|
- 操作:broadcast(3次)
|
|||
|
|
- 输出:[5, 5, 5]
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "BroadcastNode"
|
|||
|
|
CATEGORY = "Array/Broadcast"
|
|||
|
|
DISPLAY_NAME = "广播"
|
|||
|
|
DESCRIPTION = "将单个值广播为数组"
|
|||
|
|
|
|||
|
|
@input_port("value", "Any", description="要广播的值")
|
|||
|
|
@output_port("broadcast", "Array", description="广播后的数组")
|
|||
|
|
@param("count", "Integer", default=3, description="广播次数")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
value = inputs.get("value")
|
|||
|
|
count = self.get_param("count", 3)
|
|||
|
|
|
|||
|
|
result = [value] * count
|
|||
|
|
|
|||
|
|
return {"broadcast": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ArrayZipNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
数组拉链节点 - 多数组合并
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入1:[1, 2, 3]
|
|||
|
|
- 输入2:['a', 'b', 'c']
|
|||
|
|
- 操作:zip
|
|||
|
|
- 输出:[[1, 'a'], [2, 'b'], [3, 'c']]
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ArrayZipNode"
|
|||
|
|
CATEGORY = "Array/Combine"
|
|||
|
|
DISPLAY_NAME = "数组拉链"
|
|||
|
|
DESCRIPTION = "将多个数组按位置组合"
|
|||
|
|
|
|||
|
|
@input_port("array1", "Array", description="第一个数组")
|
|||
|
|
@input_port("array2", "Array", description="第二个数组")
|
|||
|
|
@output_port("zipped", "Array[Array]", description="拉链后的数组")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
arr1 = inputs.get("array1", [])
|
|||
|
|
arr2 = inputs.get("array2", [])
|
|||
|
|
|
|||
|
|
if not isinstance(arr1, list):
|
|||
|
|
arr1 = [arr1]
|
|||
|
|
if not isinstance(arr2, list):
|
|||
|
|
arr2 = [arr2]
|
|||
|
|
|
|||
|
|
# 拉链操作
|
|||
|
|
zipped = list(zip(arr1, arr2))
|
|||
|
|
result = [list(pair) for pair in zipped]
|
|||
|
|
|
|||
|
|
return {"zipped": result}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_node
|
|||
|
|
class ConditionalBranchNode(TraceNode):
|
|||
|
|
"""
|
|||
|
|
条件分支节点 - 根据条件选择输出
|
|||
|
|
|
|||
|
|
场景:
|
|||
|
|
- 输入:值
|
|||
|
|
- 条件:值 > 10
|
|||
|
|
- 输出1:值(如果满足条件)
|
|||
|
|
- 输出2:值(如果不满足条件)
|
|||
|
|
"""
|
|||
|
|
NODE_TYPE = "ConditionalBranchNode"
|
|||
|
|
CATEGORY = "Control/Branch"
|
|||
|
|
DISPLAY_NAME = "条件分支"
|
|||
|
|
DESCRIPTION = "根据条件选择不同的输出路径"
|
|||
|
|
|
|||
|
|
@input_port("value", "Any", description="输入值")
|
|||
|
|
@output_port("true_output", "Any", description="条件为真时的输出")
|
|||
|
|
@output_port("false_output", "Any", description="条件为假时的输出")
|
|||
|
|
@param("threshold", "Number", default=0, description="阈值")
|
|||
|
|
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
|||
|
|
value = inputs.get("value", 0)
|
|||
|
|
threshold = self.get_param("threshold", 0)
|
|||
|
|
|
|||
|
|
outputs = {}
|
|||
|
|
if value > threshold:
|
|||
|
|
outputs["true_output"] = value
|
|||
|
|
outputs["false_output"] = None
|
|||
|
|
else:
|
|||
|
|
outputs["true_output"] = None
|
|||
|
|
outputs["false_output"] = value
|
|||
|
|
|
|||
|
|
return outputs
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===========================
|
|||
|
|
# 示例:工作流打包为函数节点
|
|||
|
|
# ===========================
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
示例工作流定义(可打包为函数节点):
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
"id": "multiply_and_sum",
|
|||
|
|
"type": "FunctionNode",
|
|||
|
|
"display_name": "乘积求和",
|
|||
|
|
"description": "先将数组中的每个数乘以2,再求和",
|
|||
|
|
"sub_workflow": {
|
|||
|
|
"nodes": [
|
|||
|
|
{
|
|||
|
|
"id": "input",
|
|||
|
|
"type": "InputNodeImpl",
|
|||
|
|
"params": {}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"id": "multiply",
|
|||
|
|
"type": "ArrayMapNode",
|
|||
|
|
"params": {"multiplier": 2}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"id": "sum",
|
|||
|
|
"type": "ArrayReduceNode",
|
|||
|
|
"params": {"operation": "sum"}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"id": "output",
|
|||
|
|
"type": "OutputNodeImpl",
|
|||
|
|
"params": {}
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"edges": [
|
|||
|
|
{
|
|||
|
|
"source": "input",
|
|||
|
|
"sourcePort": "output",
|
|||
|
|
"target": "multiply",
|
|||
|
|
"targetPort": "values",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"source": "multiply",
|
|||
|
|
"sourcePort": "mapped",
|
|||
|
|
"target": "sum",
|
|||
|
|
"targetPort": "values",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"source": "sum",
|
|||
|
|
"sourcePort": "result",
|
|||
|
|
"target": "output",
|
|||
|
|
"targetPort": "input",
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
使用方式:
|
|||
|
|
1. 从上面的定义创建 AdvancedWorkflowGraph
|
|||
|
|
2. 调用 AdvancedWorkflowExecutor.execute()
|
|||
|
|
3. 函数节点会自动展开子工作流并执行
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
__all__ = [
|
|||
|
|
"InputNodeImpl",
|
|||
|
|
"OutputNodeImpl",
|
|||
|
|
"FunctionNodeImpl",
|
|||
|
|
"ArrayMapNode",
|
|||
|
|
"ArrayConcatNode",
|
|||
|
|
"ArrayFilterNode",
|
|||
|
|
"ArrayReduceNode",
|
|||
|
|
"BroadcastNode",
|
|||
|
|
"ArrayZipNode",
|
|||
|
|
"ConditionalBranchNode"
|
|||
|
|
]
|