add tests

This commit is contained in:
ouczbs 2026-01-22 23:22:01 +08:00
parent c2335ae1a6
commit 649e2edca1
6 changed files with 393 additions and 293 deletions

View File

@ -0,0 +1,93 @@
import sys
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any
import uvicorn
# 确保 pytrace 在 Python 路径中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.model.graph import WorkflowGraph
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.runtime.repository import NodeRepository
# 导入项目中已有的 TraceServer
from app.services import trace_server
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
from pytrace.nodes.test_math_suite import MathSuite
# --- 1. 数据模型 ---
class GraphExecutionRequest(BaseModel):
graph_data: Dict[str, Any]
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
# --- 2. FastAPI 应用定义 ---
app = FastAPI(
title="TraceStudio - Remote Test Server",
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
version="2.0.0",
)
@app.on_event("startup")
async def startup_event():
# 1. 在后台启动TCP服务器用于监听客户端的连接和注册
trace_server.start_background()
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
NodeRepository.set_remote_registry(trace_server.remote_registry)
print("--- Remote Test Server ---")
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
print("HTTP Server ready to execute graphs at /execute_graph")
@app.post("/execute_graph", tags=["Execution"])
async def execute_graph(request: GraphExecutionRequest):
"""
接收图的JSON并执行可能会调用已通过TCP连接注册的远程节点
"""
try:
print("\n--- New Graph Execution Request ---")
# 1. 从请求数据构建 WorkflowGraph
graph = WorkflowGraph.from_dict(request.graph_data)
print(f"Executing graph: '{graph.name}'")
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
# RemoteProxyNode 将通过这个服务来回调客户端
context = TraceContext(services={"rpc": trace_server})
# 3. 实例化执行器并运行
executor = WorkflowExecutor(graph, context)
executor.execute_batch(request.initial_inputs)
# 4. 获取并返回最终输出
final_outputs = context.get_outputs()
print(f"Graph execution finished. Final outputs: {final_outputs}")
return {
"success": True,
"outputs": final_outputs
}
except Exception as e:
import traceback
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
# --- 3. 启动服务器 ---
if __name__ == "__main__":
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1,103 @@
{
"name": "Complex Remote and Subgraph Test",
"nodes": [
{
"id": "main_input",
"class_id": "InputTraceNode",
"params": {
"main_text": { "mode": "context", "value": "inputs.main_text" },
"main_num": { "mode": "context", "value": "inputs.main_num" }
}
},
{
"id": "remote_text_processor",
"class_id": "remote.process_text"
},
{
"id": "sub_graph_node",
"class_id": "core.Graph",
"data": {
"name": "My SubGraph",
"nodes": [
{
"id": "sub_input",
"class_id": "InputTraceNode",
"params": {
"sub_in_1": { "mode": "context", "value": "inputs.sub_in_1" },
"sub_in_2": { "mode": "context", "value": "inputs.sub_in_2" }
}
},
{
"id": "local_add",
"class_id": "math.add"
},
{
"id": "remote_ratio_calculator",
"class_id": "remote.calculate_ratio",
"params": {
"denominator": { "mode": "static", "value": 10.0 }
}
},
{
"id": "sub_output",
"class_id": "OutputTraceNode"
}
],
"edges": [
{
"source_id": "sub_input", "source_port": "sub_in_1",
"target_id": "local_add", "target_port": "a"
},
{
"source_id": "sub_input", "source_port": "sub_in_2",
"target_id": "local_add", "target_port": "b"
},
{
"source_id": "local_add", "source_port": "c",
"target_id": "remote_ratio_calculator", "target_port": "numerator"
},
{
"source_id": "remote_ratio_calculator", "source_port": "ratio",
"target_id": "sub_output", "target_port": "sub_ratio"
}
],
"interface": {
"inputs": { "sub_in_1": "integer", "sub_in_2": "integer" },
"outputs": { "sub_ratio": "float" }
}
}
},
{
"id": "main_output",
"class_id": "OutputTraceNode"
}
],
"edges": [
{
"source_id": "main_input", "source_port": "main_text",
"target_id": "remote_text_processor", "target_port": "text_in"
},
{
"source_id": "remote_text_processor", "source_port": "length",
"target_id": "sub_graph_node", "target_port": "sub_in_1"
},
{
"source_id": "main_input", "source_port": "main_num",
"target_id": "sub_graph_node", "target_port": "sub_in_2"
},
{
"source_id": "sub_graph_node", "source_port": "sub_ratio",
"target_id": "main_output", "target_port": "final_ratio"
}
],
"interface": {
"inputs": {
"main_text": "string",
"main_num": "integer"
},
"outputs": {
"final_ratio": "float"
}
}
}

View File

@ -1,73 +0,0 @@
import sys
import os
import time
# Ensure pytrace is in python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.api.node import TraceNode
from pytrace.api.decorators import register_class, register_node, input_port, output_port
from pytrace.runtime.context import TraceContext
from pytrace.remote.worker import TraceWorker, expose_node
# =============================================================================
# 1. Define Your Custom Nodes
# =============================================================================
# Use @expose_node to make the class available to the Remote Server.
# Use @register_class to define the category and display name in the UI.
@expose_node
@register_class(category="Tutorial", display_name="Client Demo")
class ClientDemoNode(TraceNode):
"""
A collection of demo nodes running on the client side.
Developers can add their own methods here.
"""
@register_node(display_name="Remote Add", description="Performs addition on the client.")
@input_port(name="a", type="float")
@input_port(name="b", type="float")
@output_port(name="result", type="float")
def add(self, ctx: TraceContext):
# 1. Get Inputs
a = ctx.get_input("a", 0.0)
b = ctx.get_input("b", 0.0)
# 2. Execute Logic
result = a + b
# 3. Log execution (This log will appear on the Server console!)
ctx.log(f"[Client] Executing Add: {a} + {b} = {result}")
# 4. Set Output
ctx.set_output("result", result)
@register_node(display_name="Remote Echo", description="Echoes a message.")
@input_port(name="message", type="string")
@output_port(name="echo", type="string")
def echo(self, ctx: TraceContext):
msg = ctx.get_input("message", "")
ctx.log(f"[Client] Echoing: {msg}")
ctx.set_output("echo", f"Client says: {msg}")
# =============================================================================
# 2. Start the Worker
# =============================================================================
if __name__ == "__main__":
SERVER_HOST = 'localhost'
SERVER_PORT = 9999
print(f"--- TraceStudio Client Demo ---")
print(f"Connecting to Server at {SERVER_HOST}:{SERVER_PORT}...")
# Initialize the Worker
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
try:
# Start the worker (This blocks until stopped)
worker.start()
except KeyboardInterrupt:
print("\n[Client] Worker stopped.")
except Exception as e:
print(f"\n[Client] Error: {e}")

View File

@ -1,220 +0,0 @@
import socket
import json
import threading
import sys
import os
import time
# Add project root
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.api.decorators import output_port, register_class
from pytrace.remote.manager import RPCManager
from pytrace.remote.protocol import MessageType
from pytrace.model.graph import WorkflowGraph, Node, Edge
from pytrace.model.enums import DimensionMode
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.remote.registry import RemoteRegistry
from pytrace.runtime.repository import NodeRepository
from pytrace.model.data import TraceParam, ParamMode
from pytrace.api.node import GraphTraceNode, TraceNode
from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry
class ClientSession:
"""
Represents a connected Worker Client on the Server.
Wraps the Socket connection and the RPCManager.
"""
def __init__(self, conn, addr, server: 'TraceServer'):
self.conn = conn
self.addr = addr
self.server = server
self.client_id = None
self.rpc = RPCManager(send_callback=self._send_json)
# Register internal handlers
self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register)
def _send_json(self, data):
try:
msg = json.dumps(data) + "\n"
self.conn.sendall(msg.encode('utf-8'))
except Exception as e:
print(f"[Session] Send error: {e}")
self.close()
def _handle_register(self, payload, msg_id):
self.client_id = payload.get("client_id")
nodes = payload.get("nodes", [])
print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.")
# Update Server Registry
self.server.remote_registry.register_client(self.client_id, nodes)
self.server.register_session(self.client_id, self)
def start_loop(self):
"""Blocking read loop."""
f = self.conn.makefile('r', encoding='utf-8')
try:
for line in f:
if not line: break
try:
data = json.loads(line)
self.rpc.process_message(data)
except Exception as e:
print(f"[Session] Msg error: {e}")
finally:
self.close()
def close(self):
print(f"[Server] Session {self.client_id} disconnected")
self.rpc.close()
try:
self.conn.close()
except:
pass
if self.client_id:
self.server.unregister_session(self.client_id)
class TraceServer:
def __init__(self, socket_port=9999):
self.socket_port = socket_port
self.sessions = {} # client_id -> ClientSession
self.lock = threading.Lock()
self.remote_registry = RemoteRegistry()
def start(self):
# 1. Start Socket Server (For Workers)
threading.Thread(target=self._socket_server_loop, daemon=True).start()
# 2. Start "HTTP" Server (For Users - Simulated)
threading.Thread(target=self._http_server_simulation, daemon=True).start()
print(f"[Server] Running. Socket Port: {self.socket_port}")
# Keep main thread alive
while True:
time.sleep(1)
def _socket_server_loop(self):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('localhost', self.socket_port))
server_socket.listen(5)
while True:
conn, addr = server_socket.accept()
session = ClientSession(conn, addr, self)
threading.Thread(target=session.start_loop, daemon=True).start()
def register_session(self, client_id, session):
with self.lock:
self.sessions[client_id] = session
def unregister_session(self, client_id):
with self.lock:
self.sessions.pop(client_id, None)
# --- RPC Service Interface for Workflow ---
def call(self, target_id: str, method: str, params: dict, timeout=3600):
# 1. Routing Logic
if not target_id:
providers = self.remote_registry.find_providers(method)
if not providers:
raise RuntimeError(f"No active client provides node '{method}'")
target_id = providers[0]
print(f"[Server] Auto-routing '{method}' to client '{target_id}'")
# 2. Get Session
session = self.sessions.get(target_id)
if not session:
raise RuntimeError(f"Client {target_id} not connected")
# 3. Execute via Session's RPC
return session.rpc.call(target_id, method, params, timeout)
# --- HTTP API Simulation ---
def _http_server_simulation(self):
"""
Simulates an HTTP endpoint that receives workflow execution requests.
In a real app, this would be Flask/FastAPI.
"""
print("[Server] HTTP API ready. Waiting for requests...")
# Simulate a user request coming in after 5 seconds
time.sleep(5)
# Mock Request Payload
# User explicitly wants 'ClientMathNode.add' to run on a specific worker if available
# Or they might leave it empty for auto-routing.
# Let's find a connected client to simulate the user picking one
target_worker_id = None
while not target_worker_id:
with self.lock:
if self.sessions:
target_worker_id = list(self.sessions.keys())[0]
if not target_worker_id:
print("[Server] HTTP Waiting for workers to connect...")
time.sleep(2)
print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<")
# The User Request Data
request_payload = {
"inputs": {"x": 10, "y": 20},
"node_targets": {
# The user specifies that the node with ID 'node_remote_1' in the graph
# MUST be executed by 'target_worker_id'
"node_remote_1": target_worker_id
}
}
self.run_workflow(request_payload)
def run_workflow(self, payload):
print("[Server] Loading test workflow: tests/assets/main.flow")
flow_path = os.path.abspath("tests/assets/main.flow")
with open(flow_path, 'r', encoding='utf-8') as f:
graph_data = json.load(f)
# Inject a node ID into the graph data for demonstration if needed,
# or assume the graph already has 'node_remote_1'.
# For this demo, let's assume the graph's first node is the one we want to target.
# (In reality, graph_data comes from the frontend with IDs)
graph = WorkflowGraph.from_dict(graph_data)
# Prepare Repository
repo = NodeRepository(remote_registry=self.remote_registry)
# Prepare Context with User Preferences (node_targets)
context = TraceContext(
contexts={},
services={"rpc": self, "node_repository": repo},
node_targets=payload.get("node_targets")
)
# Execute
print("[Server] Executing workflow...")
executor = WorkflowExecutor(graph, context, node_repository=repo)
inputs = payload.get("inputs")
executor.execute_batch(inputs)
# 7. Verify Results
outputs = context.context.get_outputs()
print(f"[Server] Workflow finished. Outputs: {outputs}")
res = outputs.get("final_result")
val = res.value if hasattr(res, 'value') else res
if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15
print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}")
else:
print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}")
if __name__ == "__main__":
server = TraceServer()
server.start()

93
tests/run_remote_test.py Normal file
View File

@ -0,0 +1,93 @@
import json
import httpx
import os
# --- 配置 ---
SERVER_URL = "http://localhost:8000"
GRAPH_EXECUTE_ENDPOINT = f"{SERVER_URL}/execute_graph"
GRAPH_FILE_PATH = os.path.join(os.path.dirname(__file__), "assets", "complex_remote_graph.json")
def run_test():
"""
执行端到端远程调用测试
使用方法:
1. 确保 remote_test_server.py 正在运行 (uvicorn server.remote_test_server:app --reload --port 8000)
2. 确保 test_complex_remote_client.py 正在运行 (python tests/test_complex_remote_client.py)
3. 运行此脚本 (python tests/run_remote_test.py)
"""
print("--- Running End-to-End Remote Execution Test ---")
# 1. 加载图JSON文件
try:
with open(GRAPH_FILE_PATH, 'r', encoding='utf-8') as f:
graph_data = json.load(f)
print(f"Successfully loaded graph from: {GRAPH_FILE_PATH}")
except FileNotFoundError:
print(f"ERROR: Graph file not found at {GRAPH_FILE_PATH}")
return
# 2. 准备初始输入
# 根据 complex_remote_graph.json 的 "interface" 定义
initial_inputs = {
"main_text": "hello_remote_world",
"main_num": 5
}
print(f"Initial inputs: {initial_inputs}")
# 3. 构建请求体
request_payload = {
"graph_data": graph_data,
"initial_inputs": initial_inputs
}
# 4. 发送HTTP请求到服务器
print(f"Sending POST request to {GRAPH_EXECUTE_ENDPOINT}...")
try:
with httpx.Client(timeout=60.0) as client:
response = client.post(GRAPH_EXECUTE_ENDPOINT, json=request_payload)
response.raise_for_status() # 如果状态码不是 2xx则会引发异常
result = response.json()
print("\n--- Test Result ---")
print("Request successful!")
print("Server response:")
# 使用json.dumps美化输出
print(json.dumps(result, indent=2, ensure_ascii=False))
# 预期结果计算:
# 1. main_text "hello_remote_world" 长度为 18
# 2. 子图接收 sub_in_1=18, sub_in_2=5
# 3. 子图内本地相加: 18 + 5 = 23
# 4. 子图内远程计算比率: 23 / 10.0 = 2.3
# 5. 主图接收到最终结果 2.3
expected_ratio = 2.3
final_ratio = result.get("outputs", {}).get("final_ratio", {}).get("value")
print("\n--- Verification ---")
print(f"Expected final_ratio: {expected_ratio}")
print(f"Actual final_ratio: {final_ratio}")
assert final_ratio == expected_ratio, "Test assertion failed!"
print("✅ Test Passed Successfully!")
except httpx.RequestError as e:
print("\n--- TEST FAILED ---")
print(f"ERROR: Could not connect to the server at {SERVER_URL}.")
print("Please ensure the 'remote_test_server.py' is running.")
print(f"Details: {e}")
except httpx.HTTPStatusError as e:
print("\n--- TEST FAILED ---")
print(f"ERROR: Server returned a non-200 status code: {e.response.status_code}")
print("Server response:")
try:
# 尝试打印JSON错误详情
print(json.dumps(e.response.json(), indent=2, ensure_ascii=False))
except json.JSONDecodeError:
print(e.response.text)
if __name__ == "__main__":
run_test()

View File

@ -0,0 +1,104 @@
import sys
import os
import time
# 确保 pytrace 在 Python 路径中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.api.node import TraceNode
from pytrace.api.decorators import register_class, register_node, input_port, output_port
from pytrace.runtime.context import TraceContext
from pytrace.remote.worker import TraceWorker, expose_node
# --- 1. 定义可被远程调用的节点 ---
# 使用 @expose_node 将此类暴露给服务器
@expose_node
@register_class(category="ComplexRemote", display_name="Complex Remote Node")
class ComplexRemoteNode(TraceNode):
"""
这个类包含一组用于远程调用的方法节点
每个方法都添加了详细的打印语句方便您在客户端控制台进行调试
"""
@register_node(
class_id="remote.process_text",
display_name="Process Text (Remote)",
description="在客户端处理文本,计算长度并添加前缀。"
)
@input_port(name="text_in", type="string")
@output_port(name="text_out", type="string")
@output_port(name="length", type="integer")
def process_text(self, ctx: TraceContext):
text = ctx.get_input("text_in", "")
print("\n--- [REMOTE] ---")
print(f"Node 'remote.process_text' is executing.")
print(f" - Received input 'text_in': '{text}'")
# 模拟一些耗时操作
time.sleep(0.5)
processed_text = f"processed_by_client:_{text}"
text_length = len(text)
ctx.set_output("text_out", processed_text)
ctx.set_output("length", text_length)
print(f" - Sent output 'text_out': '{processed_text}'")
print(f" - Sent output 'length': {text_length}")
print("--- [REMOTE] ---\
")
@register_node(
class_id="remote.calculate_ratio",
display_name="Calculate Ratio (Remote)",
description="在客户端根据两个数字计算比率。"
)
@input_port(name="numerator", type="float", label="分子")
@input_port(name="denominator", type="float", label="分母")
@output_port(name="ratio", type="float", label="比率")
def calculate_ratio(self, ctx: TraceContext):
numerator = ctx.get_input("numerator", 0.0)
denominator = ctx.get_input("denominator", 1.0)
print("\n--- [REMOTE] ---")
print(f"Node 'remote.calculate_ratio' is executing.")
print(f" - Received input 'numerator': {numerator}")
print(f" - Received input 'denominator': {denominator}")
if denominator == 0:
print(" - Error: Denominator is zero.")
# 在实际应用中可能需要更复杂的错误处理
ratio = 0.0
else:
ratio = numerator / denominator
ctx.set_output("ratio", ratio)
print(f" - Sent output 'ratio': {ratio}")
print("--- [REMOTE] ---\
")
# --- 2. 启动客户端 Worker ---
if __name__ == "__main__":
# 服务器的TCP RPC地址和端口
SERVER_HOST = 'localhost'
SERVER_PORT = 9999 # 这应该与 remote_test_server.py 中的端口一致
print(f"--- Complex Remote Test Client ---")
print(f"Attempting to connect to RPC Server at {SERVER_HOST}:{SERVER_PORT}...")
# 初始化并启动 Worker
# 它会自动查找所有被 @expose_node 装饰的类,并向服务器注册
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
try:
# worker.start() 是一个阻塞操作,它会保持运行以监听服务器的调用
worker.start()
except KeyboardInterrupt:
print("\n[Client] Worker stopped by user.")
except Exception as e:
print(f"\n[Client] An error occurred: {e}")
print("[Client] Is the test server running?")