116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
|
|
"""
|
||
|
|
Tests for Spec generation and Pydantic validation.
|
||
|
|
"""
|
||
|
|
import pytest
|
||
|
|
from pydantic import ValidationError
|
||
|
|
from pytrace.model.specs import NodeSpec, PortSpec, ParamSpec
|
||
|
|
from pytrace.model.enums import NodeType, DimensionMode, CachePolicy
|
||
|
|
|
||
|
|
def test_portspec_creation():
|
||
|
|
port = PortSpec(name="input_port", label="Input Port", type="str")
|
||
|
|
assert port.name == "input_port"
|
||
|
|
assert port.label == "Input Port"
|
||
|
|
assert port.type == "str"
|
||
|
|
assert port.dimension == DimensionMode.SCALAR
|
||
|
|
assert port.description == ""
|
||
|
|
|
||
|
|
def test_portspec_with_dimension_and_description():
|
||
|
|
port = PortSpec(
|
||
|
|
name="list_port",
|
||
|
|
label="List Port",
|
||
|
|
type="List[int]",
|
||
|
|
dimension=DimensionMode.LIST,
|
||
|
|
description="A list of integers"
|
||
|
|
)
|
||
|
|
assert port.dimension == DimensionMode.LIST
|
||
|
|
assert port.description == "A list of integers"
|
||
|
|
|
||
|
|
def test_paramspec_creation():
|
||
|
|
param = ParamSpec(name="my_param", label="My Param", type="int", default=10)
|
||
|
|
assert param.name == "my_param"
|
||
|
|
assert param.label == "My Param"
|
||
|
|
assert param.type == "int"
|
||
|
|
assert param.default == 10
|
||
|
|
assert param.options is None
|
||
|
|
|
||
|
|
def test_paramspec_with_options():
|
||
|
|
param = ParamSpec(
|
||
|
|
name="choice",
|
||
|
|
label="Choice Param",
|
||
|
|
type="str",
|
||
|
|
options=["A", "B", "C"]
|
||
|
|
)
|
||
|
|
assert param.options == ["A", "B", "C"]
|
||
|
|
|
||
|
|
def test_nodespec_creation():
|
||
|
|
node_spec = NodeSpec(
|
||
|
|
type="test_node.MyNode",
|
||
|
|
name="My Test Node",
|
||
|
|
category="Test",
|
||
|
|
description="A test node",
|
||
|
|
version="1.0.0",
|
||
|
|
cache_policy=CachePolicy.PERMANENT,
|
||
|
|
tags=["example", "test"]
|
||
|
|
)
|
||
|
|
assert node_spec.type == "test_node.MyNode"
|
||
|
|
assert node_spec.name == "My Test Node"
|
||
|
|
assert node_spec.node_type == NodeType.COMMON
|
||
|
|
assert node_spec.category == "Test"
|
||
|
|
assert node_spec.description == "A test node"
|
||
|
|
assert node_spec.version == "1.0.0"
|
||
|
|
assert node_spec.cache_policy == CachePolicy.PERMANENT
|
||
|
|
assert node_spec.tags == ["example", "test"]
|
||
|
|
assert node_spec.inputs == []
|
||
|
|
assert node_spec.outputs == []
|
||
|
|
assert node_spec.params == []
|
||
|
|
|
||
|
|
def test_nodespec_with_ports_and_params():
|
||
|
|
input_port = PortSpec(name="in1", label="Input One", type="float")
|
||
|
|
output_port = PortSpec(name="out1", label="Output One", type="float")
|
||
|
|
param = ParamSpec(name="factor", label="Factor", type="float", default=1.0)
|
||
|
|
|
||
|
|
node_spec = NodeSpec(
|
||
|
|
type="calc.Multiply",
|
||
|
|
name="Multiply Node",
|
||
|
|
inputs=[input_port],
|
||
|
|
outputs=[output_port],
|
||
|
|
params=[param]
|
||
|
|
)
|
||
|
|
assert len(node_spec.inputs) == 1
|
||
|
|
assert node_spec.inputs[0] == input_port
|
||
|
|
assert len(node_spec.outputs) == 1
|
||
|
|
assert node_spec.outputs[0] == output_port
|
||
|
|
assert len(node_spec.params) == 1
|
||
|
|
assert node_spec.params[0] == param
|
||
|
|
|
||
|
|
def test_nodespec_default_values():
|
||
|
|
node_spec = NodeSpec(type="simple.Node", name="Simple Node")
|
||
|
|
assert node_spec.node_type == NodeType.COMMON
|
||
|
|
assert node_spec.category == "Default"
|
||
|
|
assert node_spec.version == "1.0.0"
|
||
|
|
assert node_spec.cache_policy == CachePolicy.SESSION
|
||
|
|
assert node_spec.tags == []
|
||
|
|
|
||
|
|
def test_nodespec_validation_error():
|
||
|
|
with pytest.raises(ValidationError):
|
||
|
|
# Missing required 'type' field
|
||
|
|
NodeSpec(name="Invalid Node")
|
||
|
|
|
||
|
|
with pytest.raises(ValidationError):
|
||
|
|
# Missing required 'name' field
|
||
|
|
NodeSpec(type="invalid.Node")
|
||
|
|
|
||
|
|
def test_portspec_validation_error():
|
||
|
|
with pytest.raises(ValidationError):
|
||
|
|
# Missing required 'name' field
|
||
|
|
PortSpec(label="Missing Name", type="str")
|
||
|
|
|
||
|
|
with pytest.raises(ValidationError):
|
||
|
|
# Invalid DimensionMode
|
||
|
|
PortSpec(name="port", label="Port", type="str", dimension="invalid_dim")
|
||
|
|
|
||
|
|
def test_paramspec_validation_error():
|
||
|
|
with pytest.raises(ValidationError):
|
||
|
|
# Missing required 'name' field
|
||
|
|
ParamSpec(label="Missing Name Param", type="int")
|