add tests
This commit is contained in:
parent
c2335ae1a6
commit
649e2edca1
93
server/remote_test_server.py
Normal file
93
server/remote_test_server.py
Normal file
@ -0,0 +1,93 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any
|
||||
import uvicorn
|
||||
|
||||
# 确保 pytrace 在 Python 路径中
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from pytrace.model.graph import WorkflowGraph
|
||||
from pytrace.runtime.context import TraceContext
|
||||
from pytrace.runtime.executor import WorkflowExecutor
|
||||
from pytrace.runtime.repository import NodeRepository
|
||||
# 导入项目中已有的 TraceServer
|
||||
from app.services import trace_server
|
||||
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
|
||||
from pytrace.nodes.test_math_suite import MathSuite
|
||||
|
||||
|
||||
# --- 1. 数据模型 ---
|
||||
|
||||
class GraphExecutionRequest(BaseModel):
|
||||
graph_data: Dict[str, Any]
|
||||
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# --- 2. FastAPI 应用定义 ---
|
||||
|
||||
app = FastAPI(
|
||||
title="TraceStudio - Remote Test Server",
|
||||
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
|
||||
version="2.0.0",
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# 1. 在后台启动TCP服务器,用于监听客户端的连接和注册
|
||||
trace_server.start_background()
|
||||
|
||||
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
|
||||
NodeRepository.set_remote_registry(trace_server.remote_registry)
|
||||
|
||||
print("--- Remote Test Server ---")
|
||||
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
|
||||
print("HTTP Server ready to execute graphs at /execute_graph")
|
||||
|
||||
|
||||
@app.post("/execute_graph", tags=["Execution"])
|
||||
async def execute_graph(request: GraphExecutionRequest):
|
||||
"""
|
||||
接收图的JSON并执行,可能会调用已通过TCP连接注册的远程节点。
|
||||
"""
|
||||
try:
|
||||
print("\n--- New Graph Execution Request ---")
|
||||
# 1. 从请求数据构建 WorkflowGraph
|
||||
graph = WorkflowGraph.from_dict(request.graph_data)
|
||||
print(f"Executing graph: '{graph.name}'")
|
||||
|
||||
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
|
||||
# RemoteProxyNode 将通过这个服务来回调客户端
|
||||
context = TraceContext(services={"rpc": trace_server})
|
||||
|
||||
# 3. 实例化执行器并运行
|
||||
executor = WorkflowExecutor(graph, context)
|
||||
executor.execute_batch(request.initial_inputs)
|
||||
|
||||
# 4. 获取并返回最终输出
|
||||
final_outputs = context.get_outputs()
|
||||
print(f"Graph execution finished. Final outputs: {final_outputs}")
|
||||
return {
|
||||
"success": True,
|
||||
"outputs": final_outputs
|
||||
}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
)
|
||||
|
||||
# --- 3. 启动服务器 ---
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
|
||||
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
103
tests/assets/complex_remote_graph.json
Normal file
103
tests/assets/complex_remote_graph.json
Normal file
@ -0,0 +1,103 @@
|
||||
|
||||
{
|
||||
"name": "Complex Remote and Subgraph Test",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "main_input",
|
||||
"class_id": "InputTraceNode",
|
||||
"params": {
|
||||
"main_text": { "mode": "context", "value": "inputs.main_text" },
|
||||
"main_num": { "mode": "context", "value": "inputs.main_num" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "remote_text_processor",
|
||||
"class_id": "remote.process_text"
|
||||
},
|
||||
{
|
||||
"id": "sub_graph_node",
|
||||
"class_id": "core.Graph",
|
||||
"data": {
|
||||
"name": "My SubGraph",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "sub_input",
|
||||
"class_id": "InputTraceNode",
|
||||
"params": {
|
||||
"sub_in_1": { "mode": "context", "value": "inputs.sub_in_1" },
|
||||
"sub_in_2": { "mode": "context", "value": "inputs.sub_in_2" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "local_add",
|
||||
"class_id": "math.add"
|
||||
},
|
||||
{
|
||||
"id": "remote_ratio_calculator",
|
||||
"class_id": "remote.calculate_ratio",
|
||||
"params": {
|
||||
"denominator": { "mode": "static", "value": 10.0 }
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "sub_output",
|
||||
"class_id": "OutputTraceNode"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source_id": "sub_input", "source_port": "sub_in_1",
|
||||
"target_id": "local_add", "target_port": "a"
|
||||
},
|
||||
{
|
||||
"source_id": "sub_input", "source_port": "sub_in_2",
|
||||
"target_id": "local_add", "target_port": "b"
|
||||
},
|
||||
{
|
||||
"source_id": "local_add", "source_port": "c",
|
||||
"target_id": "remote_ratio_calculator", "target_port": "numerator"
|
||||
},
|
||||
{
|
||||
"source_id": "remote_ratio_calculator", "source_port": "ratio",
|
||||
"target_id": "sub_output", "target_port": "sub_ratio"
|
||||
}
|
||||
],
|
||||
"interface": {
|
||||
"inputs": { "sub_in_1": "integer", "sub_in_2": "integer" },
|
||||
"outputs": { "sub_ratio": "float" }
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "main_output",
|
||||
"class_id": "OutputTraceNode"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source_id": "main_input", "source_port": "main_text",
|
||||
"target_id": "remote_text_processor", "target_port": "text_in"
|
||||
},
|
||||
{
|
||||
"source_id": "remote_text_processor", "source_port": "length",
|
||||
"target_id": "sub_graph_node", "target_port": "sub_in_1"
|
||||
},
|
||||
{
|
||||
"source_id": "main_input", "source_port": "main_num",
|
||||
"target_id": "sub_graph_node", "target_port": "sub_in_2"
|
||||
},
|
||||
{
|
||||
"source_id": "sub_graph_node", "source_port": "sub_ratio",
|
||||
"target_id": "main_output", "target_port": "final_ratio"
|
||||
}
|
||||
],
|
||||
"interface": {
|
||||
"inputs": {
|
||||
"main_text": "string",
|
||||
"main_num": "integer"
|
||||
},
|
||||
"outputs": {
|
||||
"final_ratio": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,73 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# Ensure pytrace is in python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from pytrace.api.node import TraceNode
|
||||
from pytrace.api.decorators import register_class, register_node, input_port, output_port
|
||||
from pytrace.runtime.context import TraceContext
|
||||
from pytrace.remote.worker import TraceWorker, expose_node
|
||||
|
||||
# =============================================================================
|
||||
# 1. Define Your Custom Nodes
|
||||
# =============================================================================
|
||||
# Use @expose_node to make the class available to the Remote Server.
|
||||
# Use @register_class to define the category and display name in the UI.
|
||||
|
||||
@expose_node
|
||||
@register_class(category="Tutorial", display_name="Client Demo")
|
||||
class ClientDemoNode(TraceNode):
|
||||
"""
|
||||
A collection of demo nodes running on the client side.
|
||||
Developers can add their own methods here.
|
||||
"""
|
||||
|
||||
@register_node(display_name="Remote Add", description="Performs addition on the client.")
|
||||
@input_port(name="a", type="float")
|
||||
@input_port(name="b", type="float")
|
||||
@output_port(name="result", type="float")
|
||||
def add(self, ctx: TraceContext):
|
||||
# 1. Get Inputs
|
||||
a = ctx.get_input("a", 0.0)
|
||||
b = ctx.get_input("b", 0.0)
|
||||
|
||||
# 2. Execute Logic
|
||||
result = a + b
|
||||
|
||||
# 3. Log execution (This log will appear on the Server console!)
|
||||
ctx.log(f"[Client] Executing Add: {a} + {b} = {result}")
|
||||
|
||||
# 4. Set Output
|
||||
ctx.set_output("result", result)
|
||||
|
||||
@register_node(display_name="Remote Echo", description="Echoes a message.")
|
||||
@input_port(name="message", type="string")
|
||||
@output_port(name="echo", type="string")
|
||||
def echo(self, ctx: TraceContext):
|
||||
msg = ctx.get_input("message", "")
|
||||
ctx.log(f"[Client] Echoing: {msg}")
|
||||
ctx.set_output("echo", f"Client says: {msg}")
|
||||
|
||||
# =============================================================================
|
||||
# 2. Start the Worker
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
SERVER_HOST = 'localhost'
|
||||
SERVER_PORT = 9999
|
||||
|
||||
print(f"--- TraceStudio Client Demo ---")
|
||||
print(f"Connecting to Server at {SERVER_HOST}:{SERVER_PORT}...")
|
||||
|
||||
# Initialize the Worker
|
||||
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
|
||||
|
||||
try:
|
||||
# Start the worker (This blocks until stopped)
|
||||
worker.start()
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Client] Worker stopped.")
|
||||
except Exception as e:
|
||||
print(f"\n[Client] Error: {e}")
|
||||
@ -1,220 +0,0 @@
|
||||
import socket
|
||||
import json
|
||||
import threading
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# Add project root
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from pytrace.api.decorators import output_port, register_class
|
||||
from pytrace.remote.manager import RPCManager
|
||||
from pytrace.remote.protocol import MessageType
|
||||
from pytrace.model.graph import WorkflowGraph, Node, Edge
|
||||
from pytrace.model.enums import DimensionMode
|
||||
from pytrace.runtime.context import TraceContext
|
||||
from pytrace.runtime.executor import WorkflowExecutor
|
||||
from pytrace.remote.registry import RemoteRegistry
|
||||
from pytrace.runtime.repository import NodeRepository
|
||||
from pytrace.model.data import TraceParam, ParamMode
|
||||
from pytrace.api.node import GraphTraceNode, TraceNode
|
||||
from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry
|
||||
|
||||
class ClientSession:
|
||||
"""
|
||||
Represents a connected Worker Client on the Server.
|
||||
Wraps the Socket connection and the RPCManager.
|
||||
"""
|
||||
def __init__(self, conn, addr, server: 'TraceServer'):
|
||||
self.conn = conn
|
||||
self.addr = addr
|
||||
self.server = server
|
||||
self.client_id = None
|
||||
self.rpc = RPCManager(send_callback=self._send_json)
|
||||
|
||||
# Register internal handlers
|
||||
self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register)
|
||||
|
||||
def _send_json(self, data):
|
||||
try:
|
||||
msg = json.dumps(data) + "\n"
|
||||
self.conn.sendall(msg.encode('utf-8'))
|
||||
except Exception as e:
|
||||
print(f"[Session] Send error: {e}")
|
||||
self.close()
|
||||
|
||||
def _handle_register(self, payload, msg_id):
|
||||
self.client_id = payload.get("client_id")
|
||||
nodes = payload.get("nodes", [])
|
||||
print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.")
|
||||
|
||||
# Update Server Registry
|
||||
self.server.remote_registry.register_client(self.client_id, nodes)
|
||||
self.server.register_session(self.client_id, self)
|
||||
|
||||
def start_loop(self):
|
||||
"""Blocking read loop."""
|
||||
f = self.conn.makefile('r', encoding='utf-8')
|
||||
try:
|
||||
for line in f:
|
||||
if not line: break
|
||||
try:
|
||||
data = json.loads(line)
|
||||
self.rpc.process_message(data)
|
||||
except Exception as e:
|
||||
print(f"[Session] Msg error: {e}")
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
print(f"[Server] Session {self.client_id} disconnected")
|
||||
self.rpc.close()
|
||||
try:
|
||||
self.conn.close()
|
||||
except:
|
||||
pass
|
||||
if self.client_id:
|
||||
self.server.unregister_session(self.client_id)
|
||||
|
||||
class TraceServer:
|
||||
def __init__(self, socket_port=9999):
|
||||
self.socket_port = socket_port
|
||||
self.sessions = {} # client_id -> ClientSession
|
||||
self.lock = threading.Lock()
|
||||
self.remote_registry = RemoteRegistry()
|
||||
|
||||
def start(self):
|
||||
# 1. Start Socket Server (For Workers)
|
||||
threading.Thread(target=self._socket_server_loop, daemon=True).start()
|
||||
|
||||
# 2. Start "HTTP" Server (For Users - Simulated)
|
||||
threading.Thread(target=self._http_server_simulation, daemon=True).start()
|
||||
|
||||
print(f"[Server] Running. Socket Port: {self.socket_port}")
|
||||
|
||||
# Keep main thread alive
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
def _socket_server_loop(self):
|
||||
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server_socket.bind(('localhost', self.socket_port))
|
||||
server_socket.listen(5)
|
||||
|
||||
while True:
|
||||
conn, addr = server_socket.accept()
|
||||
session = ClientSession(conn, addr, self)
|
||||
threading.Thread(target=session.start_loop, daemon=True).start()
|
||||
|
||||
def register_session(self, client_id, session):
|
||||
with self.lock:
|
||||
self.sessions[client_id] = session
|
||||
|
||||
def unregister_session(self, client_id):
|
||||
with self.lock:
|
||||
self.sessions.pop(client_id, None)
|
||||
|
||||
# --- RPC Service Interface for Workflow ---
|
||||
def call(self, target_id: str, method: str, params: dict, timeout=3600):
|
||||
# 1. Routing Logic
|
||||
if not target_id:
|
||||
providers = self.remote_registry.find_providers(method)
|
||||
if not providers:
|
||||
raise RuntimeError(f"No active client provides node '{method}'")
|
||||
target_id = providers[0]
|
||||
print(f"[Server] Auto-routing '{method}' to client '{target_id}'")
|
||||
|
||||
# 2. Get Session
|
||||
session = self.sessions.get(target_id)
|
||||
if not session:
|
||||
raise RuntimeError(f"Client {target_id} not connected")
|
||||
|
||||
# 3. Execute via Session's RPC
|
||||
return session.rpc.call(target_id, method, params, timeout)
|
||||
|
||||
# --- HTTP API Simulation ---
|
||||
def _http_server_simulation(self):
|
||||
"""
|
||||
Simulates an HTTP endpoint that receives workflow execution requests.
|
||||
In a real app, this would be Flask/FastAPI.
|
||||
"""
|
||||
print("[Server] HTTP API ready. Waiting for requests...")
|
||||
|
||||
# Simulate a user request coming in after 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
# Mock Request Payload
|
||||
# User explicitly wants 'ClientMathNode.add' to run on a specific worker if available
|
||||
# Or they might leave it empty for auto-routing.
|
||||
|
||||
# Let's find a connected client to simulate the user picking one
|
||||
target_worker_id = None
|
||||
while not target_worker_id:
|
||||
with self.lock:
|
||||
if self.sessions:
|
||||
target_worker_id = list(self.sessions.keys())[0]
|
||||
if not target_worker_id:
|
||||
print("[Server] HTTP Waiting for workers to connect...")
|
||||
time.sleep(2)
|
||||
|
||||
print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<")
|
||||
|
||||
# The User Request Data
|
||||
request_payload = {
|
||||
"inputs": {"x": 10, "y": 20},
|
||||
"node_targets": {
|
||||
# The user specifies that the node with ID 'node_remote_1' in the graph
|
||||
# MUST be executed by 'target_worker_id'
|
||||
"node_remote_1": target_worker_id
|
||||
}
|
||||
}
|
||||
|
||||
self.run_workflow(request_payload)
|
||||
|
||||
def run_workflow(self, payload):
|
||||
print("[Server] Loading test workflow: tests/assets/main.flow")
|
||||
flow_path = os.path.abspath("tests/assets/main.flow")
|
||||
with open(flow_path, 'r', encoding='utf-8') as f:
|
||||
graph_data = json.load(f)
|
||||
|
||||
# Inject a node ID into the graph data for demonstration if needed,
|
||||
# or assume the graph already has 'node_remote_1'.
|
||||
# For this demo, let's assume the graph's first node is the one we want to target.
|
||||
# (In reality, graph_data comes from the frontend with IDs)
|
||||
|
||||
graph = WorkflowGraph.from_dict(graph_data)
|
||||
|
||||
# Prepare Repository
|
||||
repo = NodeRepository(remote_registry=self.remote_registry)
|
||||
|
||||
# Prepare Context with User Preferences (node_targets)
|
||||
context = TraceContext(
|
||||
contexts={},
|
||||
services={"rpc": self, "node_repository": repo},
|
||||
node_targets=payload.get("node_targets")
|
||||
)
|
||||
|
||||
# Execute
|
||||
print("[Server] Executing workflow...")
|
||||
executor = WorkflowExecutor(graph, context, node_repository=repo)
|
||||
|
||||
inputs = payload.get("inputs")
|
||||
executor.execute_batch(inputs)
|
||||
|
||||
# 7. Verify Results
|
||||
outputs = context.context.get_outputs()
|
||||
print(f"[Server] Workflow finished. Outputs: {outputs}")
|
||||
|
||||
res = outputs.get("final_result")
|
||||
val = res.value if hasattr(res, 'value') else res
|
||||
|
||||
if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15
|
||||
print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}")
|
||||
else:
|
||||
print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = TraceServer()
|
||||
server.start()
|
||||
93
tests/run_remote_test.py
Normal file
93
tests/run_remote_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
|
||||
import json
|
||||
import httpx
|
||||
import os
|
||||
|
||||
# --- 配置 ---
|
||||
SERVER_URL = "http://localhost:8000"
|
||||
GRAPH_EXECUTE_ENDPOINT = f"{SERVER_URL}/execute_graph"
|
||||
GRAPH_FILE_PATH = os.path.join(os.path.dirname(__file__), "assets", "complex_remote_graph.json")
|
||||
|
||||
def run_test():
|
||||
"""
|
||||
执行端到端远程调用测试。
|
||||
|
||||
使用方法:
|
||||
1. 确保 remote_test_server.py 正在运行 (uvicorn server.remote_test_server:app --reload --port 8000)
|
||||
2. 确保 test_complex_remote_client.py 正在运行 (python tests/test_complex_remote_client.py)
|
||||
3. 运行此脚本 (python tests/run_remote_test.py)
|
||||
"""
|
||||
|
||||
print("--- Running End-to-End Remote Execution Test ---")
|
||||
|
||||
# 1. 加载图JSON文件
|
||||
try:
|
||||
with open(GRAPH_FILE_PATH, 'r', encoding='utf-8') as f:
|
||||
graph_data = json.load(f)
|
||||
print(f"Successfully loaded graph from: {GRAPH_FILE_PATH}")
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: Graph file not found at {GRAPH_FILE_PATH}")
|
||||
return
|
||||
|
||||
# 2. 准备初始输入
|
||||
# 根据 complex_remote_graph.json 的 "interface" 定义
|
||||
initial_inputs = {
|
||||
"main_text": "hello_remote_world",
|
||||
"main_num": 5
|
||||
}
|
||||
print(f"Initial inputs: {initial_inputs}")
|
||||
|
||||
# 3. 构建请求体
|
||||
request_payload = {
|
||||
"graph_data": graph_data,
|
||||
"initial_inputs": initial_inputs
|
||||
}
|
||||
|
||||
# 4. 发送HTTP请求到服务器
|
||||
print(f"Sending POST request to {GRAPH_EXECUTE_ENDPOINT}...")
|
||||
try:
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(GRAPH_EXECUTE_ENDPOINT, json=request_payload)
|
||||
response.raise_for_status() # 如果状态码不是 2xx,则会引发异常
|
||||
|
||||
result = response.json()
|
||||
|
||||
print("\n--- Test Result ---")
|
||||
print("Request successful!")
|
||||
print("Server response:")
|
||||
# 使用json.dumps美化输出
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
# 预期结果计算:
|
||||
# 1. main_text "hello_remote_world" 长度为 18
|
||||
# 2. 子图接收 sub_in_1=18, sub_in_2=5
|
||||
# 3. 子图内本地相加: 18 + 5 = 23
|
||||
# 4. 子图内远程计算比率: 23 / 10.0 = 2.3
|
||||
# 5. 主图接收到最终结果 2.3
|
||||
expected_ratio = 2.3
|
||||
final_ratio = result.get("outputs", {}).get("final_ratio", {}).get("value")
|
||||
|
||||
print("\n--- Verification ---")
|
||||
print(f"Expected final_ratio: {expected_ratio}")
|
||||
print(f"Actual final_ratio: {final_ratio}")
|
||||
assert final_ratio == expected_ratio, "Test assertion failed!"
|
||||
print("✅ Test Passed Successfully!")
|
||||
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print("\n--- TEST FAILED ---")
|
||||
print(f"ERROR: Could not connect to the server at {SERVER_URL}.")
|
||||
print("Please ensure the 'remote_test_server.py' is running.")
|
||||
print(f"Details: {e}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
print("\n--- TEST FAILED ---")
|
||||
print(f"ERROR: Server returned a non-200 status code: {e.response.status_code}")
|
||||
print("Server response:")
|
||||
try:
|
||||
# 尝试打印JSON错误详情
|
||||
print(json.dumps(e.response.json(), indent=2, ensure_ascii=False))
|
||||
except json.JSONDecodeError:
|
||||
print(e.response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test()
|
||||
104
tests/test_complex_remote_client.py
Normal file
104
tests/test_complex_remote_client.py
Normal file
@ -0,0 +1,104 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# 确保 pytrace 在 Python 路径中
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from pytrace.api.node import TraceNode
|
||||
from pytrace.api.decorators import register_class, register_node, input_port, output_port
|
||||
from pytrace.runtime.context import TraceContext
|
||||
from pytrace.remote.worker import TraceWorker, expose_node
|
||||
|
||||
# --- 1. 定义可被远程调用的节点 ---
|
||||
# 使用 @expose_node 将此类暴露给服务器
|
||||
@expose_node
|
||||
@register_class(category="ComplexRemote", display_name="Complex Remote Node")
|
||||
class ComplexRemoteNode(TraceNode):
|
||||
"""
|
||||
这个类包含一组用于远程调用的方法(节点)。
|
||||
每个方法都添加了详细的打印语句,方便您在客户端控制台进行调试。
|
||||
"""
|
||||
|
||||
@register_node(
|
||||
class_id="remote.process_text",
|
||||
display_name="Process Text (Remote)",
|
||||
description="在客户端处理文本,计算长度并添加前缀。"
|
||||
)
|
||||
@input_port(name="text_in", type="string")
|
||||
@output_port(name="text_out", type="string")
|
||||
@output_port(name="length", type="integer")
|
||||
def process_text(self, ctx: TraceContext):
|
||||
text = ctx.get_input("text_in", "")
|
||||
|
||||
print("\n--- [REMOTE] ---")
|
||||
print(f"Node 'remote.process_text' is executing.")
|
||||
print(f" - Received input 'text_in': '{text}'")
|
||||
|
||||
# 模拟一些耗时操作
|
||||
time.sleep(0.5)
|
||||
|
||||
processed_text = f"processed_by_client:_{text}"
|
||||
text_length = len(text)
|
||||
|
||||
ctx.set_output("text_out", processed_text)
|
||||
ctx.set_output("length", text_length)
|
||||
|
||||
print(f" - Sent output 'text_out': '{processed_text}'")
|
||||
print(f" - Sent output 'length': {text_length}")
|
||||
print("--- [REMOTE] ---\
|
||||
")
|
||||
|
||||
|
||||
@register_node(
|
||||
class_id="remote.calculate_ratio",
|
||||
display_name="Calculate Ratio (Remote)",
|
||||
description="在客户端根据两个数字计算比率。"
|
||||
)
|
||||
@input_port(name="numerator", type="float", label="分子")
|
||||
@input_port(name="denominator", type="float", label="分母")
|
||||
@output_port(name="ratio", type="float", label="比率")
|
||||
def calculate_ratio(self, ctx: TraceContext):
|
||||
numerator = ctx.get_input("numerator", 0.0)
|
||||
denominator = ctx.get_input("denominator", 1.0)
|
||||
|
||||
print("\n--- [REMOTE] ---")
|
||||
print(f"Node 'remote.calculate_ratio' is executing.")
|
||||
print(f" - Received input 'numerator': {numerator}")
|
||||
print(f" - Received input 'denominator': {denominator}")
|
||||
|
||||
if denominator == 0:
|
||||
print(" - Error: Denominator is zero.")
|
||||
# 在实际应用中可能需要更复杂的错误处理
|
||||
ratio = 0.0
|
||||
else:
|
||||
ratio = numerator / denominator
|
||||
|
||||
ctx.set_output("ratio", ratio)
|
||||
print(f" - Sent output 'ratio': {ratio}")
|
||||
print("--- [REMOTE] ---\
|
||||
")
|
||||
|
||||
|
||||
# --- 2. 启动客户端 Worker ---
|
||||
if __name__ == "__main__":
|
||||
# 服务器的TCP RPC地址和端口
|
||||
SERVER_HOST = 'localhost'
|
||||
SERVER_PORT = 9999 # 这应该与 remote_test_server.py 中的端口一致
|
||||
|
||||
print(f"--- Complex Remote Test Client ---")
|
||||
print(f"Attempting to connect to RPC Server at {SERVER_HOST}:{SERVER_PORT}...")
|
||||
|
||||
# 初始化并启动 Worker
|
||||
# 它会自动查找所有被 @expose_node 装饰的类,并向服务器注册
|
||||
worker = TraceWorker(host=SERVER_HOST, port=SERVER_PORT)
|
||||
|
||||
try:
|
||||
# worker.start() 是一个阻塞操作,它会保持运行以监听服务器的调用
|
||||
worker.start()
|
||||
except KeyboardInterrupt:
|
||||
print("\n[Client] Worker stopped by user.")
|
||||
except Exception as e:
|
||||
print(f"\n[Client] An error occurred: {e}")
|
||||
print("[Client] Is the test server running?")
|
||||
Loading…
Reference in New Issue
Block a user