TraceStudio-dist/server/app/nodes/advanced_example_nodes.py
2026-01-13 16:41:31 +08:00

381 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
特殊节点示例和高级使用场景
"""
from ..core.node_base import (
TraceNode, CachePolicy, input_port, output_port, param, context_var,
InputNode, OutputNode, FunctionNode
)
from ..core.node_registry import NodeRegistry, register_node
from typing import Any, Dict, Optional, List
# ===========================
# 注册特殊节点
# ===========================
@register_node
class InputNodeImpl(InputNode):
"""
输入节点实现
特殊逻辑:
- 没有输入端口inputs 为空)
- 从 context全局上下文中提取所有字段作为输出端口
- 每个全局上下文字段都映射为一个输出端口
"""
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 如果没有连线到此节点inputs 会是空的
# 我们从 context全局上下文中提取数据
# context 中的数据会由工作流执行器自动传递
# 返回整个全局上下文作为输出(严格返回 dict
# 如果 context 是 ExecutionContext 实例,读取其 global_context。
if context is None:
return {}
# Duck-typing: prefer `global_context` attribute when present
gc = None
try:
if hasattr(context, "global_context") and isinstance(getattr(context, "global_context"), dict):
gc = dict(getattr(context, "global_context"))
except Exception:
gc = None
if gc is None and isinstance(context, dict):
gc = dict(context)
return gc or {}
@register_node
class OutputNodeImpl(OutputNode):
"""输出节点实现"""
@input_port("input", "Any", description="要输出的数据")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
return inputs
@register_node
class FunctionNodeImpl(FunctionNode):
"""函数节点实现"""
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 由 AdvancedWorkflowExecutor 处理
return inputs
# ===========================
# 示例:支持数组操作的节点
# ===========================
@register_node
class ArrayMapNode(TraceNode):
"""
数组映射节点 - 支持升维操作
场景:
- 输入:整数数组 [1, 2, 3]
- 操作map(x => x * 2)
- 输出:[2, 4, 6]
通过升维机制实现:
AddNode作为内部逻辑升维(EXPAND)处理数组遍历
"""
NODE_TYPE = "ArrayMapNode"
CATEGORY = "Array/Transform"
DISPLAY_NAME = "数组映射"
DESCRIPTION = "对数组中的每个元素应用操作"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("mapped", "Array[Number]", description="映射后的数组")
@param("multiplier", "Number", default=1, description="乘数")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
multiplier = self.get_param("multiplier", 1)
if not isinstance(values, list):
values = [values]
mapped = [v * multiplier for v in values]
return {"mapped": mapped}
@register_node
class ArrayConcatNode(TraceNode):
"""
数组连接节点 - 支持多线汇聚
场景:
- 输入1[1, 2] (来自节点A)
- 输入2[3, 4] (来自节点B)
- 操作concat
- 输出:[1, 2, 3, 4]
通过BROADCAST机制实现
多条线汇聚到数组输入,自动打包为嵌套数组再展平
"""
NODE_TYPE = "ArrayConcatNode"
CATEGORY = "Array/Combine"
DISPLAY_NAME = "数组连接"
DESCRIPTION = "将多个数组合并为一个"
@input_port("arrays", "Array[Array]", description="要连接的数组们")
@output_port("result", "Array", description="连接后的数组")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
arrays_input = inputs.get("arrays", [])
# 展平:如果输入是 [[1,2], [3,4]],输出 [1,2,3,4]
result = []
for arr in arrays_input:
if isinstance(arr, list):
result.extend(arr)
else:
result.append(arr)
return {"result": result}
@register_node
class ArrayFilterNode(TraceNode):
"""
数组过滤节点 - 数组到数组
场景:
- 输入:[1, 2, 3, 4, 5]
- 操作filter(x > 2)
- 输出:[3, 4, 5]
"""
NODE_TYPE = "ArrayFilterNode"
CATEGORY = "Array/Filter"
DISPLAY_NAME = "数组过滤"
DESCRIPTION = "根据条件过滤数组元素"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("filtered", "Array[Number]", description="过滤后的数组")
@param("threshold", "Number", default=0, description="阈值(保留 > threshold 的元素)")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
threshold = self.get_param("threshold", 0)
if not isinstance(values, list):
values = [values]
filtered = [v for v in values if v > threshold]
return {"filtered": filtered}
@register_node
class ArrayReduceNode(TraceNode):
"""
数组规约节点 - 数组到标量
场景:
- 输入:[1, 2, 3, 4, 5]
- 操作sumreduce
- 输出15
"""
NODE_TYPE = "ArrayReduceNode"
CATEGORY = "Array/Reduce"
DISPLAY_NAME = "数组规约"
DESCRIPTION = "将数组规约为单个值"
@input_port("values", "Array[Number]", description="输入数组")
@output_port("result", "Number", description="规约后的值")
@param("operation", "String", default="sum", description="操作类型sum、product、max、min")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
values = inputs.get("values", [])
operation = self.get_param("operation", "sum")
if not isinstance(values, list):
values = [values]
if operation == "sum":
result = sum(values)
elif operation == "product":
result = 1
for v in values:
result *= v
elif operation == "max":
result = max(values) if values else 0
elif operation == "min":
result = min(values) if values else 0
else:
result = sum(values)
return {"result": result}
@register_node
class BroadcastNode(TraceNode):
"""
广播节点 - 标量到数组
场景:
- 输入5
- 操作broadcast(3次)
- 输出:[5, 5, 5]
"""
NODE_TYPE = "BroadcastNode"
CATEGORY = "Array/Broadcast"
DISPLAY_NAME = "广播"
DESCRIPTION = "将单个值广播为数组"
@input_port("value", "Any", description="要广播的值")
@output_port("broadcast", "Array", description="广播后的数组")
@param("count", "Integer", default=3, description="广播次数")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
value = inputs.get("value")
count = self.get_param("count", 3)
result = [value] * count
return {"broadcast": result}
@register_node
class ArrayZipNode(TraceNode):
"""
数组拉链节点 - 多数组合并
场景:
- 输入1[1, 2, 3]
- 输入2['a', 'b', 'c']
- 操作zip
- 输出:[[1, 'a'], [2, 'b'], [3, 'c']]
"""
NODE_TYPE = "ArrayZipNode"
CATEGORY = "Array/Combine"
DISPLAY_NAME = "数组拉链"
DESCRIPTION = "将多个数组按位置组合"
@input_port("array1", "Array", description="第一个数组")
@input_port("array2", "Array", description="第二个数组")
@output_port("zipped", "Array[Array]", description="拉链后的数组")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
arr1 = inputs.get("array1", [])
arr2 = inputs.get("array2", [])
if not isinstance(arr1, list):
arr1 = [arr1]
if not isinstance(arr2, list):
arr2 = [arr2]
# 拉链操作
zipped = list(zip(arr1, arr2))
result = [list(pair) for pair in zipped]
return {"zipped": result}
@register_node
class ConditionalBranchNode(TraceNode):
"""
条件分支节点 - 根据条件选择输出
场景:
- 输入:值
- 条件:值 > 10
- 输出1如果满足条件
- 输出2如果不满足条件
"""
NODE_TYPE = "ConditionalBranchNode"
CATEGORY = "Control/Branch"
DISPLAY_NAME = "条件分支"
DESCRIPTION = "根据条件选择不同的输出路径"
@input_port("value", "Any", description="输入值")
@output_port("true_output", "Any", description="条件为真时的输出")
@output_port("false_output", "Any", description="条件为假时的输出")
@param("threshold", "Number", default=0, description="阈值")
async def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
value = inputs.get("value", 0)
threshold = self.get_param("threshold", 0)
outputs = {}
if value > threshold:
outputs["true_output"] = value
outputs["false_output"] = None
else:
outputs["true_output"] = None
outputs["false_output"] = value
return outputs
# ===========================
# 示例:工作流打包为函数节点
# ===========================
"""
示例工作流定义(可打包为函数节点):
{
"id": "multiply_and_sum",
"type": "FunctionNode",
"display_name": "乘积求和",
"description": "先将数组中的每个数乘以2再求和",
"sub_workflow": {
"nodes": [
{
"id": "input",
"type": "InputNodeImpl",
"params": {}
},
{
"id": "multiply",
"type": "ArrayMapNode",
"params": {"multiplier": 2}
},
{
"id": "sum",
"type": "ArrayReduceNode",
"params": {"operation": "sum"}
},
{
"id": "output",
"type": "OutputNodeImpl",
"params": {}
}
],
"edges": [
{
"source": "input",
"sourcePort": "output",
"target": "multiply",
"targetPort": "values",
},
{
"source": "multiply",
"sourcePort": "mapped",
"target": "sum",
"targetPort": "values",
},
{
"source": "sum",
"sourcePort": "result",
"target": "output",
"targetPort": "input",
}
]
}
}
使用方式:
1. 从上面的定义创建 AdvancedWorkflowGraph
2. 调用 AdvancedWorkflowExecutor.execute()
3. 函数节点会自动展开子工作流并执行
"""
__all__ = [
"InputNodeImpl",
"OutputNodeImpl",
"FunctionNodeImpl",
"ArrayMapNode",
"ArrayConcatNode",
"ArrayFilterNode",
"ArrayReduceNode",
"BroadcastNode",
"ArrayZipNode",
"ConditionalBranchNode"
]