import socket import json import threading import sys import os import time # Add project root sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from pytrace.api.decorators import output_port, register_class from pytrace.remote.manager import RPCManager from pytrace.remote.protocol import MessageType from pytrace.model.graph import WorkflowGraph, Node, Edge from pytrace.model.enums import DimensionMode from pytrace.runtime.context import TraceContext from pytrace.runtime.executor import WorkflowExecutor from pytrace.remote.registry import RemoteRegistry from pytrace.runtime.repository import NodeRepository from pytrace.model.data import TraceParam, ParamMode from pytrace.api.node import GraphTraceNode, TraceNode from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry class ClientSession: """ Represents a connected Worker Client on the Server. Wraps the Socket connection and the RPCManager. """ def __init__(self, conn, addr, server: 'TraceServer'): self.conn = conn self.addr = addr self.server = server self.client_id = None self.rpc = RPCManager(send_callback=self._send_json) # Register internal handlers self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register) def _send_json(self, data): try: msg = json.dumps(data) + "\n" self.conn.sendall(msg.encode('utf-8')) except Exception as e: print(f"[Session] Send error: {e}") self.close() def _handle_register(self, payload, msg_id): self.client_id = payload.get("client_id") nodes = payload.get("nodes", []) print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.") # Update Server Registry self.server.remote_registry.register_client(self.client_id, nodes) self.server.register_session(self.client_id, self) def start_loop(self): """Blocking read loop.""" f = self.conn.makefile('r', encoding='utf-8') try: for line in f: if not line: break try: data = json.loads(line) self.rpc.process_message(data) except Exception as e: print(f"[Session] Msg error: {e}") finally: self.close() def close(self): print(f"[Server] Session {self.client_id} disconnected") self.rpc.close() try: self.conn.close() except: pass if self.client_id: self.server.unregister_session(self.client_id) class TraceServer: def __init__(self, socket_port=9999): self.socket_port = socket_port self.sessions = {} # client_id -> ClientSession self.lock = threading.Lock() self.remote_registry = RemoteRegistry() def start(self): # 1. Start Socket Server (For Workers) threading.Thread(target=self._socket_server_loop, daemon=True).start() # 2. Start "HTTP" Server (For Users - Simulated) threading.Thread(target=self._http_server_simulation, daemon=True).start() print(f"[Server] Running. Socket Port: {self.socket_port}") # Keep main thread alive while True: time.sleep(1) def _socket_server_loop(self): server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_socket.bind(('localhost', self.socket_port)) server_socket.listen(5) while True: conn, addr = server_socket.accept() session = ClientSession(conn, addr, self) threading.Thread(target=session.start_loop, daemon=True).start() def register_session(self, client_id, session): with self.lock: self.sessions[client_id] = session def unregister_session(self, client_id): with self.lock: self.sessions.pop(client_id, None) # --- RPC Service Interface for Workflow --- def call(self, target_id: str, method: str, params: dict, timeout=3600): # 1. Routing Logic if not target_id: providers = self.remote_registry.find_providers(method) if not providers: raise RuntimeError(f"No active client provides node '{method}'") target_id = providers[0] print(f"[Server] Auto-routing '{method}' to client '{target_id}'") # 2. Get Session session = self.sessions.get(target_id) if not session: raise RuntimeError(f"Client {target_id} not connected") # 3. Execute via Session's RPC return session.rpc.call(target_id, method, params, timeout) # --- HTTP API Simulation --- def _http_server_simulation(self): """ Simulates an HTTP endpoint that receives workflow execution requests. In a real app, this would be Flask/FastAPI. """ print("[Server] HTTP API ready. Waiting for requests...") # Simulate a user request coming in after 5 seconds time.sleep(5) # Mock Request Payload # User explicitly wants 'ClientMathNode.add' to run on a specific worker if available # Or they might leave it empty for auto-routing. # Let's find a connected client to simulate the user picking one target_worker_id = None while not target_worker_id: with self.lock: if self.sessions: target_worker_id = list(self.sessions.keys())[0] if not target_worker_id: print("[Server] HTTP Waiting for workers to connect...") time.sleep(2) print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<") # The User Request Data request_payload = { "inputs": {"x": 10, "y": 20}, "node_targets": { # The user specifies that the node with ID 'node_remote_1' in the graph # MUST be executed by 'target_worker_id' "node_remote_1": target_worker_id } } self.run_workflow(request_payload) def run_workflow(self, payload): print("[Server] Loading test workflow: tests/assets/main.flow") flow_path = os.path.abspath("tests/assets/main.flow") with open(flow_path, 'r', encoding='utf-8') as f: graph_data = json.load(f) # Inject a node ID into the graph data for demonstration if needed, # or assume the graph already has 'node_remote_1'. # For this demo, let's assume the graph's first node is the one we want to target. # (In reality, graph_data comes from the frontend with IDs) graph = WorkflowGraph.from_dict(graph_data) # Prepare Repository repo = NodeRepository(remote_registry=self.remote_registry) # Prepare Context with User Preferences (node_targets) context = TraceContext( contexts={}, services={"rpc": self, "node_repository": repo}, node_targets=payload.get("node_targets") ) # Execute print("[Server] Executing workflow...") executor = WorkflowExecutor(graph, context, node_repository=repo) inputs = payload.get("inputs") executor.execute_batch(inputs) # 7. Verify Results outputs = context.context.get_outputs() print(f"[Server] Workflow finished. Outputs: {outputs}") res = outputs.get("final_result") val = res.value if hasattr(res, 'value') else res if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15 print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}") else: print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}") if __name__ == "__main__": server = TraceServer() server.start()