146 lines
5.3 KiB
Python
146 lines
5.3 KiB
Python
"""
|
|
Tests for the Node Registry.
|
|
"""
|
|
import pytest
|
|
from pytrace.core.registry import NodeRegistry, NodeEntry, get_type_string
|
|
from pytrace.core.node_base import TraceNode, input_port, output_port, parameter, register_node, expose_node
|
|
from pytrace.model.specs import NodeSpec, PortSpec, ParamSpec
|
|
from pytrace.model.enums import NodeType, CachePolicy
|
|
|
|
# Ensure registry is clean before each test
|
|
@pytest.fixture(autouse=True)
|
|
def clean_registry():
|
|
NodeRegistry.clear()
|
|
yield
|
|
|
|
# --- Test Standard Node Registration ---
|
|
|
|
@register_node
|
|
class MyStandardNode(TraceNode):
|
|
NAME = "Standard Test Node"
|
|
CATEGORY = "Test"
|
|
TAGS = ["basic", "standard"]
|
|
|
|
@input_port(name="in1", label="Input 1", type="Any")
|
|
@output_port(name="out1", label="Output 1", type="Any")
|
|
def process(self, io, context):
|
|
pass
|
|
|
|
def test_standard_node_registration():
|
|
node_type = get_type_string(MyStandardNode)
|
|
entry = NodeRegistry._entries.get(node_type)
|
|
|
|
assert entry is not None
|
|
assert entry.spec.name == "Standard Test Node"
|
|
assert entry.spec.category == "Test"
|
|
assert "basic" in entry.spec.tags
|
|
assert len(entry.spec.inputs) == 1
|
|
assert entry.spec.inputs[0].name == "in1"
|
|
assert len(entry.spec.outputs) == 1
|
|
assert entry.spec.outputs[0].name == "out1"
|
|
assert entry.factory(uid="test", name="test", params={}).__class__ == MyStandardNode
|
|
|
|
# --- Test Node Suite Registration ---
|
|
|
|
class MyNodeSuite(TraceNode):
|
|
CATEGORY = "Suite"
|
|
TAGS = ["suite", "virtual"]
|
|
|
|
@expose_node(display_name="Suite Method 1", icon="✅")
|
|
@input_port(name="s_in1", label="Suite Input 1", type="str")
|
|
def suite_method_1(self, io, context):
|
|
pass
|
|
|
|
@expose_node(display_name="Suite Method 2")
|
|
@output_port(name="s_out1", label="Suite Output 1", type="str")
|
|
@parameter(name="s_param1", label="Suite Param 1", type="int", default=10)
|
|
def suite_method_2(self, io, context):
|
|
pass
|
|
|
|
# Manually register the suite class after defining it
|
|
NodeRegistry.register(MyNodeSuite)
|
|
|
|
def test_node_suite_registration():
|
|
# MyNodeSuite itself should not be registered as a direct executable node
|
|
assert not NodeRegistry._is_standard_node(MyNodeSuite)
|
|
assert get_type_string(MyNodeSuite) not in NodeRegistry._entries
|
|
|
|
# Exposed methods should be registered as virtual nodes
|
|
method1_type = get_type_string(MyNodeSuite, "suite_method_1")
|
|
method2_type = get_type_string(MyNodeSuite, "suite_method_2")
|
|
|
|
entry1 = NodeRegistry._entries.get(method1_type)
|
|
entry2 = NodeRegistry._entries.get(method2_type)
|
|
|
|
assert entry1 is not None
|
|
assert entry1.spec.name == "Suite Method 1"
|
|
assert entry1.spec.category == "Suite"
|
|
assert "suite" in entry1.spec.tags
|
|
assert "virtual" in entry1.spec.tags
|
|
assert len(entry1.spec.inputs) == 1
|
|
assert entry1.spec.inputs[0].name == "s_in1"
|
|
assert len(entry1.spec.outputs) == 0 # No output_port on suite_method_1
|
|
|
|
assert entry2 is not None
|
|
assert entry2.spec.name == "Suite Method 2"
|
|
assert entry2.spec.category == "Suite"
|
|
assert len(entry2.spec.outputs) == 1
|
|
assert entry2.spec.outputs[0].name == "s_out1"
|
|
assert len(entry2.spec.params) == 1
|
|
assert entry2.spec.params[0].name == "s_param1"
|
|
|
|
# Test instantiation
|
|
instance1 = entry1.factory(uid="s1", name="Suite1", params={})
|
|
assert isinstance(instance1, TraceNode)
|
|
assert hasattr(instance1, "process")
|
|
|
|
# --- Test Remote Node Registration ---
|
|
def test_remote_node_registration():
|
|
remote_spec = NodeSpec(
|
|
type="my.remote.NodeType",
|
|
name="Remote Test Node",
|
|
category="Remote",
|
|
tags=["remote"],
|
|
inputs=[PortSpec(name="r_in", label="Remote In", type="str")]
|
|
)
|
|
remote_config = {"url": "http://remote.service/api"}
|
|
|
|
NodeRegistry.register_remote(remote_spec, remote_config)
|
|
|
|
entry = NodeRegistry._entries.get("my.remote.NodeType")
|
|
assert entry is not None
|
|
assert entry.spec.name == "Remote Test Node"
|
|
assert "remote" in entry.spec.tags
|
|
assert entry.origin == "remote"
|
|
|
|
# Test factory creates a RemoteProxyNode
|
|
instance = entry.factory(uid="r1", name="Remote1", params={})
|
|
from pytrace.nodes.proxies import RemoteProxyNode # Import here to avoid circular
|
|
assert isinstance(instance, RemoteProxyNode)
|
|
assert instance.remote_id == "my.remote.NodeType"
|
|
assert instance.config == remote_config
|
|
|
|
# --- Test Instantiation ---
|
|
def test_instantiate_unified_interface():
|
|
# Standard node
|
|
instance = NodeRegistry.instantiate(get_type_string(MyStandardNode), uid="std1", name="StdNode", params={})
|
|
assert isinstance(instance, MyStandardNode)
|
|
assert instance.uid == "std1"
|
|
|
|
# Virtual node from suite
|
|
method1_type = get_type_string(MyNodeSuite, "suite_method_1")
|
|
instance = NodeRegistry.instantiate(method1_type, uid="v1", name="VirtualNode", params={})
|
|
assert isinstance(instance, TraceNode)
|
|
assert hasattr(instance, "process")
|
|
|
|
# Remote node
|
|
from pytrace.nodes.proxies import RemoteProxyNode
|
|
remote_type = "my.remote.NodeType"
|
|
instance = NodeRegistry.instantiate(remote_type, uid="rem1", name="RemNode", params={})
|
|
assert isinstance(instance, RemoteProxyNode)
|
|
assert instance.uid == "rem1"
|
|
assert instance.remote_id == remote_type
|
|
|
|
with pytest.raises(ValueError):
|
|
NodeRegistry.instantiate("non.existent.Node", uid="bad", name="Bad", params={})
|