305 lines
11 KiB
Python
305 lines
11 KiB
Python
"""
|
||
高级功能测试 - 特殊节点、维度转换、函数节点嵌套
|
||
"""
|
||
import asyncio
|
||
from pathlib import Path
|
||
|
||
# 导入节点和执行器
|
||
|
||
# ==== 路径修正:兼容直接运行和服务器导入 ====
|
||
import sys
|
||
from pathlib import Path
|
||
cur = Path(__file__).resolve()
|
||
root = cur.parent.parent
|
||
if str(root) not in sys.path:
|
||
sys.path.insert(0, str(root))
|
||
|
||
from app.core.node_registry import NodeRegistry
|
||
from app.core.node_base import (
|
||
DimensionMode, WorkflowPackager, NodeType
|
||
)
|
||
from app.core.workflow_executor import WorkflowExecutor, WorkflowGraph
|
||
from app.nodes.advanced_example_nodes import *
|
||
|
||
|
||
async def test_special_nodes():
|
||
"""测试特殊节点注册"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 1: 特殊节点注册")
|
||
print("=" * 60)
|
||
|
||
# 检查特殊节点是否注册
|
||
special_nodes = ["InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl"]
|
||
for node_type in special_nodes:
|
||
exists = NodeRegistry.exists(node_type)
|
||
status = "✅ REGISTERED" if exists else "❌ MISSING"
|
||
print(f" {node_type:<20} {status}")
|
||
|
||
# 检查数组操作节点
|
||
array_nodes = ["ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode",
|
||
"ArrayReduceNode", "BroadcastNode", "ArrayZipNode"]
|
||
print("\n 数组操作节点:")
|
||
for node_type in array_nodes:
|
||
exists = NodeRegistry.exists(node_type)
|
||
status = "✅" if exists else "❌"
|
||
print(f" {status} {node_type}")
|
||
|
||
print(f"\n 总注册节点数: {len(NodeRegistry.list_all())}")
|
||
|
||
|
||
async def test_dimension_inference():
|
||
"""测试维度转换推断"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 2: 维度转换推断")
|
||
print("=" * 60)
|
||
|
||
print(" (维度推断测试已迁移至独立测试文件)")
|
||
|
||
|
||
async def test_simple_workflow():
|
||
"""测试简单工作流(无维度转换)"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 3: 简单工作流")
|
||
print("=" * 60)
|
||
|
||
# 定义工作流:输入 → ArrayMapNode(×2) → ArrayReduceNode(sum) → 输出
|
||
nodes = [
|
||
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
||
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
|
||
{"id": "reduce", "type": "normal", "class_name": "ArrayReduceNode", "params": {"operation": "sum"}},
|
||
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
||
]
|
||
|
||
edges = [
|
||
# InputNode 直接输出全局上下文中的 values 字段
|
||
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
|
||
{"source": "map", "source_port": "mapped", "target": "reduce", "target_port": "values"},
|
||
{"source": "reduce", "source_port": "result", "target": "output", "target_port": "input"},
|
||
]
|
||
|
||
executor = WorkflowExecutor(user_id="test_user")
|
||
success, report = await executor.execute(
|
||
nodes=nodes,
|
||
edges=edges,
|
||
global_context={"values": [1, 2, 3, 4, 5]}
|
||
)
|
||
|
||
if success:
|
||
print(" ✅ 工作流执行成功")
|
||
print(f" 输入: [1, 2, 3, 4, 5]")
|
||
print(f" 乘以2: [2, 4, 6, 8, 10]")
|
||
print(f" 求和: 30")
|
||
|
||
# 验证结果
|
||
node_results = report.get("node_results", {})
|
||
if "output" in node_results:
|
||
output_result = node_results["output"]
|
||
print(f" ✅ 最终输出: {output_result}")
|
||
else:
|
||
print(f" ❌ 工作流执行失败: {report.get('error')}")
|
||
|
||
print(f" 执行耗时: {report.get('total_duration', 0):.3f}s")
|
||
|
||
|
||
async def test_array_operations():
|
||
"""测试各种数组操作"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 4: 数组操作")
|
||
print("=" * 60)
|
||
|
||
# 测试用例
|
||
test_cases = [
|
||
{
|
||
"name": "数组过滤",
|
||
"nodes": [
|
||
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
||
{"id": "filter", "type": "normal", "class_name": "ArrayFilterNode", "params": {"threshold": 2}},
|
||
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
||
],
|
||
"edges": [
|
||
{"source": "input", "source_port": "values", "target": "filter", "target_port": "values"},
|
||
{"source": "filter", "source_port": "filtered", "target": "output", "target_port": "input"},
|
||
],
|
||
"global_context": {"values": [1, 2, 3, 4, 5]},
|
||
"expected": [3, 4, 5]
|
||
},
|
||
{
|
||
"name": "广播",
|
||
"nodes": [
|
||
{"id": "input1", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
||
{"id": "broadcast", "type": "normal", "class_name": "BroadcastNode", "params": {"count": 3}},
|
||
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
||
],
|
||
"edges": [
|
||
{"source": "input1", "source_port": "value", "target": "broadcast", "target_port": "value"},
|
||
{"source": "broadcast", "source_port": "broadcast", "target": "output", "target_port": "input"},
|
||
],
|
||
"global_context": {"value": 42},
|
||
"expected": [42, 42, 42]
|
||
}
|
||
]
|
||
|
||
for test_case in test_cases:
|
||
print(f"\n 测试: {test_case['name']}")
|
||
executor = WorkflowExecutor(user_id="test_user")
|
||
success, report = await executor.execute(
|
||
nodes=test_case["nodes"],
|
||
edges=test_case["edges"],
|
||
global_context=test_case["global_context"]
|
||
)
|
||
|
||
if success:
|
||
print(f" ✅ 执行成功")
|
||
else:
|
||
print(f" ❌ 执行失败: {report.get('error')}")
|
||
|
||
|
||
async def test_workflow_graph():
|
||
"""测试工作流图操作"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 5: 工作流图操作")
|
||
print("=" * 60)
|
||
|
||
graph = WorkflowGraph()
|
||
|
||
# 添加节点
|
||
graph.add_node("n1", NodeType.INPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
|
||
graph.nodes["n1"]["class_name"] = "InputNodeImpl"
|
||
graph.add_node("n2", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
|
||
graph.nodes["n2"]["class_name"] = "ArrayMapNode"
|
||
graph.add_node("n3", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
|
||
graph.nodes["n3"]["class_name"] = "ArrayReduceNode"
|
||
graph.add_node("n4", NodeType.OUTPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
|
||
graph.nodes["n4"]["class_name"] = "OutputNodeImpl"
|
||
|
||
# 添加边(InputNode 的输出端口直接是 global_context 中的字段名)
|
||
graph.add_edge("n1", "values", "n2", "values", DimensionMode.NONE)
|
||
graph.add_edge("n2", "mapped", "n3", "values", DimensionMode.NONE)
|
||
graph.add_edge("n3", "result", "n4", "input", DimensionMode.NONE)
|
||
|
||
# 测试操作
|
||
print(f" 节点数: {len(graph.nodes)}")
|
||
print(f" 边数: {len(graph.edges)}")
|
||
|
||
# 检查循环
|
||
has_cycle = graph.has_cycle()
|
||
print(f" 无循环: {'✅ YES' if not has_cycle else '❌ NO'}")
|
||
|
||
# 拓扑排序
|
||
try:
|
||
topo = graph.topological_sort()
|
||
print(f" 拓扑排序: {' → '.join(topo)}")
|
||
print(" ✅ 排序成功")
|
||
except ValueError as e:
|
||
print(f" ❌ 排序失败: {e}")
|
||
|
||
# 验证函数工作流
|
||
valid, error = graph.validate_function_workflow()
|
||
if valid:
|
||
print(" ✅ 可作为函数节点工作流")
|
||
else:
|
||
print(f" ❌ 不可作为函数节点: {error}")
|
||
|
||
|
||
async def test_nested_function_workflow():
|
||
"""测试嵌套函数节点"""
|
||
print("\n" + "=" * 60)
|
||
print("✅ 测试 6: 嵌套函数节点")
|
||
print("=" * 60)
|
||
|
||
# 子工作流:输入 → ×2 → 输出
|
||
sub_workflow_nodes = [
|
||
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
||
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
|
||
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
||
]
|
||
|
||
sub_workflow_edges = [
|
||
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
|
||
{"source": "map", "source_port": "mapped", "target": "output", "target_port": "input"},
|
||
]
|
||
|
||
# 验证可作为函数节点
|
||
try:
|
||
valid, error = WorkflowPackager.validate_function_workflow(sub_workflow_nodes, sub_workflow_edges)
|
||
if valid:
|
||
print(" ✅ 子工作流验证通过")
|
||
else:
|
||
print(f" ❌ 子工作流验证失败: {error}")
|
||
return
|
||
except Exception as e:
|
||
print(f" ❌ 验证异常: {e}")
|
||
return
|
||
|
||
# 打包为函数节点
|
||
function_node_def = WorkflowPackager.package_as_function(
|
||
node_id="multiply_func",
|
||
nodes=sub_workflow_nodes,
|
||
edges=sub_workflow_edges,
|
||
display_name="乘以2",
|
||
description="将数组中的所有元素乘以2"
|
||
)
|
||
|
||
# package_as_function 返回的字段中包含 sub_workflow -> {nodes, edges}
|
||
# 将其展开为 WorkflowExecutor 期待的字段名
|
||
if "sub_workflow" in function_node_def:
|
||
function_node_def["sub_workflow_nodes"] = function_node_def["sub_workflow"].get("nodes", [])
|
||
function_node_def["sub_workflow_edges"] = function_node_def["sub_workflow"].get("edges", [])
|
||
|
||
print(f" ✅ 打包成功: {function_node_def['id']}")
|
||
|
||
# 创建包含函数节点的主工作流
|
||
main_nodes = [
|
||
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
||
function_node_def,
|
||
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
||
]
|
||
|
||
main_edges = [
|
||
{"source": "input", "source_port": "values", "target": "multiply_func", "target_port": "values"},
|
||
{"source": "multiply_func", "source_port": "mapped", "target": "output", "target_port": "input"},
|
||
]
|
||
|
||
# 执行主工作流
|
||
executor = WorkflowExecutor(user_id="test_user")
|
||
success, report = await executor.execute(
|
||
nodes=main_nodes,
|
||
edges=main_edges,
|
||
global_context={"values": [1, 2, 3]}
|
||
)
|
||
|
||
if success:
|
||
print(" ✅ 函数节点执行成功")
|
||
print(f" 输入: [1, 2, 3]")
|
||
print(f" 函数执行完后应得到: [2, 4, 6]")
|
||
else:
|
||
print(f" ❌ 执行失败: {report.get('error')}")
|
||
|
||
|
||
async def main():
|
||
"""运行所有测试"""
|
||
print("\n" + "=" * 60)
|
||
print("🚀 TraceStudio 高级功能完整测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
await test_special_nodes()
|
||
await test_dimension_inference()
|
||
await test_simple_workflow()
|
||
await test_array_operations()
|
||
await test_workflow_graph()
|
||
await test_nested_function_workflow()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✅ 所有测试完成!")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 测试过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|