pytrace/tests/demo_server.py
2026-01-19 00:49:55 +08:00

220 lines
8.0 KiB
Python

import socket
import json
import threading
import sys
import os
import time
# Add project root
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.api.decorators import output_port, register_class
from pytrace.remote.manager import RPCManager
from pytrace.remote.protocol import MessageType
from pytrace.model.graph import WorkflowGraph, Node, Edge
from pytrace.model.enums import DimensionMode
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.remote.registry import RemoteRegistry
from pytrace.runtime.repository import NodeRepository
from pytrace.model.data import TraceParam, ParamMode
from pytrace.api.node import GraphTraceNode, TraceNode
from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry
class ClientSession:
"""
Represents a connected Worker Client on the Server.
Wraps the Socket connection and the RPCManager.
"""
def __init__(self, conn, addr, server: 'TraceServer'):
self.conn = conn
self.addr = addr
self.server = server
self.client_id = None
self.rpc = RPCManager(send_callback=self._send_json)
# Register internal handlers
self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register)
def _send_json(self, data):
try:
msg = json.dumps(data) + "\n"
self.conn.sendall(msg.encode('utf-8'))
except Exception as e:
print(f"[Session] Send error: {e}")
self.close()
def _handle_register(self, payload, msg_id):
self.client_id = payload.get("client_id")
nodes = payload.get("nodes", [])
print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.")
# Update Server Registry
self.server.remote_registry.register_client(self.client_id, nodes)
self.server.register_session(self.client_id, self)
def start_loop(self):
"""Blocking read loop."""
f = self.conn.makefile('r', encoding='utf-8')
try:
for line in f:
if not line: break
try:
data = json.loads(line)
self.rpc.process_message(data)
except Exception as e:
print(f"[Session] Msg error: {e}")
finally:
self.close()
def close(self):
print(f"[Server] Session {self.client_id} disconnected")
self.rpc.close()
try:
self.conn.close()
except:
pass
if self.client_id:
self.server.unregister_session(self.client_id)
class TraceServer:
def __init__(self, socket_port=9999):
self.socket_port = socket_port
self.sessions = {} # client_id -> ClientSession
self.lock = threading.Lock()
self.remote_registry = RemoteRegistry()
def start(self):
# 1. Start Socket Server (For Workers)
threading.Thread(target=self._socket_server_loop, daemon=True).start()
# 2. Start "HTTP" Server (For Users - Simulated)
threading.Thread(target=self._http_server_simulation, daemon=True).start()
print(f"[Server] Running. Socket Port: {self.socket_port}")
# Keep main thread alive
while True:
time.sleep(1)
def _socket_server_loop(self):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('localhost', self.socket_port))
server_socket.listen(5)
while True:
conn, addr = server_socket.accept()
session = ClientSession(conn, addr, self)
threading.Thread(target=session.start_loop, daemon=True).start()
def register_session(self, client_id, session):
with self.lock:
self.sessions[client_id] = session
def unregister_session(self, client_id):
with self.lock:
self.sessions.pop(client_id, None)
# --- RPC Service Interface for Workflow ---
def call(self, target_id: str, method: str, params: dict, timeout=3600):
# 1. Routing Logic
if not target_id:
providers = self.remote_registry.find_providers(method)
if not providers:
raise RuntimeError(f"No active client provides node '{method}'")
target_id = providers[0]
print(f"[Server] Auto-routing '{method}' to client '{target_id}'")
# 2. Get Session
session = self.sessions.get(target_id)
if not session:
raise RuntimeError(f"Client {target_id} not connected")
# 3. Execute via Session's RPC
return session.rpc.call(target_id, method, params, timeout)
# --- HTTP API Simulation ---
def _http_server_simulation(self):
"""
Simulates an HTTP endpoint that receives workflow execution requests.
In a real app, this would be Flask/FastAPI.
"""
print("[Server] HTTP API ready. Waiting for requests...")
# Simulate a user request coming in after 5 seconds
time.sleep(5)
# Mock Request Payload
# User explicitly wants 'ClientMathNode.add' to run on a specific worker if available
# Or they might leave it empty for auto-routing.
# Let's find a connected client to simulate the user picking one
target_worker_id = None
while not target_worker_id:
with self.lock:
if self.sessions:
target_worker_id = list(self.sessions.keys())[0]
if not target_worker_id:
print("[Server] HTTP Waiting for workers to connect...")
time.sleep(2)
print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<")
# The User Request Data
request_payload = {
"inputs": {"x": 10, "y": 20},
"node_targets": {
# The user specifies that the node with ID 'node_remote_1' in the graph
# MUST be executed by 'target_worker_id'
"node_remote_1": target_worker_id
}
}
self.run_workflow(request_payload)
def run_workflow(self, payload):
print("[Server] Loading test workflow: tests/assets/main.flow")
flow_path = os.path.abspath("tests/assets/main.flow")
with open(flow_path, 'r', encoding='utf-8') as f:
graph_data = json.load(f)
# Inject a node ID into the graph data for demonstration if needed,
# or assume the graph already has 'node_remote_1'.
# For this demo, let's assume the graph's first node is the one we want to target.
# (In reality, graph_data comes from the frontend with IDs)
graph = WorkflowGraph.from_dict(graph_data)
# Prepare Repository
repo = NodeRepository(remote_registry=self.remote_registry)
# Prepare Context with User Preferences (node_targets)
context = TraceContext(
contexts={},
services={"rpc": self, "node_repository": repo},
node_targets=payload.get("node_targets")
)
# Execute
print("[Server] Executing workflow...")
executor = WorkflowExecutor(graph, context, node_repository=repo)
inputs = payload.get("inputs")
executor.execute_batch(inputs)
# 7. Verify Results
outputs = context.context.get_outputs()
print(f"[Server] Workflow finished. Outputs: {outputs}")
res = outputs.get("final_result")
val = res.value if hasattr(res, 'value') else res
if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15
print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}")
else:
print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}")
if __name__ == "__main__":
server = TraceServer()
server.start()