add tests
This commit is contained in:
parent
c2335ae1a6
commit
649e2edca1
93
server/remote_test_server.py
Normal file
93
server/remote_test_server.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Dict, Any
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# 确保 pytrace 在 Python 路径中
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from pytrace.model.graph import WorkflowGraph
|
||||||
|
from pytrace.runtime.context import TraceContext
|
||||||
|
from pytrace.runtime.executor import WorkflowExecutor
|
||||||
|
from pytrace.runtime.repository import NodeRepository
|
||||||
|
# 导入项目中已有的 TraceServer
|
||||||
|
from app.services import trace_server
|
||||||
|
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
|
||||||
|
from pytrace.nodes.test_math_suite import MathSuite
|
||||||
|
|
||||||
|
|
||||||
|
# --- 1. 数据模型 ---
|
||||||
|
|
||||||
|
class GraphExecutionRequest(BaseModel):
|
||||||
|
graph_data: Dict[str, Any]
|
||||||
|
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. FastAPI 应用定义 ---
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="TraceStudio - Remote Test Server",
|
||||||
|
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
|
||||||
|
version="2.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
# 1. 在后台启动TCP服务器,用于监听客户端的连接和注册
|
||||||
|
trace_server.start_background()
|
||||||
|
|
||||||
|
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
|
||||||
|
NodeRepository.set_remote_registry(trace_server.remote_registry)
|
||||||
|
|
||||||
|
print("--- Remote Test Server ---")
|
||||||
|
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
|
||||||
|
print("HTTP Server ready to execute graphs at /execute_graph")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/execute_graph", tags=["Execution"])
|
||||||
|
async def execute_graph(request: GraphExecutionRequest):
|
||||||
|
"""
|
||||||
|
接收图的JSON并执行,可能会调用已通过TCP连接注册的远程节点。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("\n--- New Graph Execution Request ---")
|
||||||
|
# 1. 从请求数据构建 WorkflowGraph
|
||||||
|
graph = WorkflowGraph.from_dict(request.graph_data)
|
||||||
|
print(f"Executing graph: '{graph.name}'")
|
||||||
|
|
||||||
|
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
|
||||||
|
# RemoteProxyNode 将通过这个服务来回调客户端
|
||||||
|
context = TraceContext(services={"rpc": trace_server})
|
||||||
|
|
||||||
|
# 3. 实例化执行器并运行
|
||||||
|
executor = WorkflowExecutor(graph, context)
|
||||||
|
executor.execute_batch(request.initial_inputs)
|
||||||
|
|
||||||
|
# 4. 获取并返回最终输出
|
||||||
|
final_outputs = context.get_outputs()
|
||||||
|
print(f"Graph execution finished. Final outputs: {final_outputs}")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"outputs": final_outputs
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": str(e),
|
||||||
|
"type": type(e).__name__,
|
||||||
|
"traceback": traceback.format_exc()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 3. 启动服务器 ---
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
|
||||||
|
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
103
tests/assets/complex_remote_graph.json
Normal file
103
tests/assets/complex_remote_graph.json
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
|
||||||
|
{
|
||||||
|
"name": "Complex Remote and Subgraph Test",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "main_input",
|
||||||
|
"class_id": "InputTraceNode",
|
||||||
|
"params": {
|
||||||
|
"main_text": { "mode": "context", "value": "inputs.main_text" },
|
||||||
|
"main_num": { "mode": "context", "value": "inputs.main_num" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "remote_text_processor",
|
||||||
|
"class_id": "remote.process_text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "sub_graph_node",
|
||||||
|
"class_id": "core.Graph",
|
||||||
|
"data": {
|
||||||
|
"name": "My SubGraph",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "sub_input",
|
||||||
|
"class_id": "InputTraceNode",
|
||||||
|
"params": {
|
||||||
|
"sub_in_1": { "mode": "context", "value": "inputs.sub_in_1" },
|
||||||
|
"sub_in_2": { "mode": "context", "value": "inputs.sub_in_2" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "local_add",
|
||||||
|
"class_id": "math.add"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "remote_ratio_calculator",
|
||||||
|
"class_id": "remote.calculate_ratio",
|
||||||
|
"params": {
|
||||||
|
"denominator": { "mode": "static", "value": 10.0 }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "sub_output",
|
||||||
|
"class_id": "OutputTraceNode"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"source_id": "sub_input", "source_port": "sub_in_1",
|
||||||
|
"target_id": "local_add", "target_port": "a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "sub_input", "source_port": "sub_in_2",
|
||||||
|
"target_id": "local_add", "target_port": "b"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "local_add", "source_port": "c",
|
||||||
|
"target_id": "remote_ratio_calculator", "target_port": "numerator"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "remote_ratio_calculator", "source_port": "ratio",
|
||||||
|
"target_id": "sub_output", "target_port": "sub_ratio"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"interface": {
|
||||||
|
"inputs": { "sub_in_1": "integer", "sub_in_2": "integer" },
|
||||||
|
"outputs": { "sub_ratio": "float" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "main_output",
|
||||||
|
"class_id": "OutputTraceNode"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"source_id": "main_input", "source_port": "main_text",
|
||||||
|
"target_id": "remote_text_processor", "target_port": "text_in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "remote_text_processor", "source_port": "length",
|
||||||
|
"target_id": "sub_graph_node", "target_port": "sub_in_1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "main_input", "source_port": "main_num",
|
||||||
|
"target_id": "sub_graph_node", "target_port": "sub_in_2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_id": "sub_graph_node", "source_port": "sub_ratio",
|
||||||
|
"target_id": "main_output", "target_port": "final_ratio"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"interface": {
|
||||||
|
"inputs": {
|
||||||
|
"main_text": "string",
|
||||||
|
"main_num": "integer"
|
||||||
|
},
|
||||||
|
"outputs": {
|
||||||
|
"final_ratio": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,73 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Ensure pytrace is in python path
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
|
|
||||||
from pytrace.api.node import TraceNode
|
|
||||||
from pytrace.api.decorators import register_class, register_node, input_port, output_port
|
|
||||||
from pytrace.runtime.context import TraceContext
|
|
||||||
from pytrace.remote.worker import TraceWorker, expose_node
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 1. Define Your Custom Nodes
|
|
||||||
# =============================================================================
|
|
||||||
# Use @expose_node to make the class available to the Remote Server.
|
|
||||||
# Use @register_class to define the category and display name in the UI.
|
|
||||||
|
|
||||||
@expose_node
|
|
||||||
@register_class(category="Tutorial", display_name="Client Demo")
|
|
||||||
class ClientDemoNode(TraceNode):
|
|
||||||
"""
|
|
||||||
A collection of demo nodes running on the client side.
|
|
||||||
Developers can add their own methods here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@register_node(display_name="Remote Add", description="Performs addition on the client.")
|
|
||||||
@input_port(name="a", type="float")
|
|
||||||
@input_port(name="b", type="float")
|
|
||||||
@output_port(name="result", type="float")
|
|
||||||
def add(self, ctx: TraceContext):
|
|
||||||
# 1. Get Inputs
|
|
||||||
a = ctx.get_input("a", 0.0)
|
|
||||||
b = ctx.get_input("b", 0.0)
|
|
||||||
|
|
||||||
# 2. Execute Logic
|
|
||||||
result = a + b
|
|
||||||
|
|
||||||
# 3. Log execution (This log will appear on the Server console!)
|
|
||||||
ctx.log(f"[Client] Executing Add: {a} + {b} = {result}")
|
|
||||||
|
|
||||||
# 4. Set Output
|
|
||||||
ctx.set_output("result", result)
|
|
||||||
|
|
||||||
@register_node(display_name="Remote Echo", description="Echoes a message.")
|
|
||||||
@input_port(name="message", type="string")
|
|
||||||
@output_port(name="echo", type="string")
|
|
||||||
def echo(self, ctx: TraceContext):
|
|
||||||
msg = ctx.get_input("message", "")
|
|
||||||
ctx.log(f"[Client] Echoing: {msg}")
|
|
||||||
ctx.set_output("echo", f"Client says: {msg}")
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 2. Start the Worker
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
SERVER_HOST = 'localhost'
|
|
||||||
SERVER_PORT = 9999
|
|
||||||
|
|
||||||
print(f"--- TraceStudio Client Demo ---")
|
|
||||||
print(f"Connecting to Server at {SERVER_HOST}:{SERVER_PORT}...")
|
|
||||||
|
|
||||||
# Initialize the Worker
|
|
||||||
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Start the worker (This blocks until stopped)
|
|
||||||
worker.start()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n[Client] Worker stopped.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[Client] Error: {e}")
|
|
||||||
@ -1,220 +0,0 @@
|
|||||||
import socket
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Add project root
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
|
|
||||||
from pytrace.api.decorators import output_port, register_class
|
|
||||||
from pytrace.remote.manager import RPCManager
|
|
||||||
from pytrace.remote.protocol import MessageType
|
|
||||||
from pytrace.model.graph import WorkflowGraph, Node, Edge
|
|
||||||
from pytrace.model.enums import DimensionMode
|
|
||||||
from pytrace.runtime.context import TraceContext
|
|
||||||
from pytrace.runtime.executor import WorkflowExecutor
|
|
||||||
from pytrace.remote.registry import RemoteRegistry
|
|
||||||
from pytrace.runtime.repository import NodeRepository
|
|
||||||
from pytrace.model.data import TraceParam, ParamMode
|
|
||||||
from pytrace.api.node import GraphTraceNode, TraceNode
|
|
||||||
from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry
|
|
||||||
|
|
||||||
class ClientSession:
|
|
||||||
"""
|
|
||||||
Represents a connected Worker Client on the Server.
|
|
||||||
Wraps the Socket connection and the RPCManager.
|
|
||||||
"""
|
|
||||||
def __init__(self, conn, addr, server: 'TraceServer'):
|
|
||||||
self.conn = conn
|
|
||||||
self.addr = addr
|
|
||||||
self.server = server
|
|
||||||
self.client_id = None
|
|
||||||
self.rpc = RPCManager(send_callback=self._send_json)
|
|
||||||
|
|
||||||
# Register internal handlers
|
|
||||||
self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register)
|
|
||||||
|
|
||||||
def _send_json(self, data):
|
|
||||||
try:
|
|
||||||
msg = json.dumps(data) + "\n"
|
|
||||||
self.conn.sendall(msg.encode('utf-8'))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Session] Send error: {e}")
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def _handle_register(self, payload, msg_id):
|
|
||||||
self.client_id = payload.get("client_id")
|
|
||||||
nodes = payload.get("nodes", [])
|
|
||||||
print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.")
|
|
||||||
|
|
||||||
# Update Server Registry
|
|
||||||
self.server.remote_registry.register_client(self.client_id, nodes)
|
|
||||||
self.server.register_session(self.client_id, self)
|
|
||||||
|
|
||||||
def start_loop(self):
|
|
||||||
"""Blocking read loop."""
|
|
||||||
f = self.conn.makefile('r', encoding='utf-8')
|
|
||||||
try:
|
|
||||||
for line in f:
|
|
||||||
if not line: break
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
self.rpc.process_message(data)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Session] Msg error: {e}")
|
|
||||||
finally:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
print(f"[Server] Session {self.client_id} disconnected")
|
|
||||||
self.rpc.close()
|
|
||||||
try:
|
|
||||||
self.conn.close()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if self.client_id:
|
|
||||||
self.server.unregister_session(self.client_id)
|
|
||||||
|
|
||||||
class TraceServer:
|
|
||||||
def __init__(self, socket_port=9999):
|
|
||||||
self.socket_port = socket_port
|
|
||||||
self.sessions = {} # client_id -> ClientSession
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.remote_registry = RemoteRegistry()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
# 1. Start Socket Server (For Workers)
|
|
||||||
threading.Thread(target=self._socket_server_loop, daemon=True).start()
|
|
||||||
|
|
||||||
# 2. Start "HTTP" Server (For Users - Simulated)
|
|
||||||
threading.Thread(target=self._http_server_simulation, daemon=True).start()
|
|
||||||
|
|
||||||
print(f"[Server] Running. Socket Port: {self.socket_port}")
|
|
||||||
|
|
||||||
# Keep main thread alive
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
def _socket_server_loop(self):
|
|
||||||
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
server_socket.bind(('localhost', self.socket_port))
|
|
||||||
server_socket.listen(5)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
conn, addr = server_socket.accept()
|
|
||||||
session = ClientSession(conn, addr, self)
|
|
||||||
threading.Thread(target=session.start_loop, daemon=True).start()
|
|
||||||
|
|
||||||
def register_session(self, client_id, session):
|
|
||||||
with self.lock:
|
|
||||||
self.sessions[client_id] = session
|
|
||||||
|
|
||||||
def unregister_session(self, client_id):
|
|
||||||
with self.lock:
|
|
||||||
self.sessions.pop(client_id, None)
|
|
||||||
|
|
||||||
# --- RPC Service Interface for Workflow ---
|
|
||||||
def call(self, target_id: str, method: str, params: dict, timeout=3600):
|
|
||||||
# 1. Routing Logic
|
|
||||||
if not target_id:
|
|
||||||
providers = self.remote_registry.find_providers(method)
|
|
||||||
if not providers:
|
|
||||||
raise RuntimeError(f"No active client provides node '{method}'")
|
|
||||||
target_id = providers[0]
|
|
||||||
print(f"[Server] Auto-routing '{method}' to client '{target_id}'")
|
|
||||||
|
|
||||||
# 2. Get Session
|
|
||||||
session = self.sessions.get(target_id)
|
|
||||||
if not session:
|
|
||||||
raise RuntimeError(f"Client {target_id} not connected")
|
|
||||||
|
|
||||||
# 3. Execute via Session's RPC
|
|
||||||
return session.rpc.call(target_id, method, params, timeout)
|
|
||||||
|
|
||||||
# --- HTTP API Simulation ---
|
|
||||||
def _http_server_simulation(self):
|
|
||||||
"""
|
|
||||||
Simulates an HTTP endpoint that receives workflow execution requests.
|
|
||||||
In a real app, this would be Flask/FastAPI.
|
|
||||||
"""
|
|
||||||
print("[Server] HTTP API ready. Waiting for requests...")
|
|
||||||
|
|
||||||
# Simulate a user request coming in after 5 seconds
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
# Mock Request Payload
|
|
||||||
# User explicitly wants 'ClientMathNode.add' to run on a specific worker if available
|
|
||||||
# Or they might leave it empty for auto-routing.
|
|
||||||
|
|
||||||
# Let's find a connected client to simulate the user picking one
|
|
||||||
target_worker_id = None
|
|
||||||
while not target_worker_id:
|
|
||||||
with self.lock:
|
|
||||||
if self.sessions:
|
|
||||||
target_worker_id = list(self.sessions.keys())[0]
|
|
||||||
if not target_worker_id:
|
|
||||||
print("[Server] HTTP Waiting for workers to connect...")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<")
|
|
||||||
|
|
||||||
# The User Request Data
|
|
||||||
request_payload = {
|
|
||||||
"inputs": {"x": 10, "y": 20},
|
|
||||||
"node_targets": {
|
|
||||||
# The user specifies that the node with ID 'node_remote_1' in the graph
|
|
||||||
# MUST be executed by 'target_worker_id'
|
|
||||||
"node_remote_1": target_worker_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.run_workflow(request_payload)
|
|
||||||
|
|
||||||
def run_workflow(self, payload):
|
|
||||||
print("[Server] Loading test workflow: tests/assets/main.flow")
|
|
||||||
flow_path = os.path.abspath("tests/assets/main.flow")
|
|
||||||
with open(flow_path, 'r', encoding='utf-8') as f:
|
|
||||||
graph_data = json.load(f)
|
|
||||||
|
|
||||||
# Inject a node ID into the graph data for demonstration if needed,
|
|
||||||
# or assume the graph already has 'node_remote_1'.
|
|
||||||
# For this demo, let's assume the graph's first node is the one we want to target.
|
|
||||||
# (In reality, graph_data comes from the frontend with IDs)
|
|
||||||
|
|
||||||
graph = WorkflowGraph.from_dict(graph_data)
|
|
||||||
|
|
||||||
# Prepare Repository
|
|
||||||
repo = NodeRepository(remote_registry=self.remote_registry)
|
|
||||||
|
|
||||||
# Prepare Context with User Preferences (node_targets)
|
|
||||||
context = TraceContext(
|
|
||||||
contexts={},
|
|
||||||
services={"rpc": self, "node_repository": repo},
|
|
||||||
node_targets=payload.get("node_targets")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
print("[Server] Executing workflow...")
|
|
||||||
executor = WorkflowExecutor(graph, context, node_repository=repo)
|
|
||||||
|
|
||||||
inputs = payload.get("inputs")
|
|
||||||
executor.execute_batch(inputs)
|
|
||||||
|
|
||||||
# 7. Verify Results
|
|
||||||
outputs = context.context.get_outputs()
|
|
||||||
print(f"[Server] Workflow finished. Outputs: {outputs}")
|
|
||||||
|
|
||||||
res = outputs.get("final_result")
|
|
||||||
val = res.value if hasattr(res, 'value') else res
|
|
||||||
|
|
||||||
if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15
|
|
||||||
print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}")
|
|
||||||
else:
|
|
||||||
print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
server = TraceServer()
|
|
||||||
server.start()
|
|
||||||
93
tests/run_remote_test.py
Normal file
93
tests/run_remote_test.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
import os
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
SERVER_URL = "http://localhost:8000"
|
||||||
|
GRAPH_EXECUTE_ENDPOINT = f"{SERVER_URL}/execute_graph"
|
||||||
|
GRAPH_FILE_PATH = os.path.join(os.path.dirname(__file__), "assets", "complex_remote_graph.json")
|
||||||
|
|
||||||
|
def run_test():
|
||||||
|
"""
|
||||||
|
执行端到端远程调用测试。
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
1. 确保 remote_test_server.py 正在运行 (uvicorn server.remote_test_server:app --reload --port 8000)
|
||||||
|
2. 确保 test_complex_remote_client.py 正在运行 (python tests/test_complex_remote_client.py)
|
||||||
|
3. 运行此脚本 (python tests/run_remote_test.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("--- Running End-to-End Remote Execution Test ---")
|
||||||
|
|
||||||
|
# 1. 加载图JSON文件
|
||||||
|
try:
|
||||||
|
with open(GRAPH_FILE_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
graph_data = json.load(f)
|
||||||
|
print(f"Successfully loaded graph from: {GRAPH_FILE_PATH}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"ERROR: Graph file not found at {GRAPH_FILE_PATH}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 准备初始输入
|
||||||
|
# 根据 complex_remote_graph.json 的 "interface" 定义
|
||||||
|
initial_inputs = {
|
||||||
|
"main_text": "hello_remote_world",
|
||||||
|
"main_num": 5
|
||||||
|
}
|
||||||
|
print(f"Initial inputs: {initial_inputs}")
|
||||||
|
|
||||||
|
# 3. 构建请求体
|
||||||
|
request_payload = {
|
||||||
|
"graph_data": graph_data,
|
||||||
|
"initial_inputs": initial_inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 发送HTTP请求到服务器
|
||||||
|
print(f"Sending POST request to {GRAPH_EXECUTE_ENDPOINT}...")
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(GRAPH_EXECUTE_ENDPOINT, json=request_payload)
|
||||||
|
response.raise_for_status() # 如果状态码不是 2xx,则会引发异常
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
print("\n--- Test Result ---")
|
||||||
|
print("Request successful!")
|
||||||
|
print("Server response:")
|
||||||
|
# 使用json.dumps美化输出
|
||||||
|
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
# 预期结果计算:
|
||||||
|
# 1. main_text "hello_remote_world" 长度为 18
|
||||||
|
# 2. 子图接收 sub_in_1=18, sub_in_2=5
|
||||||
|
# 3. 子图内本地相加: 18 + 5 = 23
|
||||||
|
# 4. 子图内远程计算比率: 23 / 10.0 = 2.3
|
||||||
|
# 5. 主图接收到最终结果 2.3
|
||||||
|
expected_ratio = 2.3
|
||||||
|
final_ratio = result.get("outputs", {}).get("final_ratio", {}).get("value")
|
||||||
|
|
||||||
|
print("\n--- Verification ---")
|
||||||
|
print(f"Expected final_ratio: {expected_ratio}")
|
||||||
|
print(f"Actual final_ratio: {final_ratio}")
|
||||||
|
assert final_ratio == expected_ratio, "Test assertion failed!"
|
||||||
|
print("✅ Test Passed Successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print("\n--- TEST FAILED ---")
|
||||||
|
print(f"ERROR: Could not connect to the server at {SERVER_URL}.")
|
||||||
|
print("Please ensure the 'remote_test_server.py' is running.")
|
||||||
|
print(f"Details: {e}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
print("\n--- TEST FAILED ---")
|
||||||
|
print(f"ERROR: Server returned a non-200 status code: {e.response.status_code}")
|
||||||
|
print("Server response:")
|
||||||
|
try:
|
||||||
|
# 尝试打印JSON错误详情
|
||||||
|
print(json.dumps(e.response.json(), indent=2, ensure_ascii=False))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(e.response.text)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_test()
|
||||||
104
tests/test_complex_remote_client.py
Normal file
104
tests/test_complex_remote_client.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 确保 pytrace 在 Python 路径中
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from pytrace.api.node import TraceNode
|
||||||
|
from pytrace.api.decorators import register_class, register_node, input_port, output_port
|
||||||
|
from pytrace.runtime.context import TraceContext
|
||||||
|
from pytrace.remote.worker import TraceWorker, expose_node
|
||||||
|
|
||||||
|
# --- 1. 定义可被远程调用的节点 ---
|
||||||
|
# 使用 @expose_node 将此类暴露给服务器
|
||||||
|
@expose_node
|
||||||
|
@register_class(category="ComplexRemote", display_name="Complex Remote Node")
|
||||||
|
class ComplexRemoteNode(TraceNode):
|
||||||
|
"""
|
||||||
|
这个类包含一组用于远程调用的方法(节点)。
|
||||||
|
每个方法都添加了详细的打印语句,方便您在客户端控制台进行调试。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_node(
|
||||||
|
class_id="remote.process_text",
|
||||||
|
display_name="Process Text (Remote)",
|
||||||
|
description="在客户端处理文本,计算长度并添加前缀。"
|
||||||
|
)
|
||||||
|
@input_port(name="text_in", type="string")
|
||||||
|
@output_port(name="text_out", type="string")
|
||||||
|
@output_port(name="length", type="integer")
|
||||||
|
def process_text(self, ctx: TraceContext):
|
||||||
|
text = ctx.get_input("text_in", "")
|
||||||
|
|
||||||
|
print("\n--- [REMOTE] ---")
|
||||||
|
print(f"Node 'remote.process_text' is executing.")
|
||||||
|
print(f" - Received input 'text_in': '{text}'")
|
||||||
|
|
||||||
|
# 模拟一些耗时操作
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
processed_text = f"processed_by_client:_{text}"
|
||||||
|
text_length = len(text)
|
||||||
|
|
||||||
|
ctx.set_output("text_out", processed_text)
|
||||||
|
ctx.set_output("length", text_length)
|
||||||
|
|
||||||
|
print(f" - Sent output 'text_out': '{processed_text}'")
|
||||||
|
print(f" - Sent output 'length': {text_length}")
|
||||||
|
print("--- [REMOTE] ---\
|
||||||
|
")
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(
|
||||||
|
class_id="remote.calculate_ratio",
|
||||||
|
display_name="Calculate Ratio (Remote)",
|
||||||
|
description="在客户端根据两个数字计算比率。"
|
||||||
|
)
|
||||||
|
@input_port(name="numerator", type="float", label="分子")
|
||||||
|
@input_port(name="denominator", type="float", label="分母")
|
||||||
|
@output_port(name="ratio", type="float", label="比率")
|
||||||
|
def calculate_ratio(self, ctx: TraceContext):
|
||||||
|
numerator = ctx.get_input("numerator", 0.0)
|
||||||
|
denominator = ctx.get_input("denominator", 1.0)
|
||||||
|
|
||||||
|
print("\n--- [REMOTE] ---")
|
||||||
|
print(f"Node 'remote.calculate_ratio' is executing.")
|
||||||
|
print(f" - Received input 'numerator': {numerator}")
|
||||||
|
print(f" - Received input 'denominator': {denominator}")
|
||||||
|
|
||||||
|
if denominator == 0:
|
||||||
|
print(" - Error: Denominator is zero.")
|
||||||
|
# 在实际应用中可能需要更复杂的错误处理
|
||||||
|
ratio = 0.0
|
||||||
|
else:
|
||||||
|
ratio = numerator / denominator
|
||||||
|
|
||||||
|
ctx.set_output("ratio", ratio)
|
||||||
|
print(f" - Sent output 'ratio': {ratio}")
|
||||||
|
print("--- [REMOTE] ---\
|
||||||
|
")
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. 启动客户端 Worker ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 服务器的TCP RPC地址和端口
|
||||||
|
SERVER_HOST = 'localhost'
|
||||||
|
SERVER_PORT = 9999 # 这应该与 remote_test_server.py 中的端口一致
|
||||||
|
|
||||||
|
print(f"--- Complex Remote Test Client ---")
|
||||||
|
print(f"Attempting to connect to RPC Server at {SERVER_HOST}:{SERVER_PORT}...")
|
||||||
|
|
||||||
|
# 初始化并启动 Worker
|
||||||
|
# 它会自动查找所有被 @expose_node 装饰的类,并向服务器注册
|
||||||
|
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# worker.start() 是一个阻塞操作,它会保持运行以监听服务器的调用
|
||||||
|
worker.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[Client] Worker stopped by user.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[Client] An error occurred: {e}")
|
||||||
|
print("[Client] Is the test server running?")
|
||||||
Loading…
Reference in New Issue
Block a user