72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
import socket
|
|
import threading
|
|
import logging
|
|
from typing import Dict
|
|
from pytrace.remote.engine import ServerEngine
|
|
from pytrace.remote.registry import RemoteRegistry
|
|
from pytrace.remote.manager import SyncRPCManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TraceServer(ServerEngine):
|
|
"""
|
|
Extended ServerEngine that integrates with RemoteRegistry and handles
|
|
TCP socket listening for workers.
|
|
"""
|
|
def __init__(self, host="0.0.0.0", port=9999):
|
|
super().__init__()
|
|
self.host = host
|
|
self.port = port
|
|
self.remote_registry = RemoteRegistry()
|
|
self._stop_event = threading.Event()
|
|
|
|
def start_background(self):
|
|
"""Starts the socket server in a background thread."""
|
|
t = threading.Thread(target=self._socket_loop, daemon=True, name="TraceServerSocket")
|
|
t.start()
|
|
logger.info(f"TraceServer worker socket listening on {self.host}:{self.port}")
|
|
|
|
def _socket_loop(self):
|
|
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
try:
|
|
server_socket.bind((self.host, self.port))
|
|
server_socket.listen(5)
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
conn, addr = server_socket.accept()
|
|
self.handle_connection(conn, addr)
|
|
except Exception as e:
|
|
logger.error(f"Socket accept error: {e}")
|
|
finally:
|
|
server_socket.close()
|
|
|
|
def _on_client_register(self, payload: Dict, manager: SyncRPCManager):
|
|
# Override to use RemoteRegistry and prefer client_id from payload
|
|
worker_client_id = payload.get("client_id")
|
|
nodes = payload.get("nodes", [])
|
|
|
|
# Use worker provided ID or generate one
|
|
client_id = worker_client_id or f"client_{id(manager)}"
|
|
|
|
self.sessions[client_id] = manager
|
|
self.remote_registry.register_client(client_id, nodes)
|
|
|
|
logger.info(f"Client registered: {client_id} with {len(nodes)} nodes")
|
|
|
|
def call(self, target_id: str, method: str, params: dict, timeout=60.0):
|
|
"""
|
|
RPC Service interface compatible with RemoteProxyNode.
|
|
"""
|
|
# 1. Routing Logic
|
|
if not target_id:
|
|
providers = self.remote_registry.find_providers(method)
|
|
if not providers:
|
|
raise RuntimeError(f"No active client provides node '{method}'")
|
|
# Simple Load Balancing: Pick the first one
|
|
target_id = providers[0]
|
|
|
|
# 2. Call via Engine
|
|
return self.call_client(target_id, method, params)
|
|
|
|
trace_server = TraceServer() |