pytrace/tests/test_workflow.py
2026-01-19 00:49:55 +08:00

102 lines
3.4 KiB
Python

import json
import os
import sys
from typing import Dict, Any
from pytrace.model.data import TraceData
# Add project root to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.api.node import TraceNode, GraphTraceNode, InputTraceNode, OutputTraceNode
from pytrace.api.decorators import output_port, register_class, register_node
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.runtime.registry import NodeRegistry
from pytrace.model.graph import WorkflowGraph
from pytrace.model.meta import ClassMeta
from pytrace.internal.meta_map import MetaMap
# Ensure MathSuite is registered
from pytrace.nodes.test_math_suite import MathSuite
# --- Define FlowLoaderNode ---
@register_class(category="Flow" , tags = ["Test"], display_name = "FlowLoaderNode test", is_dynamic=True)
class FlowLoaderNode(TraceNode):
def __init__(self):
super().__init__()
self._cached_graph_node = None
def __call__(self, ctx: TraceContext):
# 1. Get flow file path
flow_path = ctx.get_param("flow_path")
if not flow_path:
raise ValueError("Flow path is required")
# Fix relative path
if not os.path.isabs(flow_path):
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
flow_path = os.path.join(base_dir, flow_path)
# 2. Load or use cached GraphTraceNode
if not self._cached_graph_node:
print(f"[FlowLoader] Loading sub-flow from: {flow_path}")
if not os.path.exists(flow_path):
raise FileNotFoundError(f"Flow file not found: {flow_path}")
with open(flow_path, 'r', encoding='utf-8') as f:
graph_data = json.load(f)
# Instantiate GraphTraceNode from core library
self._cached_graph_node = GraphTraceNode(graph_data=graph_data, contexts=ctx)
# 3. Delegate execution
self._cached_graph_node(ctx)
@register_class(category="Flow" , tags = ["Test"], display_name = "TestYNode test")
class TestYNode(TraceNode):
def __init__(self):
super().__init__()
self._cached_graph_node = None
@output_port(name="y", type="list<float>")
def __call__(self, ctx: TraceContext):
ctx.set_output("y", [11.0, 22.0, 33.0, 44])
def test_nested_workflow():
print("--- Starting Nested Workflow Test ---")
# 2. Load Main Graph
main_flow_path = os.path.join(os.path.dirname(__file__), "assets", "main.flow")
print(f"Loading main flow: {main_flow_path}")
with open(main_flow_path, 'r', encoding='utf-8') as f:
main_graph_data = json.load(f)
graph = WorkflowGraph.from_dict(main_graph_data)
# 3. Prepare Context
context = TraceContext(contexts={})
# 4. Execute
executor = WorkflowExecutor(graph, context)
executor.execute_batch({
"x" : [1.0, 2.0, 3.0],
"y" : 100,
"z" : [33, 44]
}, {"x" : True})
# 5. Verify Results
trace_final_outputs = context.get_outputs()
print("Final Context Outputs:", trace_final_outputs)
result_data = trace_final_outputs["final_result"]
value = result_data.value if hasattr(result_data, 'value') else result_data
print(f"Result Check: {value} == 15.0")
print("--- Test Passed Successfully ---")
if __name__ == "__main__":
test_nested_workflow()