TraceStudio-dev/server/tests/test_advanced_features.py

305 lines
11 KiB
Python
Raw Normal View History

"""
高级功能测试 - 特殊节点维度转换函数节点嵌套
"""
import asyncio
from pathlib import Path
# 导入节点和执行器
# ==== 路径修正:兼容直接运行和服务器导入 ====
import sys
from pathlib import Path
cur = Path(__file__).resolve()
root = cur.parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from app.core.node_registry import NodeRegistry
from app.core.node_base import (
DimensionMode, WorkflowPackager, NodeType
)
from app.core.workflow_executor import WorkflowExecutor, WorkflowGraph
from app.nodes.advanced_example_nodes import *
async def test_special_nodes():
"""测试特殊节点注册"""
print("\n" + "=" * 60)
print("✅ 测试 1: 特殊节点注册")
print("=" * 60)
# 检查特殊节点是否注册
special_nodes = ["InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl"]
for node_type in special_nodes:
exists = NodeRegistry.exists(node_type)
status = "✅ REGISTERED" if exists else "❌ MISSING"
print(f" {node_type:<20} {status}")
# 检查数组操作节点
array_nodes = ["ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode",
"ArrayReduceNode", "BroadcastNode", "ArrayZipNode"]
print("\n 数组操作节点:")
for node_type in array_nodes:
exists = NodeRegistry.exists(node_type)
status = "" if exists else ""
print(f" {status} {node_type}")
print(f"\n 总注册节点数: {len(NodeRegistry.list_all())}")
async def test_dimension_inference():
"""测试维度转换推断"""
print("\n" + "=" * 60)
print("✅ 测试 2: 维度转换推断")
print("=" * 60)
print(" (维度推断测试已迁移至独立测试文件)")
async def test_simple_workflow():
"""测试简单工作流(无维度转换)"""
print("\n" + "=" * 60)
print("✅ 测试 3: 简单工作流")
print("=" * 60)
# 定义工作流:输入 → ArrayMapNode(×2) → ArrayReduceNode(sum) → 输出
nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
{"id": "reduce", "type": "normal", "class_name": "ArrayReduceNode", "params": {"operation": "sum"}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
edges = [
# InputNode 直接输出全局上下文中的 values 字段
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
{"source": "map", "source_port": "mapped", "target": "reduce", "target_port": "values"},
{"source": "reduce", "source_port": "result", "target": "output", "target_port": "input"},
]
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=nodes,
edges=edges,
global_context={"values": [1, 2, 3, 4, 5]}
)
if success:
print(" ✅ 工作流执行成功")
print(f" 输入: [1, 2, 3, 4, 5]")
print(f" 乘以2: [2, 4, 6, 8, 10]")
print(f" 求和: 30")
# 验证结果
node_results = report.get("node_results", {})
if "output" in node_results:
output_result = node_results["output"]
print(f" ✅ 最终输出: {output_result}")
else:
print(f" ❌ 工作流执行失败: {report.get('error')}")
print(f" 执行耗时: {report.get('total_duration', 0):.3f}s")
async def test_array_operations():
"""测试各种数组操作"""
print("\n" + "=" * 60)
print("✅ 测试 4: 数组操作")
print("=" * 60)
# 测试用例
test_cases = [
{
"name": "数组过滤",
"nodes": [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "filter", "type": "normal", "class_name": "ArrayFilterNode", "params": {"threshold": 2}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
],
"edges": [
{"source": "input", "source_port": "values", "target": "filter", "target_port": "values"},
{"source": "filter", "source_port": "filtered", "target": "output", "target_port": "input"},
],
"global_context": {"values": [1, 2, 3, 4, 5]},
"expected": [3, 4, 5]
},
{
"name": "广播",
"nodes": [
{"id": "input1", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "broadcast", "type": "normal", "class_name": "BroadcastNode", "params": {"count": 3}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
],
"edges": [
{"source": "input1", "source_port": "value", "target": "broadcast", "target_port": "value"},
{"source": "broadcast", "source_port": "broadcast", "target": "output", "target_port": "input"},
],
"global_context": {"value": 42},
"expected": [42, 42, 42]
}
]
for test_case in test_cases:
print(f"\n 测试: {test_case['name']}")
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=test_case["nodes"],
edges=test_case["edges"],
global_context=test_case["global_context"]
)
if success:
print(f" ✅ 执行成功")
else:
print(f" ❌ 执行失败: {report.get('error')}")
async def test_workflow_graph():
"""测试工作流图操作"""
print("\n" + "=" * 60)
print("✅ 测试 5: 工作流图操作")
print("=" * 60)
graph = WorkflowGraph()
# 添加节点
graph.add_node("n1", NodeType.INPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n1"]["class_name"] = "InputNodeImpl"
graph.add_node("n2", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n2"]["class_name"] = "ArrayMapNode"
graph.add_node("n3", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n3"]["class_name"] = "ArrayReduceNode"
graph.add_node("n4", NodeType.OUTPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n4"]["class_name"] = "OutputNodeImpl"
# 添加边InputNode 的输出端口直接是 global_context 中的字段名)
graph.add_edge("n1", "values", "n2", "values", DimensionMode.NONE)
graph.add_edge("n2", "mapped", "n3", "values", DimensionMode.NONE)
graph.add_edge("n3", "result", "n4", "input", DimensionMode.NONE)
# 测试操作
print(f" 节点数: {len(graph.nodes)}")
print(f" 边数: {len(graph.edges)}")
# 检查循环
has_cycle = graph.has_cycle()
print(f" 无循环: {'✅ YES' if not has_cycle else '❌ NO'}")
# 拓扑排序
try:
topo = graph.topological_sort()
print(f" 拓扑排序: {''.join(topo)}")
print(" ✅ 排序成功")
except ValueError as e:
print(f" ❌ 排序失败: {e}")
# 验证函数工作流
valid, error = graph.validate_function_workflow()
if valid:
print(" ✅ 可作为函数节点工作流")
else:
print(f" ❌ 不可作为函数节点: {error}")
async def test_nested_function_workflow():
"""测试嵌套函数节点"""
print("\n" + "=" * 60)
print("✅ 测试 6: 嵌套函数节点")
print("=" * 60)
# 子工作流:输入 → ×2 → 输出
sub_workflow_nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
sub_workflow_edges = [
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
{"source": "map", "source_port": "mapped", "target": "output", "target_port": "input"},
]
# 验证可作为函数节点
try:
valid, error = WorkflowPackager.validate_function_workflow(sub_workflow_nodes, sub_workflow_edges)
if valid:
print(" ✅ 子工作流验证通过")
else:
print(f" ❌ 子工作流验证失败: {error}")
return
except Exception as e:
print(f" ❌ 验证异常: {e}")
return
# 打包为函数节点
function_node_def = WorkflowPackager.package_as_function(
node_id="multiply_func",
nodes=sub_workflow_nodes,
edges=sub_workflow_edges,
display_name="乘以2",
description="将数组中的所有元素乘以2"
)
# package_as_function 返回的字段中包含 sub_workflow -> {nodes, edges}
# 将其展开为 WorkflowExecutor 期待的字段名
if "sub_workflow" in function_node_def:
function_node_def["sub_workflow_nodes"] = function_node_def["sub_workflow"].get("nodes", [])
function_node_def["sub_workflow_edges"] = function_node_def["sub_workflow"].get("edges", [])
print(f" ✅ 打包成功: {function_node_def['id']}")
# 创建包含函数节点的主工作流
main_nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
function_node_def,
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
main_edges = [
{"source": "input", "source_port": "values", "target": "multiply_func", "target_port": "values"},
{"source": "multiply_func", "source_port": "mapped", "target": "output", "target_port": "input"},
]
# 执行主工作流
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=main_nodes,
edges=main_edges,
global_context={"values": [1, 2, 3]}
)
if success:
print(" ✅ 函数节点执行成功")
print(f" 输入: [1, 2, 3]")
print(f" 函数执行完后应得到: [2, 4, 6]")
else:
print(f" ❌ 执行失败: {report.get('error')}")
async def main():
"""运行所有测试"""
print("\n" + "=" * 60)
print("🚀 TraceStudio 高级功能完整测试")
print("=" * 60)
try:
await test_special_nodes()
await test_dimension_inference()
await test_simple_workflow()
await test_array_operations()
await test_workflow_graph()
await test_nested_function_workflow()
print("\n" + "=" * 60)
print("✅ 所有测试完成!")
print("=" * 60)
except Exception as e:
print(f"\n❌ 测试过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())