102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from typing import Dict, Any
|
|
|
|
from pytrace.model.data import TraceData
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
from pytrace.api.node import TraceNode, GraphTraceNode, InputTraceNode, OutputTraceNode
|
|
from pytrace.api.decorators import output_port, register_class, register_node
|
|
from pytrace.runtime.context import TraceContext
|
|
from pytrace.runtime.executor import WorkflowExecutor
|
|
from pytrace.runtime.registry import NodeRegistry
|
|
from pytrace.model.graph import WorkflowGraph
|
|
from pytrace.model.meta import ClassMeta
|
|
from pytrace.internal.meta_map import MetaMap
|
|
|
|
# Ensure MathSuite is registered
|
|
from pytrace.nodes.test_math_suite import MathSuite
|
|
|
|
# --- Define FlowLoaderNode ---
|
|
@register_class(category="Flow" , tags = ["Test"], display_name = "FlowLoaderNode test", is_dynamic=True)
|
|
class FlowLoaderNode(TraceNode):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._cached_graph_node = None
|
|
|
|
def __call__(self, ctx: TraceContext):
|
|
# 1. Get flow file path
|
|
flow_path = ctx.get_param("flow_path")
|
|
if not flow_path:
|
|
raise ValueError("Flow path is required")
|
|
|
|
# Fix relative path
|
|
if not os.path.isabs(flow_path):
|
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|
flow_path = os.path.join(base_dir, flow_path)
|
|
|
|
# 2. Load or use cached GraphTraceNode
|
|
if not self._cached_graph_node:
|
|
print(f"[FlowLoader] Loading sub-flow from: {flow_path}")
|
|
if not os.path.exists(flow_path):
|
|
raise FileNotFoundError(f"Flow file not found: {flow_path}")
|
|
|
|
with open(flow_path, 'r', encoding='utf-8') as f:
|
|
graph_data = json.load(f)
|
|
|
|
# Instantiate GraphTraceNode from core library
|
|
self._cached_graph_node = GraphTraceNode(graph_data=graph_data, contexts=ctx)
|
|
|
|
# 3. Delegate execution
|
|
self._cached_graph_node(ctx)
|
|
|
|
@register_class(category="Flow" , tags = ["Test"], display_name = "TestYNode test")
|
|
class TestYNode(TraceNode):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._cached_graph_node = None
|
|
|
|
@output_port(name="y", type="list<float>")
|
|
def __call__(self, ctx: TraceContext):
|
|
ctx.set_output("y", [11.0, 22.0, 33.0, 44])
|
|
|
|
def test_nested_workflow():
|
|
print("--- Starting Nested Workflow Test ---")
|
|
|
|
# 2. Load Main Graph
|
|
main_flow_path = os.path.join(os.path.dirname(__file__), "assets", "main.flow")
|
|
print(f"Loading main flow: {main_flow_path}")
|
|
|
|
with open(main_flow_path, 'r', encoding='utf-8') as f:
|
|
main_graph_data = json.load(f)
|
|
|
|
graph = WorkflowGraph.from_dict(main_graph_data)
|
|
|
|
# 3. Prepare Context
|
|
context = TraceContext(contexts={})
|
|
|
|
# 4. Execute
|
|
executor = WorkflowExecutor(graph, context)
|
|
executor.execute_batch({
|
|
"x" : [1.0, 2.0, 3.0],
|
|
"y" : 100,
|
|
"z" : [33, 44]
|
|
}, {"x" : True})
|
|
|
|
# 5. Verify Results
|
|
trace_final_outputs = context.get_outputs()
|
|
print("Final Context Outputs:", trace_final_outputs)
|
|
|
|
result_data = trace_final_outputs["final_result"]
|
|
value = result_data.value if hasattr(result_data, 'value') else result_data
|
|
|
|
print(f"Result Check: {value} == 15.0")
|
|
print("--- Test Passed Successfully ---")
|
|
|
|
if __name__ == "__main__":
|
|
test_nested_workflow() |