""" 高级功能测试 - 特殊节点、维度转换、函数节点嵌套 """ import asyncio from pathlib import Path # 导入节点和执行器 # ==== 路径修正:兼容直接运行和服务器导入 ==== import sys from pathlib import Path cur = Path(__file__).resolve() root = cur.parent.parent if str(root) not in sys.path: sys.path.insert(0, str(root)) from app.core.node_registry import NodeRegistry from app.core.node_base import ( DimensionMode, WorkflowPackager, NodeType ) from app.core.workflow_executor import WorkflowExecutor, WorkflowGraph from app.nodes.advanced_example_nodes import * async def test_special_nodes(): """测试特殊节点注册""" print("\n" + "=" * 60) print("✅ 测试 1: 特殊节点注册") print("=" * 60) # 检查特殊节点是否注册 special_nodes = ["InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl"] for node_type in special_nodes: exists = NodeRegistry.exists(node_type) status = "✅ REGISTERED" if exists else "❌ MISSING" print(f" {node_type:<20} {status}") # 检查数组操作节点 array_nodes = ["ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode", "ArrayReduceNode", "BroadcastNode", "ArrayZipNode"] print("\n 数组操作节点:") for node_type in array_nodes: exists = NodeRegistry.exists(node_type) status = "✅" if exists else "❌" print(f" {status} {node_type}") print(f"\n 总注册节点数: {len(NodeRegistry.list_all())}") async def test_dimension_inference(): """测试维度转换推断""" print("\n" + "=" * 60) print("✅ 测试 2: 维度转换推断") print("=" * 60) print(" (维度推断测试已迁移至独立测试文件)") async def test_simple_workflow(): """测试简单工作流(无维度转换)""" print("\n" + "=" * 60) print("✅ 测试 3: 简单工作流") print("=" * 60) # 定义工作流:输入 → ArrayMapNode(×2) → ArrayReduceNode(sum) → 输出 nodes = [ {"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}}, {"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}}, {"id": "reduce", "type": "normal", "class_name": "ArrayReduceNode", "params": {"operation": "sum"}}, {"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}}, ] edges = [ # InputNode 直接输出全局上下文中的 values 字段 {"source": "input", "source_port": "values", "target": "map", "target_port": "values"}, {"source": "map", "source_port": "mapped", "target": "reduce", "target_port": "values"}, {"source": "reduce", "source_port": "result", "target": "output", "target_port": "input"}, ] executor = WorkflowExecutor(user_id="test_user") success, report = await executor.execute( nodes=nodes, edges=edges, global_context={"values": [1, 2, 3, 4, 5]} ) if success: print(" ✅ 工作流执行成功") print(f" 输入: [1, 2, 3, 4, 5]") print(f" 乘以2: [2, 4, 6, 8, 10]") print(f" 求和: 30") # 验证结果 node_results = report.get("node_results", {}) if "output" in node_results: output_result = node_results["output"] print(f" ✅ 最终输出: {output_result}") else: print(f" ❌ 工作流执行失败: {report.get('error')}") print(f" 执行耗时: {report.get('total_duration', 0):.3f}s") async def test_array_operations(): """测试各种数组操作""" print("\n" + "=" * 60) print("✅ 测试 4: 数组操作") print("=" * 60) # 测试用例 test_cases = [ { "name": "数组过滤", "nodes": [ {"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}}, {"id": "filter", "type": "normal", "class_name": "ArrayFilterNode", "params": {"threshold": 2}}, {"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}}, ], "edges": [ {"source": "input", "source_port": "values", "target": "filter", "target_port": "values"}, {"source": "filter", "source_port": "filtered", "target": "output", "target_port": "input"}, ], "global_context": {"values": [1, 2, 3, 4, 5]}, "expected": [3, 4, 5] }, { "name": "广播", "nodes": [ {"id": "input1", "type": "input", "class_name": "InputNodeImpl", "params": {}}, {"id": "broadcast", "type": "normal", "class_name": "BroadcastNode", "params": {"count": 3}}, {"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}}, ], "edges": [ {"source": "input1", "source_port": "value", "target": "broadcast", "target_port": "value"}, {"source": "broadcast", "source_port": "broadcast", "target": "output", "target_port": "input"}, ], "global_context": {"value": 42}, "expected": [42, 42, 42] } ] for test_case in test_cases: print(f"\n 测试: {test_case['name']}") executor = WorkflowExecutor(user_id="test_user") success, report = await executor.execute( nodes=test_case["nodes"], edges=test_case["edges"], global_context=test_case["global_context"] ) if success: print(f" ✅ 执行成功") else: print(f" ❌ 执行失败: {report.get('error')}") async def test_workflow_graph(): """测试工作流图操作""" print("\n" + "=" * 60) print("✅ 测试 5: 工作流图操作") print("=" * 60) graph = WorkflowGraph() # 添加节点 graph.add_node("n1", NodeType.INPUT, sub_workflow_nodes=None, sub_workflow_edges=None) graph.nodes["n1"]["class_name"] = "InputNodeImpl" graph.add_node("n2", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None) graph.nodes["n2"]["class_name"] = "ArrayMapNode" graph.add_node("n3", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None) graph.nodes["n3"]["class_name"] = "ArrayReduceNode" graph.add_node("n4", NodeType.OUTPUT, sub_workflow_nodes=None, sub_workflow_edges=None) graph.nodes["n4"]["class_name"] = "OutputNodeImpl" # 添加边(InputNode 的输出端口直接是 global_context 中的字段名) graph.add_edge("n1", "values", "n2", "values", DimensionMode.NONE) graph.add_edge("n2", "mapped", "n3", "values", DimensionMode.NONE) graph.add_edge("n3", "result", "n4", "input", DimensionMode.NONE) # 测试操作 print(f" 节点数: {len(graph.nodes)}") print(f" 边数: {len(graph.edges)}") # 检查循环 has_cycle = graph.has_cycle() print(f" 无循环: {'✅ YES' if not has_cycle else '❌ NO'}") # 拓扑排序 try: topo = graph.topological_sort() print(f" 拓扑排序: {' → '.join(topo)}") print(" ✅ 排序成功") except ValueError as e: print(f" ❌ 排序失败: {e}") # 验证函数工作流 valid, error = graph.validate_function_workflow() if valid: print(" ✅ 可作为函数节点工作流") else: print(f" ❌ 不可作为函数节点: {error}") async def test_nested_function_workflow(): """测试嵌套函数节点""" print("\n" + "=" * 60) print("✅ 测试 6: 嵌套函数节点") print("=" * 60) # 子工作流:输入 → ×2 → 输出 sub_workflow_nodes = [ {"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}}, {"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}}, {"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}}, ] sub_workflow_edges = [ {"source": "input", "source_port": "values", "target": "map", "target_port": "values"}, {"source": "map", "source_port": "mapped", "target": "output", "target_port": "input"}, ] # 验证可作为函数节点 try: valid, error = WorkflowPackager.validate_function_workflow(sub_workflow_nodes, sub_workflow_edges) if valid: print(" ✅ 子工作流验证通过") else: print(f" ❌ 子工作流验证失败: {error}") return except Exception as e: print(f" ❌ 验证异常: {e}") return # 打包为函数节点 function_node_def = WorkflowPackager.package_as_function( node_id="multiply_func", nodes=sub_workflow_nodes, edges=sub_workflow_edges, display_name="乘以2", description="将数组中的所有元素乘以2" ) # package_as_function 返回的字段中包含 sub_workflow -> {nodes, edges} # 将其展开为 WorkflowExecutor 期待的字段名 if "sub_workflow" in function_node_def: function_node_def["sub_workflow_nodes"] = function_node_def["sub_workflow"].get("nodes", []) function_node_def["sub_workflow_edges"] = function_node_def["sub_workflow"].get("edges", []) print(f" ✅ 打包成功: {function_node_def['id']}") # 创建包含函数节点的主工作流 main_nodes = [ {"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}}, function_node_def, {"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}}, ] main_edges = [ {"source": "input", "source_port": "values", "target": "multiply_func", "target_port": "values"}, {"source": "multiply_func", "source_port": "mapped", "target": "output", "target_port": "input"}, ] # 执行主工作流 executor = WorkflowExecutor(user_id="test_user") success, report = await executor.execute( nodes=main_nodes, edges=main_edges, global_context={"values": [1, 2, 3]} ) if success: print(" ✅ 函数节点执行成功") print(f" 输入: [1, 2, 3]") print(f" 函数执行完后应得到: [2, 4, 6]") else: print(f" ❌ 执行失败: {report.get('error')}") async def main(): """运行所有测试""" print("\n" + "=" * 60) print("🚀 TraceStudio 高级功能完整测试") print("=" * 60) try: await test_special_nodes() await test_dimension_inference() await test_simple_workflow() await test_array_operations() await test_workflow_graph() await test_nested_function_workflow() print("\n" + "=" * 60) print("✅ 所有测试完成!") print("=" * 60) except Exception as e: print(f"\n❌ 测试过程中出现错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": asyncio.run(main())