TraceStudio-dev/server/app/nodes/advanced_example_nodes.py

381 lines
12 KiB
Python
Raw Normal View History

"""
特殊节点示例和高级使用场景
"""
from ..core.node_base import (
TraceNode, CachePolicy, input_port, output_port, param, context_var,
InputNode, OutputNode, FunctionNode
)
from ..core.node_registry import NodeRegistry, register_node
from typing import Any, Dict, Optional, List
# ===========================
# 注册特殊节点
# ===========================
@register_node
class InputNodeImpl(InputNode):
"""
输入节点实现
特殊逻辑
- 没有输入端口inputs 为空
- context全局上下文中提取所有字段作为输出端口
- 每个全局上下文字段都映射为一个输出端口
"""
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 如果没有连线到此节点inputs 会是空的
# 我们从 context全局上下文中提取数据
# context 中的数据会由工作流执行器自动传递
# 返回整个全局上下文作为输出(严格返回 dict
# 如果 context 是 ExecutionContext 实例,读取其 global_context。
if context is None:
return {}
# Duck-typing: prefer `global_context` attribute when present
gc = None
try:
if hasattr(context, "global_context") and isinstance(getattr(context, "global_context"), dict):
gc = dict(getattr(context, "global_context"))
except Exception:
gc = None
if gc is None and isinstance(context, dict):
gc = dict(context)
return gc or {}
@register_node
class OutputNodeImpl(OutputNode):
"""输出节点实现"""
@input_port("input", "Any", description="要输出的数据")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
return inputs
@register_node
class FunctionNodeImpl(FunctionNode):
"""函数节点实现"""
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 由 AdvancedWorkflowExecutor 处理
return inputs
# ===========================
# 示例:支持数组操作的节点
# ===========================
@register_node
class ArrayMapNode(TraceNode):
"""
数组映射节点 - 支持升维操作
场景
- 输入整数数组 [1, 2, 3]
- 操作map(x => x * 2)
- 输出[2, 4, 6]
通过升维机制实现
AddNode作为内部逻辑升维(EXPAND)处理数组遍历
"""
NODE_TYPE = "ArrayMapNode"
CATEGORY = "Array/Transform"
DISPLAY_NAME = "数组映射"
DESCRIPTION = "对数组中的每个元素应用操作"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("mapped", "Array[Number]", description="映射后的数组")
@param("multiplier", "Number", default=1, description="乘数")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
multiplier = self.get_param("multiplier", 1)
if not isinstance(values, list):
values = [values]
mapped = [v * multiplier for v in values]
return {"mapped": mapped}
@register_node
class ArrayConcatNode(TraceNode):
"""
数组连接节点 - 支持多线汇聚
场景
- 输入1[1, 2] (来自节点A)
- 输入2[3, 4] (来自节点B)
- 操作concat
- 输出[1, 2, 3, 4]
通过BROADCAST机制实现
多条线汇聚到数组输入自动打包为嵌套数组再展平
"""
NODE_TYPE = "ArrayConcatNode"
CATEGORY = "Array/Combine"
DISPLAY_NAME = "数组连接"
DESCRIPTION = "将多个数组合并为一个"
@input_port("arrays", "Array[Array]", description="要连接的数组们")
@output_port("result", "Array", description="连接后的数组")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
arrays_input = inputs.get("arrays", [])
# 展平:如果输入是 [[1,2], [3,4]],输出 [1,2,3,4]
result = []
for arr in arrays_input:
if isinstance(arr, list):
result.extend(arr)
else:
result.append(arr)
return {"result": result}
@register_node
class ArrayFilterNode(TraceNode):
"""
数组过滤节点 - 数组到数组
场景
- 输入[1, 2, 3, 4, 5]
- 操作filter(x > 2)
- 输出[3, 4, 5]
"""
NODE_TYPE = "ArrayFilterNode"
CATEGORY = "Array/Filter"
DISPLAY_NAME = "数组过滤"
DESCRIPTION = "根据条件过滤数组元素"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("filtered", "Array[Number]", description="过滤后的数组")
@param("threshold", "Number", default=0, description="阈值(保留 > threshold 的元素)")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
threshold = self.get_param("threshold", 0)
if not isinstance(values, list):
values = [values]
filtered = [v for v in values if v > threshold]
return {"filtered": filtered}
@register_node
class ArrayReduceNode(TraceNode):
"""
数组规约节点 - 数组到标量
场景
- 输入[1, 2, 3, 4, 5]
- 操作sumreduce
- 输出15
"""
NODE_TYPE = "ArrayReduceNode"
CATEGORY = "Array/Reduce"
DISPLAY_NAME = "数组规约"
DESCRIPTION = "将数组规约为单个值"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("result", "Number", description="规约后的值")
@param("operation", "String", default="sum", description="操作类型sum、product、max、min")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
operation = self.get_param("operation", "sum")
if not isinstance(values, list):
values = [values]
if operation == "sum":
result = sum(values)
elif operation == "product":
result = 1
for v in values:
result *= v
elif operation == "max":
result = max(values) if values else 0
elif operation == "min":
result = min(values) if values else 0
else:
result = sum(values)
return {"result": result}
@register_node
class BroadcastNode(TraceNode):
"""
广播节点 - 标量到数组
场景
- 输入5
- 操作broadcast(3)
- 输出[5, 5, 5]
"""
NODE_TYPE = "BroadcastNode"
CATEGORY = "Array/Broadcast"
DISPLAY_NAME = "广播"
DESCRIPTION = "将单个值广播为数组"
@input_port("value", "Any", description="要广播的值")
@output_port("broadcast", "Array", description="广播后的数组")
@param("count", "Integer", default=3, description="广播次数")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
value = inputs.get("value")
count = self.get_param("count", 3)
result = [value] * count
return {"broadcast": result}
@register_node
class ArrayZipNode(TraceNode):
"""
数组拉链节点 - 多数组合并
场景
- 输入1[1, 2, 3]
- 输入2['a', 'b', 'c']
- 操作zip
- 输出[[1, 'a'], [2, 'b'], [3, 'c']]
"""
NODE_TYPE = "ArrayZipNode"
CATEGORY = "Array/Combine"
DISPLAY_NAME = "数组拉链"
DESCRIPTION = "将多个数组按位置组合"
@input_port("array1", "Array", description="第一个数组")
@input_port("array2", "Array", description="第二个数组")
@output_port("zipped", "Array[Array]", description="拉链后的数组")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
arr1 = inputs.get("array1", [])
arr2 = inputs.get("array2", [])
if not isinstance(arr1, list):
arr1 = [arr1]
if not isinstance(arr2, list):
arr2 = [arr2]
# 拉链操作
zipped = list(zip(arr1, arr2))
result = [list(pair) for pair in zipped]
return {"zipped": result}
@register_node
class ConditionalBranchNode(TraceNode):
"""
条件分支节点 - 根据条件选择输出
场景
- 输入
- 条件 > 10
- 输出1如果满足条件
- 输出2如果不满足条件
"""
NODE_TYPE = "ConditionalBranchNode"
CATEGORY = "Control/Branch"
DISPLAY_NAME = "条件分支"
DESCRIPTION = "根据条件选择不同的输出路径"
@input_port("value", "Any", description="输入值")
@output_port("true_output", "Any", description="条件为真时的输出")
@output_port("false_output", "Any", description="条件为假时的输出")
@param("threshold", "Number", default=0, description="阈值")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
value = inputs.get("value", 0)
threshold = self.get_param("threshold", 0)
outputs = {}
if value > threshold:
outputs["true_output"] = value
outputs["false_output"] = None
else:
outputs["true_output"] = None
outputs["false_output"] = value
return outputs
# ===========================
# 示例:工作流打包为函数节点
# ===========================
"""
示例工作流定义可打包为函数节点
{
"id": "multiply_and_sum",
"type": "FunctionNode",
"display_name": "乘积求和",
"description": "先将数组中的每个数乘以2再求和",
"sub_workflow": {
"nodes": [
{
"id": "input",
"type": "InputNodeImpl",
"params": {}
},
{
"id": "multiply",
"type": "ArrayMapNode",
"params": {"multiplier": 2}
},
{
"id": "sum",
"type": "ArrayReduceNode",
"params": {"operation": "sum"}
},
{
"id": "output",
"type": "OutputNodeImpl",
"params": {}
}
],
"edges": [
{
"source": "input",
"sourcePort": "output",
"target": "multiply",
"targetPort": "values",
},
{
"source": "multiply",
"sourcePort": "mapped",
"target": "sum",
"targetPort": "values",
},
{
"source": "sum",
"sourcePort": "result",
"target": "output",
"targetPort": "input",
}
]
}
}
使用方式
1. 从上面的定义创建 AdvancedWorkflowGraph
2. 调用 AdvancedWorkflowExecutor.execute()
3. 函数节点会自动展开子工作流并执行
"""
__all__ = [
"InputNodeImpl",
"OutputNodeImpl",
"FunctionNodeImpl",
"ArrayMapNode",
"ArrayConcatNode",
"ArrayFilterNode",
"ArrayReduceNode",
"BroadcastNode",
"ArrayZipNode",
"ConditionalBranchNode"
]