TraceStudio-dev/server/tests/test_advanced_features.py
2026-01-09 21:37:02 +08:00

305 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
高级功能测试 - 特殊节点、维度转换、函数节点嵌套
"""
import asyncio
from pathlib import Path
# 导入节点和执行器
# ==== 路径修正:兼容直接运行和服务器导入 ====
import sys
from pathlib import Path
cur = Path(__file__).resolve()
root = cur.parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))
from app.core.node_registry import NodeRegistry
from app.core.node_base import (
DimensionMode, WorkflowPackager, NodeType
)
from app.core.workflow_executor import WorkflowExecutor, WorkflowGraph
from app.nodes.advanced_example_nodes import *
async def test_special_nodes():
"""测试特殊节点注册"""
print("\n" + "=" * 60)
print("✅ 测试 1: 特殊节点注册")
print("=" * 60)
# 检查特殊节点是否注册
special_nodes = ["InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl"]
for node_type in special_nodes:
exists = NodeRegistry.exists(node_type)
status = "✅ REGISTERED" if exists else "❌ MISSING"
print(f" {node_type:<20} {status}")
# 检查数组操作节点
array_nodes = ["ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode",
"ArrayReduceNode", "BroadcastNode", "ArrayZipNode"]
print("\n 数组操作节点:")
for node_type in array_nodes:
exists = NodeRegistry.exists(node_type)
status = "" if exists else ""
print(f" {status} {node_type}")
print(f"\n 总注册节点数: {len(NodeRegistry.list_all())}")
async def test_dimension_inference():
"""测试维度转换推断"""
print("\n" + "=" * 60)
print("✅ 测试 2: 维度转换推断")
print("=" * 60)
print(" (维度推断测试已迁移至独立测试文件)")
async def test_simple_workflow():
"""测试简单工作流(无维度转换)"""
print("\n" + "=" * 60)
print("✅ 测试 3: 简单工作流")
print("=" * 60)
# 定义工作流:输入 → ArrayMapNode(×2) → ArrayReduceNode(sum) → 输出
nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
{"id": "reduce", "type": "normal", "class_name": "ArrayReduceNode", "params": {"operation": "sum"}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
edges = [
# InputNode 直接输出全局上下文中的 values 字段
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
{"source": "map", "source_port": "mapped", "target": "reduce", "target_port": "values"},
{"source": "reduce", "source_port": "result", "target": "output", "target_port": "input"},
]
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=nodes,
edges=edges,
global_context={"values": [1, 2, 3, 4, 5]}
)
if success:
print(" ✅ 工作流执行成功")
print(f" 输入: [1, 2, 3, 4, 5]")
print(f" 乘以2: [2, 4, 6, 8, 10]")
print(f" 求和: 30")
# 验证结果
node_results = report.get("node_results", {})
if "output" in node_results:
output_result = node_results["output"]
print(f" ✅ 最终输出: {output_result}")
else:
print(f" ❌ 工作流执行失败: {report.get('error')}")
print(f" 执行耗时: {report.get('total_duration', 0):.3f}s")
async def test_array_operations():
"""测试各种数组操作"""
print("\n" + "=" * 60)
print("✅ 测试 4: 数组操作")
print("=" * 60)
# 测试用例
test_cases = [
{
"name": "数组过滤",
"nodes": [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "filter", "type": "normal", "class_name": "ArrayFilterNode", "params": {"threshold": 2}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
],
"edges": [
{"source": "input", "source_port": "values", "target": "filter", "target_port": "values"},
{"source": "filter", "source_port": "filtered", "target": "output", "target_port": "input"},
],
"global_context": {"values": [1, 2, 3, 4, 5]},
"expected": [3, 4, 5]
},
{
"name": "广播",
"nodes": [
{"id": "input1", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "broadcast", "type": "normal", "class_name": "BroadcastNode", "params": {"count": 3}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
],
"edges": [
{"source": "input1", "source_port": "value", "target": "broadcast", "target_port": "value"},
{"source": "broadcast", "source_port": "broadcast", "target": "output", "target_port": "input"},
],
"global_context": {"value": 42},
"expected": [42, 42, 42]
}
]
for test_case in test_cases:
print(f"\n 测试: {test_case['name']}")
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=test_case["nodes"],
edges=test_case["edges"],
global_context=test_case["global_context"]
)
if success:
print(f" ✅ 执行成功")
else:
print(f" ❌ 执行失败: {report.get('error')}")
async def test_workflow_graph():
"""测试工作流图操作"""
print("\n" + "=" * 60)
print("✅ 测试 5: 工作流图操作")
print("=" * 60)
graph = WorkflowGraph()
# 添加节点
graph.add_node("n1", NodeType.INPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n1"]["class_name"] = "InputNodeImpl"
graph.add_node("n2", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n2"]["class_name"] = "ArrayMapNode"
graph.add_node("n3", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n3"]["class_name"] = "ArrayReduceNode"
graph.add_node("n4", NodeType.OUTPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
graph.nodes["n4"]["class_name"] = "OutputNodeImpl"
# 添加边InputNode 的输出端口直接是 global_context 中的字段名)
graph.add_edge("n1", "values", "n2", "values", DimensionMode.NONE)
graph.add_edge("n2", "mapped", "n3", "values", DimensionMode.NONE)
graph.add_edge("n3", "result", "n4", "input", DimensionMode.NONE)
# 测试操作
print(f" 节点数: {len(graph.nodes)}")
print(f" 边数: {len(graph.edges)}")
# 检查循环
has_cycle = graph.has_cycle()
print(f" 无循环: {'✅ YES' if not has_cycle else '❌ NO'}")
# 拓扑排序
try:
topo = graph.topological_sort()
print(f" 拓扑排序: {''.join(topo)}")
print(" ✅ 排序成功")
except ValueError as e:
print(f" ❌ 排序失败: {e}")
# 验证函数工作流
valid, error = graph.validate_function_workflow()
if valid:
print(" ✅ 可作为函数节点工作流")
else:
print(f" ❌ 不可作为函数节点: {error}")
async def test_nested_function_workflow():
"""测试嵌套函数节点"""
print("\n" + "=" * 60)
print("✅ 测试 6: 嵌套函数节点")
print("=" * 60)
# 子工作流:输入 → ×2 → 输出
sub_workflow_nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
sub_workflow_edges = [
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
{"source": "map", "source_port": "mapped", "target": "output", "target_port": "input"},
]
# 验证可作为函数节点
try:
valid, error = WorkflowPackager.validate_function_workflow(sub_workflow_nodes, sub_workflow_edges)
if valid:
print(" ✅ 子工作流验证通过")
else:
print(f" ❌ 子工作流验证失败: {error}")
return
except Exception as e:
print(f" ❌ 验证异常: {e}")
return
# 打包为函数节点
function_node_def = WorkflowPackager.package_as_function(
node_id="multiply_func",
nodes=sub_workflow_nodes,
edges=sub_workflow_edges,
display_name="乘以2",
description="将数组中的所有元素乘以2"
)
# package_as_function 返回的字段中包含 sub_workflow -> {nodes, edges}
# 将其展开为 WorkflowExecutor 期待的字段名
if "sub_workflow" in function_node_def:
function_node_def["sub_workflow_nodes"] = function_node_def["sub_workflow"].get("nodes", [])
function_node_def["sub_workflow_edges"] = function_node_def["sub_workflow"].get("edges", [])
print(f" ✅ 打包成功: {function_node_def['id']}")
# 创建包含函数节点的主工作流
main_nodes = [
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
function_node_def,
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
]
main_edges = [
{"source": "input", "source_port": "values", "target": "multiply_func", "target_port": "values"},
{"source": "multiply_func", "source_port": "mapped", "target": "output", "target_port": "input"},
]
# 执行主工作流
executor = WorkflowExecutor(user_id="test_user")
success, report = await executor.execute(
nodes=main_nodes,
edges=main_edges,
global_context={"values": [1, 2, 3]}
)
if success:
print(" ✅ 函数节点执行成功")
print(f" 输入: [1, 2, 3]")
print(f" 函数执行完后应得到: [2, 4, 6]")
else:
print(f" ❌ 执行失败: {report.get('error')}")
async def main():
"""运行所有测试"""
print("\n" + "=" * 60)
print("🚀 TraceStudio 高级功能完整测试")
print("=" * 60)
try:
await test_special_nodes()
await test_dimension_inference()
await test_simple_workflow()
await test_array_operations()
await test_workflow_graph()
await test_nested_function_workflow()
print("\n" + "=" * 60)
print("✅ 所有测试完成!")
print("=" * 60)
except Exception as e:
print(f"\n❌ 测试过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())