import socket import threading import logging from typing import Dict from pytrace.remote.engine import ServerEngine from pytrace.remote.registry import RemoteRegistry from pytrace.remote.manager import SyncRPCManager logger = logging.getLogger(__name__) class TraceServer(ServerEngine): """ Extended ServerEngine that integrates with RemoteRegistry and handles TCP socket listening for workers. """ def __init__(self, host="0.0.0.0", port=9999): super().__init__() self.host = host self.port = port self.remote_registry = RemoteRegistry() self._stop_event = threading.Event() def start_background(self): """Starts the socket server in a background thread.""" t = threading.Thread(target=self._socket_loop, daemon=True, name="TraceServerSocket") t.start() logger.info(f"TraceServer worker socket listening on {self.host}:{self.port}") def _socket_loop(self): server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: server_socket.bind((self.host, self.port)) server_socket.listen(5) while not self._stop_event.is_set(): try: conn, addr = server_socket.accept() self.handle_connection(conn, addr) except Exception as e: logger.error(f"Socket accept error: {e}") finally: server_socket.close() def _on_client_register(self, payload: Dict, manager: SyncRPCManager): # Override to use RemoteRegistry and prefer client_id from payload worker_client_id = payload.get("client_id") nodes = payload.get("nodes", []) # Use worker provided ID or generate one client_id = worker_client_id or f"client_{id(manager)}" self.sessions[client_id] = manager self.remote_registry.register_client(client_id, nodes) logger.info(f"Client registered: {client_id} with {len(nodes)} nodes") def call(self, target_id: str, method: str, params: dict, timeout=60.0): """ RPC Service interface compatible with RemoteProxyNode. """ # 1. Routing Logic if not target_id: providers = self.remote_registry.find_providers(method) if not providers: raise RuntimeError(f"No active client provides node '{method}'") # Simple Load Balancing: Pick the first one target_id = providers[0] # 2. Call via Engine return self.call_client(target_id, method, params) trace_server = TraceServer()