pytrace/tests/unit/test_spec.py
Boshuang Zhao fb8458011a upload
2026-01-15 21:58:30 +08:00

116 lines
3.8 KiB
Python

"""
Tests for Spec generation and Pydantic validation.
"""
import pytest
from pydantic import ValidationError
from pytrace.model.specs import NodeSpec, PortSpec, ParamSpec
from pytrace.model.enums import NodeType, DimensionMode, CachePolicy
def test_portspec_creation():
port = PortSpec(name="input_port", label="Input Port", type="str")
assert port.name == "input_port"
assert port.label == "Input Port"
assert port.type == "str"
assert port.dimension == DimensionMode.SCALAR
assert port.description == ""
def test_portspec_with_dimension_and_description():
port = PortSpec(
name="list_port",
label="List Port",
type="List[int]",
dimension=DimensionMode.LIST,
description="A list of integers"
)
assert port.dimension == DimensionMode.LIST
assert port.description == "A list of integers"
def test_paramspec_creation():
param = ParamSpec(name="my_param", label="My Param", type="int", default=10)
assert param.name == "my_param"
assert param.label == "My Param"
assert param.type == "int"
assert param.default == 10
assert param.options is None
def test_paramspec_with_options():
param = ParamSpec(
name="choice",
label="Choice Param",
type="str",
options=["A", "B", "C"]
)
assert param.options == ["A", "B", "C"]
def test_nodespec_creation():
node_spec = NodeSpec(
type="test_node.MyNode",
name="My Test Node",
category="Test",
description="A test node",
version="1.0.0",
cache_policy=CachePolicy.PERMANENT,
tags=["example", "test"]
)
assert node_spec.type == "test_node.MyNode"
assert node_spec.name == "My Test Node"
assert node_spec.node_type == NodeType.COMMON
assert node_spec.category == "Test"
assert node_spec.description == "A test node"
assert node_spec.version == "1.0.0"
assert node_spec.cache_policy == CachePolicy.PERMANENT
assert node_spec.tags == ["example", "test"]
assert node_spec.inputs == []
assert node_spec.outputs == []
assert node_spec.params == []
def test_nodespec_with_ports_and_params():
input_port = PortSpec(name="in1", label="Input One", type="float")
output_port = PortSpec(name="out1", label="Output One", type="float")
param = ParamSpec(name="factor", label="Factor", type="float", default=1.0)
node_spec = NodeSpec(
type="calc.Multiply",
name="Multiply Node",
inputs=[input_port],
outputs=[output_port],
params=[param]
)
assert len(node_spec.inputs) == 1
assert node_spec.inputs[0] == input_port
assert len(node_spec.outputs) == 1
assert node_spec.outputs[0] == output_port
assert len(node_spec.params) == 1
assert node_spec.params[0] == param
def test_nodespec_default_values():
node_spec = NodeSpec(type="simple.Node", name="Simple Node")
assert node_spec.node_type == NodeType.COMMON
assert node_spec.category == "Default"
assert node_spec.version == "1.0.0"
assert node_spec.cache_policy == CachePolicy.SESSION
assert node_spec.tags == []
def test_nodespec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'type' field
NodeSpec(name="Invalid Node")
with pytest.raises(ValidationError):
# Missing required 'name' field
NodeSpec(type="invalid.Node")
def test_portspec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'name' field
PortSpec(label="Missing Name", type="str")
with pytest.raises(ValidationError):
# Invalid DimensionMode
PortSpec(name="port", label="Port", type="str", dimension="invalid_dim")
def test_paramspec_validation_error():
with pytest.raises(ValidationError):
# Missing required 'name' field
ParamSpec(label="Missing Name Param", type="int")