188 lines
8.1 KiB
Python
188 lines
8.1 KiB
Python
import asyncio
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
# Optional pytest marker: if pytest is available, use pytest.mark.asyncio; otherwise provide no-op
|
|
try:
|
|
import pytest
|
|
async_marker = pytest.mark.asyncio
|
|
except Exception:
|
|
pytest = None
|
|
def async_marker(f):
|
|
return f
|
|
|
|
# Ensure repo root is in path for direct test runs
|
|
cur = Path(__file__).resolve()
|
|
root = cur.parent.parent
|
|
if str(root) not in sys.path:
|
|
sys.path.insert(0, str(root))
|
|
|
|
from app.core.node_registry import NodeRegistry
|
|
from app.core.node_base import DimensionMode, WorkflowPackager, NodeType
|
|
from app.core.workflow_executor import WorkflowExecutor, WorkflowGraph
|
|
from app.nodes.advanced_example_nodes import *
|
|
|
|
|
|
@async_marker
|
|
async def test_special_nodes_registered():
|
|
special_nodes = ["InputNodeImpl", "OutputNodeImpl", "FunctionNodeImpl"]
|
|
for n in special_nodes:
|
|
assert NodeRegistry.exists(n), f"节点 {n} 未注册"
|
|
|
|
array_nodes = ["ArrayMapNode", "ArrayConcatNode", "ArrayFilterNode",
|
|
"ArrayReduceNode", "BroadcastNode", "ArrayZipNode"]
|
|
for n in array_nodes:
|
|
assert NodeRegistry.exists(n), f"数组节点 {n} 未注册"
|
|
|
|
|
|
@async_marker
|
|
async def test_simple_workflow_execution():
|
|
nodes = [
|
|
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
|
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
|
|
{"id": "reduce", "type": "normal", "class_name": "ArrayReduceNode", "params": {"operation": "sum"}},
|
|
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
|
]
|
|
edges = [
|
|
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
|
|
{"source": "map", "source_port": "mapped", "target": "reduce", "target_port": "values"},
|
|
{"source": "reduce", "source_port": "result", "target": "output", "target_port": "input"},
|
|
]
|
|
|
|
executor = WorkflowExecutor(user_id="test_user")
|
|
success, report = await executor.execute(nodes=nodes, edges=edges, global_context={"values": [1,2,3,4,5]})
|
|
assert success is True
|
|
node_results = report.get("node_results", {})
|
|
assert "output" in node_results
|
|
output = node_results["output"]["outputs"]
|
|
assert output.get("input") == 30
|
|
|
|
|
|
@async_marker
|
|
async def test_array_operations():
|
|
# filter case
|
|
nodes = [
|
|
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
|
{"id": "filter", "type": "normal", "class_name": "ArrayFilterNode", "params": {"threshold": 2}},
|
|
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
|
]
|
|
edges = [
|
|
{"source": "input", "source_port": "values", "target": "filter", "target_port": "values"},
|
|
{"source": "filter", "source_port": "filtered", "target": "output", "target_port": "input"},
|
|
]
|
|
executor = WorkflowExecutor(user_id="test_user")
|
|
success, report = await executor.execute(nodes=nodes, edges=edges, global_context={"values":[1,2,3,4,5]})
|
|
assert success is True
|
|
out = report.get("node_results", {}).get("output", {}).get("outputs", {})
|
|
assert out.get("input") == [3,4,5]
|
|
|
|
# broadcast case
|
|
nodes = [
|
|
{"id": "input1", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
|
{"id": "broadcast", "type": "normal", "class_name": "BroadcastNode", "params": {"count": 3}},
|
|
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
|
]
|
|
edges = [
|
|
{"source": "input1", "source_port": "value", "target": "broadcast", "target_port": "value"},
|
|
{"source": "broadcast", "source_port": "broadcast", "target": "output", "target_port": "input"},
|
|
]
|
|
success, report = await WorkflowExecutor(user_id="test_user").execute(nodes=nodes, edges=edges, global_context={"value":42})
|
|
assert success is True
|
|
out = report.get("node_results", {}).get("output", {}).get("outputs", {})
|
|
assert out.get("input") == [42,42,42]
|
|
|
|
|
|
@async_marker
|
|
async def test_workflow_graph_and_function_packaging():
|
|
graph = WorkflowGraph()
|
|
graph.add_node("n1", NodeType.INPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
|
|
graph.nodes["n1"]["class_name"] = "InputNodeImpl"
|
|
graph.add_node("n2", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
|
|
graph.nodes["n2"]["class_name"] = "ArrayMapNode"
|
|
graph.add_node("n3", NodeType.NORMAL, sub_workflow_nodes=None, sub_workflow_edges=None)
|
|
graph.nodes["n3"]["class_name"] = "ArrayReduceNode"
|
|
graph.add_node("n4", NodeType.OUTPUT, sub_workflow_nodes=None, sub_workflow_edges=None)
|
|
graph.nodes["n4"]["class_name"] = "OutputNodeImpl"
|
|
|
|
graph.add_edge("n1", "values", "n2", "values", DimensionMode.NONE)
|
|
graph.add_edge("n2", "mapped", "n3", "values", DimensionMode.NONE)
|
|
graph.add_edge("n3", "result", "n4", "input", DimensionMode.NONE)
|
|
|
|
assert len(graph.nodes) == 4
|
|
assert len(graph.edges) == 3
|
|
assert graph.has_cycle() is False
|
|
topo = graph.topological_sort()
|
|
assert topo == ["n1","n2","n3","n4"]
|
|
|
|
valid, error = graph.validate_function_workflow()
|
|
assert valid is True
|
|
|
|
# package function workflow
|
|
sub_nodes = [
|
|
{"id":"input","type":"input","class_name":"InputNodeImpl","params":{}},
|
|
{"id":"map","type":"normal","class_name":"ArrayMapNode","params":{"multiplier":2}},
|
|
{"id":"output","type":"output","class_name":"OutputNodeImpl","params":{}},
|
|
]
|
|
sub_edges = [
|
|
{"source":"input","source_port":"values","target":"map","target_port":"values"},
|
|
{"source":"map","source_port":"mapped","target":"output","target_port":"input"}
|
|
]
|
|
valid, err = WorkflowPackager.validate_function_workflow(sub_nodes, sub_edges)
|
|
assert valid is True
|
|
fn_def = WorkflowPackager.package_as_function(node_id="multiply_func", nodes=sub_nodes, edges=sub_edges, display_name="乘以2", description="test")
|
|
assert fn_def.get("id") == "multiply_func"
|
|
|
|
|
|
@async_marker
|
|
async def test_nested_function_execution():
|
|
sub_nodes = [
|
|
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
|
{"id": "map", "type": "normal", "class_name": "ArrayMapNode", "params": {"multiplier": 2}},
|
|
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
|
]
|
|
sub_edges = [
|
|
{"source": "input", "source_port": "values", "target": "map", "target_port": "values"},
|
|
{"source": "map", "source_port": "mapped", "target": "output", "target_port": "input"},
|
|
]
|
|
|
|
valid, error = WorkflowPackager.validate_function_workflow(sub_nodes, sub_edges)
|
|
assert valid is True
|
|
|
|
fn_def = WorkflowPackager.package_as_function(
|
|
node_id="multiply_func",
|
|
nodes=sub_nodes,
|
|
edges=sub_edges,
|
|
display_name="乘以2",
|
|
description="将数组中的所有元素乘以2"
|
|
)
|
|
if "sub_workflow" in fn_def:
|
|
fn_def["sub_workflow_nodes"] = fn_def["sub_workflow"].get("nodes", [])
|
|
fn_def["sub_workflow_edges"] = fn_def["sub_workflow"].get("edges", [])
|
|
|
|
main_nodes = [
|
|
{"id": "input", "type": "input", "class_name": "InputNodeImpl", "params": {}},
|
|
fn_def,
|
|
{"id": "output", "type": "output", "class_name": "OutputNodeImpl", "params": {}},
|
|
]
|
|
main_edges = [
|
|
{"source": "input", "source_port": "values", "target": "multiply_func", "target_port": "values"},
|
|
{"source": "multiply_func", "source_port": "mapped", "target": "output", "target_port": "input"},
|
|
]
|
|
|
|
success, report = await WorkflowExecutor(user_id="test_user").execute(nodes=main_nodes, edges=main_edges, global_context={"values":[1,2,3]})
|
|
assert success is True
|
|
out = report.get("node_results", {}).get("output", {}).get("outputs", {})
|
|
# debug print to inspect unexpected structure
|
|
print("DEBUG nested report:", report)
|
|
assert out.get("input") == [2,4,6]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
async def _run_all():
|
|
await test_special_nodes_registered()
|
|
await test_simple_workflow_execution()
|
|
await test_array_operations()
|
|
await test_workflow_graph_and_function_packaging()
|
|
await test_nested_function_execution()
|
|
|
|
asyncio.run(_run_all()) |