""" Integration tests for the WorkflowExecutor. """ import pytest from pytrace.core.executor import WorkflowExecutor from pytrace.core.registry import NodeRegistry, register_node, instantiate from pytrace.core.node_base import TraceNode, input_port, output_port, parameter from pytrace.model.graph import WorkflowGraph, Node, Edge from pytrace.model.enums import NodeType, CachePolicy from pytrace.core.io import NodeIO from pytrace.core.context import TraceContext # Fixture to clear the registry before each test @pytest.fixture(autouse=True) def clean_registry(): NodeRegistry.clear() yield # --- Define some test nodes --- @register_node class AddNode(TraceNode): NAME = "Adder" CATEGORY = "Math" @input_port(name="a", label="A", type="float") @input_port(name="b", label="B", type="float") @output_port(name="sum", label="Sum", type="float") def process(self, io: NodeIO, context: TraceContext): a = io.get_input("a", 0.0) b = io.get_input("b", 0.0) result = a + b io.set_output("sum", result) context.log(f"AddNode: {a} + {b} = {result}") @register_node class MultiplyNode(TraceNode): NAME = "Multiplier" CATEGORY = "Math" @input_port(name="val", label="Value", type="float") @parameter(name="factor", label="Factor", type="float", default=2.0) @output_port(name="product", label="Product", type="float") def process(self, io: NodeIO, context: TraceContext): val = io.get_input("val", 1.0) factor = io.get_param("factor", 2.0) result = val * factor io.set_output("product", result) context.log(f"MultiplyNode: {val} * {factor} = {result}") @register_node class InputValueNode(TraceNode): NAME = "Input Value" NODE_TYPE = NodeType.INPUT @output_port(name="value", label="Value", type="Any") def process(self, io: NodeIO, context: TraceContext): # InputValueNode gets its value from initial_inputs of the executor # or from a direct parameter. For this test, it will come from initial_inputs. pass # The executor is responsible for providing its output based on initial_inputs @register_node class OutputCollectNode(TraceNode): NAME = "Output Collector" NODE_TYPE = NodeType.OUTPUT @input_port(name="input", label="Input", type="Any") def process(self, io: NodeIO, context: TraceContext): # OutputCollectNode typically just consumes an input, # its value is then part of the final execution report. pass # --- Tests --- def test_simple_execution_pipeline(): # Graph: InputValue -> Add -> Multiply -> OutputCollect graph = WorkflowGraph(name="Simple Pipeline") # Add nodes graph.add_node(Node("input_a", "pytrace.nodes.specials.InputNode", "Initial A", {})) graph.add_node(Node("input_b", "pytrace.nodes.specials.InputNode", "Initial B", {})) graph.add_node(Node("add_1", "pytrace.core.registry.AddNode", "Adder 1", {})) graph.add_node(Node("mult_1", "pytrace.core.registry.MultiplyNode", "Multiplier 1", {"factor": 3.0})) graph.add_node(Node("output_res", "pytrace.nodes.specials.OutputNode", "Final Output", {})) # Add edges graph.add_edge(Edge("input_a", "output", "add_1", "a")) graph.add_edge(Edge("input_b", "output", "add_1", "b")) graph.add_edge(Edge("add_1", "sum", "mult_1", "val")) graph.add_edge(Edge("mult_1", "product", "output_res", "input")) # Initial inputs for input nodes initial_inputs = { "input_a": {"output": 5.0}, "input_b": {"output": 3.0} } executor = WorkflowExecutor() results = executor.execute(graph, initial_inputs) assert results is not None # Check if final output from output_res node is correct # (5 + 3) * 3 = 24 output_collect_node_outputs = results.get("output_res", {}) assert output_collect_node_outputs.get("input").value == 24.0 # Verify intermediate results (e.g., from AddNode) add_node_outputs = results.get("add_1", {}) assert add_node_outputs.get("sum").value == 8.0 def test_graph_with_default_param(): # Graph: InputValue -> Multiply (default factor) -> OutputCollect graph = WorkflowGraph(name="Default Param Test") graph.add_node(Node("input_val", "pytrace.nodes.specials.InputNode", "Input", {})) graph.add_node(Node("mult_default", "pytrace.core.registry.MultiplyNode", "Multiplier Default", {})) # No factor param graph.add_node(Node("output_final", "pytrace.nodes.specials.OutputNode", "Output", {})) graph.add_edge(Edge("input_val", "output", "mult_default", "val")) graph.add_edge(Edge("mult_default", "product", "output_final", "input")) initial_inputs = {"input_val": {"output": 10.0}} executor = WorkflowExecutor() results = executor.execute(graph, initial_inputs) assert results is not None # 10 * 2.0 (default) = 20.0 output_collect_node_outputs = results.get("output_final", {}) assert output_collect_node_outputs.get("input").value == 20.0 def test_unregistered_node_type_raises_error(): graph = WorkflowGraph(name="Invalid Node Type") graph.add_node(Node("invalid_node", "non.existent.Node", "Non Existent", {})) executor = WorkflowExecutor() with pytest.raises(ValueError, match="Unknown node type"): executor.execute(graph) def test_cycle_in_graph_raises_error(): graph = WorkflowGraph(name="Cyclic Graph") graph.add_node(Node("n1", "pytrace.nodes.specials.InputNode", "N1", {})) graph.add_node(Node("n2", "pytrace.nodes.specials.InputNode", "N2", {})) graph.add_edge(Edge("n1", "output", "n2", "input")) graph.add_edge(Edge("n2", "output", "n1", "input")) # Cycle executor = WorkflowExecutor() with pytest.raises(ValueError, match="Graph contains a cycle"): executor.execute(graph)