pytrace/server/app/services.py

72 lines
2.7 KiB
Python
Raw Permalink Normal View History

2026-01-19 00:49:55 +08:00
import socket
import threading
import logging
from typing import Dict
from pytrace.remote.engine import ServerEngine
from pytrace.remote.registry import RemoteRegistry
from pytrace.remote.manager import SyncRPCManager
logger = logging.getLogger(__name__)
class TraceServer(ServerEngine):
"""
Extended ServerEngine that integrates with RemoteRegistry and handles
TCP socket listening for workers.
"""
def __init__(self, host="0.0.0.0", port=9999):
super().__init__()
self.host = host
self.port = port
self.remote_registry = RemoteRegistry()
self._stop_event = threading.Event()
def start_background(self):
"""Starts the socket server in a background thread."""
t = threading.Thread(target=self._socket_loop, daemon=True, name="TraceServerSocket")
t.start()
logger.info(f"TraceServer worker socket listening on {self.host}:{self.port}")
def _socket_loop(self):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
server_socket.bind((self.host, self.port))
server_socket.listen(5)
while not self._stop_event.is_set():
try:
conn, addr = server_socket.accept()
self.handle_connection(conn, addr)
except Exception as e:
logger.error(f"Socket accept error: {e}")
finally:
server_socket.close()
def _on_client_register(self, payload: Dict, manager: SyncRPCManager):
# Override to use RemoteRegistry and prefer client_id from payload
worker_client_id = payload.get("client_id")
nodes = payload.get("nodes", [])
# Use worker provided ID or generate one
client_id = worker_client_id or f"client_{id(manager)}"
self.sessions[client_id] = manager
self.remote_registry.register_client(client_id, nodes)
logger.info(f"Client registered: {client_id} with {len(nodes)} nodes")
def call(self, target_id: str, method: str, params: dict, timeout=60.0):
"""
RPC Service interface compatible with RemoteProxyNode.
"""
# 1. Routing Logic
if not target_id:
providers = self.remote_registry.find_providers(method)
if not providers:
raise RuntimeError(f"No active client provides node '{method}'")
# Simple Load Balancing: Pick the first one
target_id = providers[0]
# 2. Call via Engine
return self.call_client(target_id, method, params)
trace_server = TraceServer()