diff --git a/server/remote_test_server.py b/server/remote_test_server.py new file mode 100644 index 0000000..4b11bd3 --- /dev/null +++ b/server/remote_test_server.py @@ -0,0 +1,93 @@ + +import sys +import os +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +from typing import Dict, Any +import uvicorn + +# 确保 pytrace 在 Python 路径中 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from pytrace.model.graph import WorkflowGraph +from pytrace.runtime.context import TraceContext +from pytrace.runtime.executor import WorkflowExecutor +from pytrace.runtime.repository import NodeRepository +# 导入项目中已有的 TraceServer +from app.services import trace_server +# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用 +from pytrace.nodes.test_math_suite import MathSuite + + +# --- 1. 数据模型 --- + +class GraphExecutionRequest(BaseModel): + graph_data: Dict[str, Any] + initial_inputs: Dict[str, Any] = Field(default_factory=dict) + + +# --- 2. FastAPI 应用定义 --- + +app = FastAPI( + title="TraceStudio - Remote Test Server", + description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。", + version="2.0.0", +) + +@app.on_event("startup") +async def startup_event(): + # 1. 在后台启动TCP服务器,用于监听客户端的连接和注册 + trace_server.start_background() + + # 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository + NodeRepository.set_remote_registry(trace_server.remote_registry) + + print("--- Remote Test Server ---") + print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.") + print("HTTP Server ready to execute graphs at /execute_graph") + + +@app.post("/execute_graph", tags=["Execution"]) +async def execute_graph(request: GraphExecutionRequest): + """ + 接收图的JSON并执行,可能会调用已通过TCP连接注册的远程节点。 + """ + try: + print("\n--- New Graph Execution Request ---") + # 1. 从请求数据构建 WorkflowGraph + graph = WorkflowGraph.from_dict(request.graph_data) + print(f"Executing graph: '{graph.name}'") + + # 2. 准备执行上下文,注入 trace_server 实例作为RPC服务 + # RemoteProxyNode 将通过这个服务来回调客户端 + context = TraceContext(services={"rpc": trace_server}) + + # 3. 实例化执行器并运行 + executor = WorkflowExecutor(graph, context) + executor.execute_batch(request.initial_inputs) + + # 4. 获取并返回最终输出 + final_outputs = context.get_outputs() + print(f"Graph execution finished. Final outputs: {final_outputs}") + return { + "success": True, + "outputs": final_outputs + } + except Exception as e: + import traceback + print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, + detail={ + "error": str(e), + "type": type(e).__name__, + "traceback": traceback.format_exc() + } + ) + +# --- 3. 启动服务器 --- + +if __name__ == "__main__": + # 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。 + # 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000 + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/tests/assets/complex_remote_graph.json b/tests/assets/complex_remote_graph.json new file mode 100644 index 0000000..f1c4420 --- /dev/null +++ b/tests/assets/complex_remote_graph.json @@ -0,0 +1,103 @@ + +{ + "name": "Complex Remote and Subgraph Test", + "nodes": [ + { + "id": "main_input", + "class_id": "InputTraceNode", + "params": { + "main_text": { "mode": "context", "value": "inputs.main_text" }, + "main_num": { "mode": "context", "value": "inputs.main_num" } + } + }, + { + "id": "remote_text_processor", + "class_id": "remote.process_text" + }, + { + "id": "sub_graph_node", + "class_id": "core.Graph", + "data": { + "name": "My SubGraph", + "nodes": [ + { + "id": "sub_input", + "class_id": "InputTraceNode", + "params": { + "sub_in_1": { "mode": "context", "value": "inputs.sub_in_1" }, + "sub_in_2": { "mode": "context", "value": "inputs.sub_in_2" } + } + }, + { + "id": "local_add", + "class_id": "math.add" + }, + { + "id": "remote_ratio_calculator", + "class_id": "remote.calculate_ratio", + "params": { + "denominator": { "mode": "static", "value": 10.0 } + } + }, + { + "id": "sub_output", + "class_id": "OutputTraceNode" + } + ], + "edges": [ + { + "source_id": "sub_input", "source_port": "sub_in_1", + "target_id": "local_add", "target_port": "a" + }, + { + "source_id": "sub_input", "source_port": "sub_in_2", + "target_id": "local_add", "target_port": "b" + }, + { + "source_id": "local_add", "source_port": "c", + "target_id": "remote_ratio_calculator", "target_port": "numerator" + }, + { + "source_id": "remote_ratio_calculator", "source_port": "ratio", + "target_id": "sub_output", "target_port": "sub_ratio" + } + ], + "interface": { + "inputs": { "sub_in_1": "integer", "sub_in_2": "integer" }, + "outputs": { "sub_ratio": "float" } + } + } + }, + { + "id": "main_output", + "class_id": "OutputTraceNode" + } + ], + "edges": [ + { + "source_id": "main_input", "source_port": "main_text", + "target_id": "remote_text_processor", "target_port": "text_in" + }, + { + "source_id": "remote_text_processor", "source_port": "length", + "target_id": "sub_graph_node", "target_port": "sub_in_1" + }, + { + "source_id": "main_input", "source_port": "main_num", + "target_id": "sub_graph_node", "target_port": "sub_in_2" + }, + { + "source_id": "sub_graph_node", "source_port": "sub_ratio", + "target_id": "main_output", "target_port": "final_ratio" + } + ], + "interface": { + "inputs": { + "main_text": "string", + "main_num": "integer" + }, + "outputs": { + "final_ratio": "float" + } + } +} diff --git a/tests/demo_client.py b/tests/demo_client.py deleted file mode 100644 index f9476ca..0000000 --- a/tests/demo_client.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys -import os -import time - -# Ensure pytrace is in python path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from pytrace.api.node import TraceNode -from pytrace.api.decorators import register_class, register_node, input_port, output_port -from pytrace.runtime.context import TraceContext -from pytrace.remote.worker import TraceWorker, expose_node - -# ============================================================================= -# 1. Define Your Custom Nodes -# ============================================================================= -# Use @expose_node to make the class available to the Remote Server. -# Use @register_class to define the category and display name in the UI. - -@expose_node -@register_class(category="Tutorial", display_name="Client Demo") -class ClientDemoNode(TraceNode): - """ - A collection of demo nodes running on the client side. - Developers can add their own methods here. - """ - - @register_node(display_name="Remote Add", description="Performs addition on the client.") - @input_port(name="a", type="float") - @input_port(name="b", type="float") - @output_port(name="result", type="float") - def add(self, ctx: TraceContext): - # 1. Get Inputs - a = ctx.get_input("a", 0.0) - b = ctx.get_input("b", 0.0) - - # 2. Execute Logic - result = a + b - - # 3. Log execution (This log will appear on the Server console!) - ctx.log(f"[Client] Executing Add: {a} + {b} = {result}") - - # 4. Set Output - ctx.set_output("result", result) - - @register_node(display_name="Remote Echo", description="Echoes a message.") - @input_port(name="message", type="string") - @output_port(name="echo", type="string") - def echo(self, ctx: TraceContext): - msg = ctx.get_input("message", "") - ctx.log(f"[Client] Echoing: {msg}") - ctx.set_output("echo", f"Client says: {msg}") - -# ============================================================================= -# 2. Start the Worker -# ============================================================================= - -if __name__ == "__main__": - SERVER_HOST = 'localhost' - SERVER_PORT = 9999 - - print(f"--- TraceStudio Client Demo ---") - print(f"Connecting to Server at {SERVER_HOST}:{SERVER_PORT}...") - - # Initialize the Worker - worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT) - - try: - # Start the worker (This blocks until stopped) - worker.start() - except KeyboardInterrupt: - print("\n[Client] Worker stopped.") - except Exception as e: - print(f"\n[Client] Error: {e}") \ No newline at end of file diff --git a/tests/demo_server.py b/tests/demo_server.py deleted file mode 100644 index ec027e7..0000000 --- a/tests/demo_server.py +++ /dev/null @@ -1,220 +0,0 @@ -import socket -import json -import threading -import sys -import os -import time - -# Add project root -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from pytrace.api.decorators import output_port, register_class -from pytrace.remote.manager import RPCManager -from pytrace.remote.protocol import MessageType -from pytrace.model.graph import WorkflowGraph, Node, Edge -from pytrace.model.enums import DimensionMode -from pytrace.runtime.context import TraceContext -from pytrace.runtime.executor import WorkflowExecutor -from pytrace.remote.registry import RemoteRegistry -from pytrace.runtime.repository import NodeRepository -from pytrace.model.data import TraceParam, ParamMode -from pytrace.api.node import GraphTraceNode, TraceNode -from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry - -class ClientSession: - """ - Represents a connected Worker Client on the Server. - Wraps the Socket connection and the RPCManager. - """ - def __init__(self, conn, addr, server: 'TraceServer'): - self.conn = conn - self.addr = addr - self.server = server - self.client_id = None - self.rpc = RPCManager(send_callback=self._send_json) - - # Register internal handlers - self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register) - - def _send_json(self, data): - try: - msg = json.dumps(data) + "\n" - self.conn.sendall(msg.encode('utf-8')) - except Exception as e: - print(f"[Session] Send error: {e}") - self.close() - - def _handle_register(self, payload, msg_id): - self.client_id = payload.get("client_id") - nodes = payload.get("nodes", []) - print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.") - - # Update Server Registry - self.server.remote_registry.register_client(self.client_id, nodes) - self.server.register_session(self.client_id, self) - - def start_loop(self): - """Blocking read loop.""" - f = self.conn.makefile('r', encoding='utf-8') - try: - for line in f: - if not line: break - try: - data = json.loads(line) - self.rpc.process_message(data) - except Exception as e: - print(f"[Session] Msg error: {e}") - finally: - self.close() - - def close(self): - print(f"[Server] Session {self.client_id} disconnected") - self.rpc.close() - try: - self.conn.close() - except: - pass - if self.client_id: - self.server.unregister_session(self.client_id) - -class TraceServer: - def __init__(self, socket_port=9999): - self.socket_port = socket_port - self.sessions = {} # client_id -> ClientSession - self.lock = threading.Lock() - self.remote_registry = RemoteRegistry() - - def start(self): - # 1. Start Socket Server (For Workers) - threading.Thread(target=self._socket_server_loop, daemon=True).start() - - # 2. Start "HTTP" Server (For Users - Simulated) - threading.Thread(target=self._http_server_simulation, daemon=True).start() - - print(f"[Server] Running. Socket Port: {self.socket_port}") - - # Keep main thread alive - while True: - time.sleep(1) - - def _socket_server_loop(self): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', self.socket_port)) - server_socket.listen(5) - - while True: - conn, addr = server_socket.accept() - session = ClientSession(conn, addr, self) - threading.Thread(target=session.start_loop, daemon=True).start() - - def register_session(self, client_id, session): - with self.lock: - self.sessions[client_id] = session - - def unregister_session(self, client_id): - with self.lock: - self.sessions.pop(client_id, None) - - # --- RPC Service Interface for Workflow --- - def call(self, target_id: str, method: str, params: dict, timeout=3600): - # 1. Routing Logic - if not target_id: - providers = self.remote_registry.find_providers(method) - if not providers: - raise RuntimeError(f"No active client provides node '{method}'") - target_id = providers[0] - print(f"[Server] Auto-routing '{method}' to client '{target_id}'") - - # 2. Get Session - session = self.sessions.get(target_id) - if not session: - raise RuntimeError(f"Client {target_id} not connected") - - # 3. Execute via Session's RPC - return session.rpc.call(target_id, method, params, timeout) - - # --- HTTP API Simulation --- - def _http_server_simulation(self): - """ - Simulates an HTTP endpoint that receives workflow execution requests. - In a real app, this would be Flask/FastAPI. - """ - print("[Server] HTTP API ready. Waiting for requests...") - - # Simulate a user request coming in after 5 seconds - time.sleep(5) - - # Mock Request Payload - # User explicitly wants 'ClientMathNode.add' to run on a specific worker if available - # Or they might leave it empty for auto-routing. - - # Let's find a connected client to simulate the user picking one - target_worker_id = None - while not target_worker_id: - with self.lock: - if self.sessions: - target_worker_id = list(self.sessions.keys())[0] - if not target_worker_id: - print("[Server] HTTP Waiting for workers to connect...") - time.sleep(2) - - print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<") - - # The User Request Data - request_payload = { - "inputs": {"x": 10, "y": 20}, - "node_targets": { - # The user specifies that the node with ID 'node_remote_1' in the graph - # MUST be executed by 'target_worker_id' - "node_remote_1": target_worker_id - } - } - - self.run_workflow(request_payload) - - def run_workflow(self, payload): - print("[Server] Loading test workflow: tests/assets/main.flow") - flow_path = os.path.abspath("tests/assets/main.flow") - with open(flow_path, 'r', encoding='utf-8') as f: - graph_data = json.load(f) - - # Inject a node ID into the graph data for demonstration if needed, - # or assume the graph already has 'node_remote_1'. - # For this demo, let's assume the graph's first node is the one we want to target. - # (In reality, graph_data comes from the frontend with IDs) - - graph = WorkflowGraph.from_dict(graph_data) - - # Prepare Repository - repo = NodeRepository(remote_registry=self.remote_registry) - - # Prepare Context with User Preferences (node_targets) - context = TraceContext( - contexts={}, - services={"rpc": self, "node_repository": repo}, - node_targets=payload.get("node_targets") - ) - - # Execute - print("[Server] Executing workflow...") - executor = WorkflowExecutor(graph, context, node_repository=repo) - - inputs = payload.get("inputs") - executor.execute_batch(inputs) - - # 7. Verify Results - outputs = context.context.get_outputs() - print(f"[Server] Workflow finished. Outputs: {outputs}") - - res = outputs.get("final_result") - val = res.value if hasattr(res, 'value') else res - - if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15 - print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}") - else: - print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}") - -if __name__ == "__main__": - server = TraceServer() - server.start() \ No newline at end of file diff --git a/tests/run_remote_test.py b/tests/run_remote_test.py new file mode 100644 index 0000000..ea13595 --- /dev/null +++ b/tests/run_remote_test.py @@ -0,0 +1,93 @@ + +import json +import httpx +import os + +# --- 配置 --- +SERVER_URL = "http://localhost:8000" +GRAPH_EXECUTE_ENDPOINT = f"{SERVER_URL}/execute_graph" +GRAPH_FILE_PATH = os.path.join(os.path.dirname(__file__), "assets", "complex_remote_graph.json") + +def run_test(): + """ + 执行端到端远程调用测试。 + + 使用方法: + 1. 确保 remote_test_server.py 正在运行 (uvicorn server.remote_test_server:app --reload --port 8000) + 2. 确保 test_complex_remote_client.py 正在运行 (python tests/test_complex_remote_client.py) + 3. 运行此脚本 (python tests/run_remote_test.py) + """ + + print("--- Running End-to-End Remote Execution Test ---") + + # 1. 加载图JSON文件 + try: + with open(GRAPH_FILE_PATH, 'r', encoding='utf-8') as f: + graph_data = json.load(f) + print(f"Successfully loaded graph from: {GRAPH_FILE_PATH}") + except FileNotFoundError: + print(f"ERROR: Graph file not found at {GRAPH_FILE_PATH}") + return + + # 2. 准备初始输入 + # 根据 complex_remote_graph.json 的 "interface" 定义 + initial_inputs = { + "main_text": "hello_remote_world", + "main_num": 5 + } + print(f"Initial inputs: {initial_inputs}") + + # 3. 构建请求体 + request_payload = { + "graph_data": graph_data, + "initial_inputs": initial_inputs + } + + # 4. 发送HTTP请求到服务器 + print(f"Sending POST request to {GRAPH_EXECUTE_ENDPOINT}...") + try: + with httpx.Client(timeout=60.0) as client: + response = client.post(GRAPH_EXECUTE_ENDPOINT, json=request_payload) + response.raise_for_status() # 如果状态码不是 2xx,则会引发异常 + + result = response.json() + + print("\n--- Test Result ---") + print("Request successful!") + print("Server response:") + # 使用json.dumps美化输出 + print(json.dumps(result, indent=2, ensure_ascii=False)) + + # 预期结果计算: + # 1. main_text "hello_remote_world" 长度为 18 + # 2. 子图接收 sub_in_1=18, sub_in_2=5 + # 3. 子图内本地相加: 18 + 5 = 23 + # 4. 子图内远程计算比率: 23 / 10.0 = 2.3 + # 5. 主图接收到最终结果 2.3 + expected_ratio = 2.3 + final_ratio = result.get("outputs", {}).get("final_ratio", {}).get("value") + + print("\n--- Verification ---") + print(f"Expected final_ratio: {expected_ratio}") + print(f"Actual final_ratio: {final_ratio}") + assert final_ratio == expected_ratio, "Test assertion failed!" + print("✅ Test Passed Successfully!") + + + except httpx.RequestError as e: + print("\n--- TEST FAILED ---") + print(f"ERROR: Could not connect to the server at {SERVER_URL}.") + print("Please ensure the 'remote_test_server.py' is running.") + print(f"Details: {e}") + except httpx.HTTPStatusError as e: + print("\n--- TEST FAILED ---") + print(f"ERROR: Server returned a non-200 status code: {e.response.status_code}") + print("Server response:") + try: + # 尝试打印JSON错误详情 + print(json.dumps(e.response.json(), indent=2, ensure_ascii=False)) + except json.JSONDecodeError: + print(e.response.text) + +if __name__ == "__main__": + run_test() diff --git a/tests/test_complex_remote_client.py b/tests/test_complex_remote_client.py new file mode 100644 index 0000000..139abfc --- /dev/null +++ b/tests/test_complex_remote_client.py @@ -0,0 +1,104 @@ + +import sys +import os +import time + +# 确保 pytrace 在 Python 路径中 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from pytrace.api.node import TraceNode +from pytrace.api.decorators import register_class, register_node, input_port, output_port +from pytrace.runtime.context import TraceContext +from pytrace.remote.worker import TraceWorker, expose_node + +# --- 1. 定义可被远程调用的节点 --- +# 使用 @expose_node 将此类暴露给服务器 +@expose_node +@register_class(category="ComplexRemote", display_name="Complex Remote Node") +class ComplexRemoteNode(TraceNode): + """ + 这个类包含一组用于远程调用的方法(节点)。 + 每个方法都添加了详细的打印语句,方便您在客户端控制台进行调试。 + """ + + @register_node( + class_id="remote.process_text", + display_name="Process Text (Remote)", + description="在客户端处理文本,计算长度并添加前缀。" + ) + @input_port(name="text_in", type="string") + @output_port(name="text_out", type="string") + @output_port(name="length", type="integer") + def process_text(self, ctx: TraceContext): + text = ctx.get_input("text_in", "") + + print("\n--- [REMOTE] ---") + print(f"Node 'remote.process_text' is executing.") + print(f" - Received input 'text_in': '{text}'") + + # 模拟一些耗时操作 + time.sleep(0.5) + + processed_text = f"processed_by_client:_{text}" + text_length = len(text) + + ctx.set_output("text_out", processed_text) + ctx.set_output("length", text_length) + + print(f" - Sent output 'text_out': '{processed_text}'") + print(f" - Sent output 'length': {text_length}") + print("--- [REMOTE] ---\ +") + + + @register_node( + class_id="remote.calculate_ratio", + display_name="Calculate Ratio (Remote)", + description="在客户端根据两个数字计算比率。" + ) + @input_port(name="numerator", type="float", label="分子") + @input_port(name="denominator", type="float", label="分母") + @output_port(name="ratio", type="float", label="比率") + def calculate_ratio(self, ctx: TraceContext): + numerator = ctx.get_input("numerator", 0.0) + denominator = ctx.get_input("denominator", 1.0) + + print("\n--- [REMOTE] ---") + print(f"Node 'remote.calculate_ratio' is executing.") + print(f" - Received input 'numerator': {numerator}") + print(f" - Received input 'denominator': {denominator}") + + if denominator == 0: + print(" - Error: Denominator is zero.") + # 在实际应用中可能需要更复杂的错误处理 + ratio = 0.0 + else: + ratio = numerator / denominator + + ctx.set_output("ratio", ratio) + print(f" - Sent output 'ratio': {ratio}") + print("--- [REMOTE] ---\ +") + + +# --- 2. 启动客户端 Worker --- +if __name__ == "__main__": + # 服务器的TCP RPC地址和端口 + SERVER_HOST = 'localhost' + SERVER_PORT = 9999 # 这应该与 remote_test_server.py 中的端口一致 + + print(f"--- Complex Remote Test Client ---") + print(f"Attempting to connect to RPC Server at {SERVER_HOST}:{SERVER_PORT}...") + + # 初始化并启动 Worker + # 它会自动查找所有被 @expose_node 装饰的类,并向服务器注册 + worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT) + + try: + # worker.start() 是一个阻塞操作,它会保持运行以监听服务器的调用 + worker.start() + except KeyboardInterrupt: + print("\n[Client] Worker stopped by user.") + except Exception as e: + print(f"\n[Client] An error occurred: {e}") + print("[Client] Is the test server running?")