pytrace/tests/integration/test_execution.py
Boshuang Zhao fb8458011a upload
2026-01-15 21:58:30 +08:00

150 lines
5.7 KiB
Python

"""
Integration tests for the WorkflowExecutor.
"""
import pytest
from pytrace.core.executor import WorkflowExecutor
from pytrace.core.registry import NodeRegistry, register_node, instantiate
from pytrace.core.node_base import TraceNode, input_port, output_port, parameter
from pytrace.model.graph import WorkflowGraph, Node, Edge
from pytrace.model.enums import NodeType, CachePolicy
from pytrace.core.io import NodeIO
from pytrace.core.context import TraceContext
# Fixture to clear the registry before each test
@pytest.fixture(autouse=True)
def clean_registry():
NodeRegistry.clear()
yield
# --- Define some test nodes ---
@register_node
class AddNode(TraceNode):
NAME = "Adder"
CATEGORY = "Math"
@input_port(name="a", label="A", type="float")
@input_port(name="b", label="B", type="float")
@output_port(name="sum", label="Sum", type="float")
def process(self, io: NodeIO, context: TraceContext):
a = io.get_input("a", 0.0)
b = io.get_input("b", 0.0)
result = a + b
io.set_output("sum", result)
context.log(f"AddNode: {a} + {b} = {result}")
@register_node
class MultiplyNode(TraceNode):
NAME = "Multiplier"
CATEGORY = "Math"
@input_port(name="val", label="Value", type="float")
@parameter(name="factor", label="Factor", type="float", default=2.0)
@output_port(name="product", label="Product", type="float")
def process(self, io: NodeIO, context: TraceContext):
val = io.get_input("val", 1.0)
factor = io.get_param("factor", 2.0)
result = val * factor
io.set_output("product", result)
context.log(f"MultiplyNode: {val} * {factor} = {result}")
@register_node
class InputValueNode(TraceNode):
NAME = "Input Value"
NODE_TYPE = NodeType.INPUT
@output_port(name="value", label="Value", type="Any")
def process(self, io: NodeIO, context: TraceContext):
# InputValueNode gets its value from initial_inputs of the executor
# or from a direct parameter. For this test, it will come from initial_inputs.
pass # The executor is responsible for providing its output based on initial_inputs
@register_node
class OutputCollectNode(TraceNode):
NAME = "Output Collector"
NODE_TYPE = NodeType.OUTPUT
@input_port(name="input", label="Input", type="Any")
def process(self, io: NodeIO, context: TraceContext):
# OutputCollectNode typically just consumes an input,
# its value is then part of the final execution report.
pass
# --- Tests ---
def test_simple_execution_pipeline():
# Graph: InputValue -> Add -> Multiply -> OutputCollect
graph = WorkflowGraph(name="Simple Pipeline")
# Add nodes
graph.add_node(Node("input_a", "pytrace.nodes.specials.InputNode", "Initial A", {}))
graph.add_node(Node("input_b", "pytrace.nodes.specials.InputNode", "Initial B", {}))
graph.add_node(Node("add_1", "pytrace.core.registry.AddNode", "Adder 1", {}))
graph.add_node(Node("mult_1", "pytrace.core.registry.MultiplyNode", "Multiplier 1", {"factor": 3.0}))
graph.add_node(Node("output_res", "pytrace.nodes.specials.OutputNode", "Final Output", {}))
# Add edges
graph.add_edge(Edge("input_a", "output", "add_1", "a"))
graph.add_edge(Edge("input_b", "output", "add_1", "b"))
graph.add_edge(Edge("add_1", "sum", "mult_1", "val"))
graph.add_edge(Edge("mult_1", "product", "output_res", "input"))
# Initial inputs for input nodes
initial_inputs = {
"input_a": {"output": 5.0},
"input_b": {"output": 3.0}
}
executor = WorkflowExecutor()
results = executor.execute(graph, initial_inputs)
assert results is not None
# Check if final output from output_res node is correct
# (5 + 3) * 3 = 24
output_collect_node_outputs = results.get("output_res", {})
assert output_collect_node_outputs.get("input").value == 24.0
# Verify intermediate results (e.g., from AddNode)
add_node_outputs = results.get("add_1", {})
assert add_node_outputs.get("sum").value == 8.0
def test_graph_with_default_param():
# Graph: InputValue -> Multiply (default factor) -> OutputCollect
graph = WorkflowGraph(name="Default Param Test")
graph.add_node(Node("input_val", "pytrace.nodes.specials.InputNode", "Input", {}))
graph.add_node(Node("mult_default", "pytrace.core.registry.MultiplyNode", "Multiplier Default", {})) # No factor param
graph.add_node(Node("output_final", "pytrace.nodes.specials.OutputNode", "Output", {}))
graph.add_edge(Edge("input_val", "output", "mult_default", "val"))
graph.add_edge(Edge("mult_default", "product", "output_final", "input"))
initial_inputs = {"input_val": {"output": 10.0}}
executor = WorkflowExecutor()
results = executor.execute(graph, initial_inputs)
assert results is not None
# 10 * 2.0 (default) = 20.0
output_collect_node_outputs = results.get("output_final", {})
assert output_collect_node_outputs.get("input").value == 20.0
def test_unregistered_node_type_raises_error():
graph = WorkflowGraph(name="Invalid Node Type")
graph.add_node(Node("invalid_node", "non.existent.Node", "Non Existent", {}))
executor = WorkflowExecutor()
with pytest.raises(ValueError, match="Unknown node type"):
executor.execute(graph)
def test_cycle_in_graph_raises_error():
graph = WorkflowGraph(name="Cyclic Graph")
graph.add_node(Node("n1", "pytrace.nodes.specials.InputNode", "N1", {}))
graph.add_node(Node("n2", "pytrace.nodes.specials.InputNode", "N2", {}))
graph.add_edge(Edge("n1", "output", "n2", "input"))
graph.add_edge(Edge("n2", "output", "n1", "input")) # Cycle
executor = WorkflowExecutor()
with pytest.raises(ValueError, match="Graph contains a cycle"):
executor.execute(graph)