TraceStudio/server/app/core/function_nodes.py
2026-01-12 21:51:45 +08:00

312 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
函数节点系统 - 支持工作流封装和嵌套
提供 InputNode、OutputNode、FunctionNode 三类特殊节点
关键修复 (v1.1):
- InputSpec/OutputSpec/ParamSpec 改为字典格式(非列表)
- 类属性使用大写 NODE_TYPE 而非小写 node_type
- 方法改为 process() 而非 execute()
- 返回格式改为 {"outputs": {}, "context": {}}
"""
from typing import Any, Dict, List, Optional
from server.app.core.node_base import (
TraceNode, NodeType
)
class InputNode(TraceNode):
"""
输入节点 - 函数的入口
作用:接收外部输入,传递给函数内部的工作流
在函数中的角色:
- 作为子工作流的入口点
- 从 context 读取 __function_input_{param_name}
- 将外部输入转换为内部端口输出
"""
NODE_TYPE = NodeType.INPUT
CATEGORY = "Function/Input"
DISPLAY_NAME = "函数输入"
DESCRIPTION = "函数的输入入口节点"
# 输入节点没有输入端口(自己就是输入)
InputSpec = {}
# 输出端口
OutputSpec = {
"value": ("Any", {"description": "输入值", "required": True})
}
# 参数定义(用于配置参数名和类型)
ParamSpec = {
"param_name": ("String", {"description": "参数名称", "default": "input"}),
"param_type": ("String", {"description": "参数类型", "default": "Any"})
}
ContextSpec = {}
def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
输入节点执行:直接从 context 中获取外部传入的值
Args:
inputs: 输入(对于 InputNode 应该为空)
context: 上下文,包含 __function_input_{param_name}
Returns:
{"outputs": {"value": ...}, "context": {}}
"""
param_name = self.get_param("param_name", "input")
external_value = (context or {}).get(f"__function_input_{param_name}", None)
return {"value": external_value}
class OutputNode(TraceNode):
"""
输出节点 - 函数的出口
作用:收集函数内部工作流的结果,返回给外部
在函数中的角色:
- 作为子工作流的出口点
- 从内部节点接收输入
- 写入到 context 的 __function_output_{output_name}
- 供外部 FunctionNode 读取
"""
NODE_TYPE = NodeType.OUTPUT
CATEGORY = "Function/Output"
DISPLAY_NAME = "函数输出"
DESCRIPTION = "函数的输出出口节点"
# 输入端口
InputSpec = {
"value": ("Any", {"description": "输出值", "required": True})
}
# 输出节点没有后续输出
OutputSpec = {}
# 参数定义
ParamSpec = {
"output_name": ("String", {"description": "输出名称", "default": "result"}),
"output_type": ("String", {"description": "输出类型", "default": "Any"})
}
ContextSpec = {}
def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
输出节点执行:将输入值存储到 context供外部读取
Args:
inputs: 输入数据,应包含 "value"
context: 上下文,用于存储输出
Returns:
{"outputs": {}, "context": {...}}
"""
output_name = self.get_param("output_name", "result")
value = inputs.get("value")
# 存储到 context 供函数节点读取
if context is None:
context = {}
context[f"__function_output_{output_name}"] = value
return {}
class FunctionNode(TraceNode):
"""
函数节点 - 对工作流的封装
作用:将整个工作流作为一个节点在其他工作流中使用
"""
NODE_TYPE = NodeType.FUNCTION
CATEGORY = "Function/Workflow"
# 这些会在 __init__ 中动态生成
InputSpec = {}
OutputSpec = {}
ParamSpec = {}
ContextSpec = {}
def __init__(self, node_id: str, params: Optional[Dict[str, Any]] = None,
function_name: Optional[str] = None,
workflow_data: Optional[Dict[str, Any]] = None):
"""
初始化 FunctionNode
Args:
node_id: 节点唯一标识
params: 节点参数
function_name: 函数名称
workflow_data: 函数对应的工作流数据
"""
super().__init__(node_id, params)
self.function_name = function_name or "unnamed_function"
self.workflow_data = workflow_data or {}
# 动态生成规范
self._build_specs()
# 设置显示名称
self.DISPLAY_NAME = f"函数:{self.function_name}"
self.DESCRIPTION = f"执行函数工作流:{self.function_name}"
def _build_specs(self):
"""
从工作流数据动态生成 InputSpec、OutputSpec、ParamSpec
"""
# 从 workflow_data 的 inputs/outputs 构建规范
input_spec = {}
output_spec = {}
# 处理工作流中定义的输入
for inp in self.workflow_data.get("inputs", []):
input_name = inp.get("name", "input")
input_type = inp.get("type", "Any")
input_spec[input_name] = (
input_type,
{"description": inp.get("description", f"输入:{input_name}"), "required": True}
)
# 处理工作流中定义的输出
for out in self.workflow_data.get("outputs", []):
output_name = out.get("name", "output")
output_type = out.get("type", "Any")
output_spec[output_name] = (
output_type,
{"description": out.get("description", f"输出:{output_name}")}
)
# 如果没有定义,设置默认值
if not input_spec:
input_spec = {
"input": ("Any", {"description": "默认输入", "required": True})
}
if not output_spec:
output_spec = {
"output": ("Any", {"description": "默认输出"})
}
self.InputSpec = input_spec
self.OutputSpec = output_spec
# ParamSpec: 基本参数
self.ParamSpec = {
"function_name": (
"String",
{"description": "函数名称", "default": self.function_name}
)
}
def process(self, inputs: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
执行函数:在子工作流中执行,收集结果
Args:
inputs: 输入数据
context: 执行上下文
Returns:
{"outputs": {...}, "context": {...}}
"""
# 准备执行上下文
exec_context = context.copy() if context else {}
# 注入输入数据到 context供子工作流中的 InputNode 读取)
for input_name, input_value in inputs.items():
exec_context[f"__function_input_{input_name}"] = input_value
# 执行子工作流(这里简化处理,实际应该由工作流引擎执行)
# 在生产环境中,应该调用工作流执行引擎
# workflow_executor = WorkflowExecutor()
# results = workflow_executor.execute(
# nodes=self.workflow_data.get("nodes", []),
# edges=self.workflow_data.get("edges", []),
# context=exec_context
# )
# 临时处理:收集子工作流的输出
outputs = {}
for output_key in self.OutputSpec.keys():
# 尝试从 context 读取对应的输出(由子工作流的 OutputNode 写入)
output_value = exec_context.get(f"__function_output_{output_key}")
if output_value is not None:
outputs[output_key] = output_value
else:
# 如果子工作流没有产生输出,使用 None
outputs[output_key] = None
return outputs
def get_workflow_data(self) -> Dict[str, Any]:
"""获取函数对应的工作流数据"""
return self.workflow_data
def set_workflow_data(self, workflow_data: Dict[str, Any]):
"""设置函数对应的工作流数据"""
self.workflow_data = workflow_data
# 重新生成规范
self._build_specs()
# 工厂函数:从工作流创建函数节点类
def create_function_node_class(function_name: str, workflow_data: Dict[str, Any]) -> type:
"""
动态创建函数节点类
Args:
function_name: 函数名称
workflow_data: 工作流数据
Returns:
FunctionNode 的子类
"""
class_name = f"Function_{function_name}"
# 预先根据 workflow_data 生成类级规范,供前端插件元数据使用
cls_input_spec: Dict[str, tuple] = {}
for inp in workflow_data.get("inputs", []) or []:
name = inp.get("name", "input")
dtype = inp.get("type", "Any")
cls_input_spec[name] = (dtype, {"description": inp.get("description", name), "required": True})
cls_output_spec: Dict[str, tuple] = {}
for out in workflow_data.get("outputs", []) or []:
name = out.get("name", "output")
dtype = out.get("type", "Any")
cls_output_spec[name] = (dtype, {"description": out.get("description", name)})
if not cls_input_spec:
cls_input_spec = {"input": ("Any", {"description": "默认输入", "required": True})}
if not cls_output_spec:
cls_output_spec = {"output": ("Any", {"description": "默认输出"})}
class DynamicFunctionNode(FunctionNode):
"""动态生成的函数节点"""
# 类级元数据,供插件列表/前端元数据使用
CATEGORY = "Function/Custom"
DISPLAY_NAME = workflow_data.get("display_name", function_name)
DESCRIPTION = workflow_data.get("description", "")
NODE_TYPE = NodeType.FUNCTION
# 四大规范(类级),用于 NodeRegistry.get_metadata
InputSpec = cls_input_spec
OutputSpec = cls_output_spec
ParamSpec = {
"function_name": ("String", {"description": "函数名称", "default": function_name})
}
ContextSpec = {}
def __init__(self, node_id: str, params: Optional[Dict[str, Any]] = None):
"""实例初始化时仍构建实例级规范,保障执行期一致"""
super().__init__(node_id, params, function_name, workflow_data)
# 设置类名便于调试
DynamicFunctionNode.__name__ = class_name
DynamicFunctionNode.__qualname__ = class_name
return DynamicFunctionNode