150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
"""
|
|
Integration tests for the WorkflowExecutor.
|
|
"""
|
|
import pytest
|
|
from pytrace.core.executor import WorkflowExecutor
|
|
from pytrace.core.registry import NodeRegistry, register_node, instantiate
|
|
from pytrace.core.node_base import TraceNode, input_port, output_port, parameter
|
|
from pytrace.model.graph import WorkflowGraph, Node, Edge
|
|
from pytrace.model.enums import NodeType, CachePolicy
|
|
from pytrace.core.io import NodeIO
|
|
from pytrace.core.context import TraceContext
|
|
|
|
# Fixture to clear the registry before each test
|
|
@pytest.fixture(autouse=True)
|
|
def clean_registry():
|
|
NodeRegistry.clear()
|
|
yield
|
|
|
|
# --- Define some test nodes ---
|
|
|
|
@register_node
|
|
class AddNode(TraceNode):
|
|
NAME = "Adder"
|
|
CATEGORY = "Math"
|
|
|
|
@input_port(name="a", label="A", type="float")
|
|
@input_port(name="b", label="B", type="float")
|
|
@output_port(name="sum", label="Sum", type="float")
|
|
def process(self, io: NodeIO, context: TraceContext):
|
|
a = io.get_input("a", 0.0)
|
|
b = io.get_input("b", 0.0)
|
|
result = a + b
|
|
io.set_output("sum", result)
|
|
context.log(f"AddNode: {a} + {b} = {result}")
|
|
|
|
@register_node
|
|
class MultiplyNode(TraceNode):
|
|
NAME = "Multiplier"
|
|
CATEGORY = "Math"
|
|
|
|
@input_port(name="val", label="Value", type="float")
|
|
@parameter(name="factor", label="Factor", type="float", default=2.0)
|
|
@output_port(name="product", label="Product", type="float")
|
|
def process(self, io: NodeIO, context: TraceContext):
|
|
val = io.get_input("val", 1.0)
|
|
factor = io.get_param("factor", 2.0)
|
|
result = val * factor
|
|
io.set_output("product", result)
|
|
context.log(f"MultiplyNode: {val} * {factor} = {result}")
|
|
|
|
@register_node
|
|
class InputValueNode(TraceNode):
|
|
NAME = "Input Value"
|
|
NODE_TYPE = NodeType.INPUT
|
|
|
|
@output_port(name="value", label="Value", type="Any")
|
|
def process(self, io: NodeIO, context: TraceContext):
|
|
# InputValueNode gets its value from initial_inputs of the executor
|
|
# or from a direct parameter. For this test, it will come from initial_inputs.
|
|
pass # The executor is responsible for providing its output based on initial_inputs
|
|
|
|
@register_node
|
|
class OutputCollectNode(TraceNode):
|
|
NAME = "Output Collector"
|
|
NODE_TYPE = NodeType.OUTPUT
|
|
|
|
@input_port(name="input", label="Input", type="Any")
|
|
def process(self, io: NodeIO, context: TraceContext):
|
|
# OutputCollectNode typically just consumes an input,
|
|
# its value is then part of the final execution report.
|
|
pass
|
|
|
|
# --- Tests ---
|
|
|
|
def test_simple_execution_pipeline():
|
|
# Graph: InputValue -> Add -> Multiply -> OutputCollect
|
|
graph = WorkflowGraph(name="Simple Pipeline")
|
|
|
|
# Add nodes
|
|
graph.add_node(Node("input_a", "pytrace.nodes.specials.InputNode", "Initial A", {}))
|
|
graph.add_node(Node("input_b", "pytrace.nodes.specials.InputNode", "Initial B", {}))
|
|
graph.add_node(Node("add_1", "pytrace.core.registry.AddNode", "Adder 1", {}))
|
|
graph.add_node(Node("mult_1", "pytrace.core.registry.MultiplyNode", "Multiplier 1", {"factor": 3.0}))
|
|
graph.add_node(Node("output_res", "pytrace.nodes.specials.OutputNode", "Final Output", {}))
|
|
|
|
# Add edges
|
|
graph.add_edge(Edge("input_a", "output", "add_1", "a"))
|
|
graph.add_edge(Edge("input_b", "output", "add_1", "b"))
|
|
graph.add_edge(Edge("add_1", "sum", "mult_1", "val"))
|
|
graph.add_edge(Edge("mult_1", "product", "output_res", "input"))
|
|
|
|
# Initial inputs for input nodes
|
|
initial_inputs = {
|
|
"input_a": {"output": 5.0},
|
|
"input_b": {"output": 3.0}
|
|
}
|
|
|
|
executor = WorkflowExecutor()
|
|
results = executor.execute(graph, initial_inputs)
|
|
|
|
assert results is not None
|
|
# Check if final output from output_res node is correct
|
|
# (5 + 3) * 3 = 24
|
|
output_collect_node_outputs = results.get("output_res", {})
|
|
assert output_collect_node_outputs.get("input").value == 24.0
|
|
|
|
# Verify intermediate results (e.g., from AddNode)
|
|
add_node_outputs = results.get("add_1", {})
|
|
assert add_node_outputs.get("sum").value == 8.0
|
|
|
|
def test_graph_with_default_param():
|
|
# Graph: InputValue -> Multiply (default factor) -> OutputCollect
|
|
graph = WorkflowGraph(name="Default Param Test")
|
|
|
|
graph.add_node(Node("input_val", "pytrace.nodes.specials.InputNode", "Input", {}))
|
|
graph.add_node(Node("mult_default", "pytrace.core.registry.MultiplyNode", "Multiplier Default", {})) # No factor param
|
|
graph.add_node(Node("output_final", "pytrace.nodes.specials.OutputNode", "Output", {}))
|
|
|
|
graph.add_edge(Edge("input_val", "output", "mult_default", "val"))
|
|
graph.add_edge(Edge("mult_default", "product", "output_final", "input"))
|
|
|
|
initial_inputs = {"input_val": {"output": 10.0}}
|
|
|
|
executor = WorkflowExecutor()
|
|
results = executor.execute(graph, initial_inputs)
|
|
|
|
assert results is not None
|
|
# 10 * 2.0 (default) = 20.0
|
|
output_collect_node_outputs = results.get("output_final", {})
|
|
assert output_collect_node_outputs.get("input").value == 20.0
|
|
|
|
def test_unregistered_node_type_raises_error():
|
|
graph = WorkflowGraph(name="Invalid Node Type")
|
|
graph.add_node(Node("invalid_node", "non.existent.Node", "Non Existent", {}))
|
|
|
|
executor = WorkflowExecutor()
|
|
with pytest.raises(ValueError, match="Unknown node type"):
|
|
executor.execute(graph)
|
|
|
|
def test_cycle_in_graph_raises_error():
|
|
graph = WorkflowGraph(name="Cyclic Graph")
|
|
graph.add_node(Node("n1", "pytrace.nodes.specials.InputNode", "N1", {}))
|
|
graph.add_node(Node("n2", "pytrace.nodes.specials.InputNode", "N2", {}))
|
|
graph.add_edge(Edge("n1", "output", "n2", "input"))
|
|
graph.add_edge(Edge("n2", "output", "n1", "input")) # Cycle
|
|
|
|
executor = WorkflowExecutor()
|
|
with pytest.raises(ValueError, match="Graph contains a cycle"):
|
|
executor.execute(graph)
|