pytrace/tests/unit/test_spec.py

116 lines
3.8 KiB
Python
Raw Normal View History

2026-01-15 21:58:30 +08:00
"""
Tests for Spec generation and Pydantic validation.
"""
import pytest
from pydantic import ValidationError
from pytrace.model.specs import NodeSpec, PortSpec, ParamSpec
from pytrace.model.enums import NodeType, DimensionMode, CachePolicy
def test_portspec_creation():
port = PortSpec(name="input_port", label="Input Port", type="str")
assert port.name == "input_port"
assert port.label == "Input Port"
assert port.type == "str"
assert port.dimension == DimensionMode.SCALAR
assert port.description == ""
def test_portspec_with_dimension_and_description():
port = PortSpec(
name="list_port",
label="List Port",
type="List[int]",
dimension=DimensionMode.LIST,
description="A list of integers"
)
assert port.dimension == DimensionMode.LIST
assert port.description == "A list of integers"
def test_paramspec_creation():
param = ParamSpec(name="my_param", label="My Param", type="int", default=10)
assert param.name == "my_param"
assert param.label == "My Param"
assert param.type == "int"
assert param.default == 10
assert param.options is None
def test_paramspec_with_options():
param = ParamSpec(
name="choice",
label="Choice Param",
type="str",
options=["A", "B", "C"]
)
assert param.options == ["A", "B", "C"]
def test_nodespec_creation():
node_spec = NodeSpec(
type="test_node.MyNode",
name="My Test Node",
category="Test",
description="A test node",
version="1.0.0",
cache_policy=CachePolicy.PERMANENT,
tags=["example", "test"]
)
assert node_spec.type == "test_node.MyNode"
assert node_spec.name == "My Test Node"
assert node_spec.node_type == NodeType.COMMON
assert node_spec.category == "Test"
assert node_spec.description == "A test node"
assert node_spec.version == "1.0.0"
assert node_spec.cache_policy == CachePolicy.PERMANENT
assert node_spec.tags == ["example", "test"]
assert node_spec.inputs == []
assert node_spec.outputs == []
assert node_spec.params == []
def test_nodespec_with_ports_and_params():
input_port = PortSpec(name="in1", label="Input One", type="float")
output_port = PortSpec(name="out1", label="Output One", type="float")
param = ParamSpec(name="factor", label="Factor", type="float", default=1.0)
node_spec = NodeSpec(
type="calc.Multiply",
name="Multiply Node",
inputs=[input_port],
outputs=[output_port],
params=[param]
)
assert len(node_spec.inputs) == 1
assert node_spec.inputs[0] == input_port
assert len(node_spec.outputs) == 1
assert node_spec.outputs[0] == output_port
assert len(node_spec.params) == 1
assert node_spec.params[0] == param
def test_nodespec_default_values():
node_spec = NodeSpec(type="simple.Node", name="Simple Node")
assert node_spec.node_type == NodeType.COMMON
assert node_spec.category == "Default"
assert node_spec.version == "1.0.0"
assert node_spec.cache_policy == CachePolicy.SESSION
assert node_spec.tags == []
def test_nodespec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'type' field
NodeSpec(name="Invalid Node")
with pytest.raises(ValidationError):
# Missing required 'name' field
NodeSpec(type="invalid.Node")
def test_portspec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'name' field
PortSpec(label="Missing Name", type="str")
with pytest.raises(ValidationError):
# Invalid DimensionMode
PortSpec(name="port", label="Port", type="str", dimension="invalid_dim")
def test_paramspec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'name' field
ParamSpec(label="Missing Name Param", type="int")