110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
"""
|
|
Tests for the MathSuite node suite and exposed methods.
|
|
"""
|
|
import pytest
|
|
from pytrace.core.executor import WorkflowExecutor
|
|
from pytrace.core.registry import NodeRegistry, get_type_string, instantiate
|
|
from pytrace.core.node_base import TraceNode
|
|
from pytrace.model.graph import WorkflowGraph, Node, Edge
|
|
from pytrace.core.io import NodeIO
|
|
from pytrace.core.context import TraceContext
|
|
|
|
# Import the MathSuite directly for testing
|
|
from pytrace.nodes.math_suite import MathSuite
|
|
|
|
# Register the MathSuite (normally done at app startup)
|
|
NodeRegistry.register(MathSuite)
|
|
|
|
# Fixture to clear the registry before each test
|
|
@pytest.fixture(autouse=True)
|
|
def clean_registry():
|
|
NodeRegistry.clear()
|
|
# Re-register MathSuite after clear
|
|
NodeRegistry.register(MathSuite)
|
|
yield
|
|
|
|
def test_math_suite_exposed_nodes_registration():
|
|
add_node_type = get_type_string(MathSuite, "add")
|
|
subtract_node_type = get_type_string(MathSuite, "subtract")
|
|
|
|
assert NodeRegistry.get_spec(add_node_type) is not None
|
|
assert NodeRegistry.get_spec(subtract_node_type) is not None
|
|
|
|
add_spec = NodeRegistry.get_spec(add_node_type)
|
|
assert add_spec.name == "Add Numbers"
|
|
assert add_spec.category == "Math"
|
|
assert "suite" in add_spec.tags
|
|
assert "calculation" in add_spec.tags
|
|
assert len(add_spec.inputs) == 2
|
|
assert add_spec.inputs[0].name == "a"
|
|
assert add_spec.inputs[1].name == "b"
|
|
assert len(add_spec.outputs) == 1
|
|
assert add_spec.outputs[0].name == "result"
|
|
|
|
subtract_spec = NodeRegistry.get_spec(subtract_node_type)
|
|
assert subtract_spec.name == "Subtract Numbers"
|
|
assert subtract_spec.category == "Math"
|
|
assert "suite" in subtract_spec.tags
|
|
assert "calculation" in subtract_spec.tags
|
|
|
|
def test_math_suite_add_node_execution(mock_io, mock_context):
|
|
add_node_type = get_type_string(MathSuite, "add")
|
|
|
|
# Instantiate the virtual node
|
|
add_node_instance = NodeRegistry.instantiate(
|
|
add_node_type, uid="add_test_node", name="Add Test", params={}
|
|
)
|
|
|
|
# Mock inputs
|
|
mock_io.get_input.side_effect = lambda name, default: {"a": 10.0, "b": 5.0}.get(name, default)
|
|
|
|
add_node_instance.process(mock_io, mock_context)
|
|
|
|
mock_io.set_output.assert_called_once_with("result", 15.0)
|
|
mock_context.log.assert_called_once_with("MathSuite.add: 10.0 + 5.0 = 15.0")
|
|
|
|
def test_math_suite_subtract_node_execution(mock_io, mock_context):
|
|
subtract_node_type = get_type_string(MathSuite, "subtract")
|
|
|
|
# Instantiate the virtual node
|
|
subtract_node_instance = NodeRegistry.instantiate(
|
|
subtract_node_type, uid="sub_test_node", name="Subtract Test", params={}
|
|
)
|
|
|
|
# Mock inputs
|
|
mock_io.get_input.side_effect = lambda name, default: {"a": 20.0, "b": 7.0}.get(name, default)
|
|
|
|
subtract_node_instance.process(mock_io, mock_context)
|
|
|
|
mock_io.set_output.assert_called_once_with("result", 13.0)
|
|
mock_context.log.assert_called_once_with("MathSuite.subtract: 20.0 - 7.0 = 13.0")
|
|
|
|
def test_math_suite_integration_with_executor():
|
|
# Graph: InputValue -> MathSuite.add -> OutputCollect
|
|
graph = WorkflowGraph(name="Math Suite Integration")
|
|
|
|
add_node_type = get_type_string(MathSuite, "add")
|
|
|
|
graph.add_node(Node("input_a", "pytrace.nodes.specials.InputNode", "Initial A", {}))
|
|
graph.add_node(Node("input_b", "pytrace.nodes.specials.InputNode", "Initial B", {}))
|
|
graph.add_node(Node("add_node", add_node_type, "Adder", {}))
|
|
graph.add_node(Node("output_res", "pytrace.nodes.specials.OutputNode", "Final Output", {}))
|
|
|
|
graph.add_edge(Edge("input_a", "output", "add_node", "a"))
|
|
graph.add_edge(Edge("input_b", "output", "add_node", "b"))
|
|
graph.add_edge(Edge("add_node", "result", "output_res", "input"))
|
|
|
|
initial_inputs = {
|
|
"input_a": {"output": 100.0},
|
|
"input_b": {"output": 25.0}
|
|
}
|
|
|
|
executor = WorkflowExecutor()
|
|
results = executor.execute(graph, initial_inputs)
|
|
|
|
assert results is not None
|
|
# Expected: 100 + 25 = 125
|
|
output_collect_node_outputs = results.get("output_res", {})
|
|
assert output_collect_node_outputs.get("input").value == 125.0
|
|
|