220 lines
8.0 KiB
Python
220 lines
8.0 KiB
Python
import socket
|
|
import json
|
|
import threading
|
|
import sys
|
|
import os
|
|
import time
|
|
|
|
# Add project root
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
from pytrace.api.decorators import output_port, register_class
|
|
from pytrace.remote.manager import RPCManager
|
|
from pytrace.remote.protocol import MessageType
|
|
from pytrace.model.graph import WorkflowGraph, Node, Edge
|
|
from pytrace.model.enums import DimensionMode
|
|
from pytrace.runtime.context import TraceContext
|
|
from pytrace.runtime.executor import WorkflowExecutor
|
|
from pytrace.remote.registry import RemoteRegistry
|
|
from pytrace.runtime.repository import NodeRepository
|
|
from pytrace.model.data import TraceParam, ParamMode
|
|
from pytrace.api.node import GraphTraceNode, TraceNode
|
|
from pytrace.runtime.registry import NodeRegistry as LocalNodeRegistry
|
|
|
|
class ClientSession:
|
|
"""
|
|
Represents a connected Worker Client on the Server.
|
|
Wraps the Socket connection and the RPCManager.
|
|
"""
|
|
def __init__(self, conn, addr, server: 'TraceServer'):
|
|
self.conn = conn
|
|
self.addr = addr
|
|
self.server = server
|
|
self.client_id = None
|
|
self.rpc = RPCManager(send_callback=self._send_json)
|
|
|
|
# Register internal handlers
|
|
self.rpc.register_handler(MessageType.REGISTER_NODES, self._handle_register)
|
|
|
|
def _send_json(self, data):
|
|
try:
|
|
msg = json.dumps(data) + "\n"
|
|
self.conn.sendall(msg.encode('utf-8'))
|
|
except Exception as e:
|
|
print(f"[Session] Send error: {e}")
|
|
self.close()
|
|
|
|
def _handle_register(self, payload, msg_id):
|
|
self.client_id = payload.get("client_id")
|
|
nodes = payload.get("nodes", [])
|
|
print(f"[Server] Client {self.client_id} registered {len(nodes)} nodes.")
|
|
|
|
# Update Server Registry
|
|
self.server.remote_registry.register_client(self.client_id, nodes)
|
|
self.server.register_session(self.client_id, self)
|
|
|
|
def start_loop(self):
|
|
"""Blocking read loop."""
|
|
f = self.conn.makefile('r', encoding='utf-8')
|
|
try:
|
|
for line in f:
|
|
if not line: break
|
|
try:
|
|
data = json.loads(line)
|
|
self.rpc.process_message(data)
|
|
except Exception as e:
|
|
print(f"[Session] Msg error: {e}")
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self):
|
|
print(f"[Server] Session {self.client_id} disconnected")
|
|
self.rpc.close()
|
|
try:
|
|
self.conn.close()
|
|
except:
|
|
pass
|
|
if self.client_id:
|
|
self.server.unregister_session(self.client_id)
|
|
|
|
class TraceServer:
|
|
def __init__(self, socket_port=9999):
|
|
self.socket_port = socket_port
|
|
self.sessions = {} # client_id -> ClientSession
|
|
self.lock = threading.Lock()
|
|
self.remote_registry = RemoteRegistry()
|
|
|
|
def start(self):
|
|
# 1. Start Socket Server (For Workers)
|
|
threading.Thread(target=self._socket_server_loop, daemon=True).start()
|
|
|
|
# 2. Start "HTTP" Server (For Users - Simulated)
|
|
threading.Thread(target=self._http_server_simulation, daemon=True).start()
|
|
|
|
print(f"[Server] Running. Socket Port: {self.socket_port}")
|
|
|
|
# Keep main thread alive
|
|
while True:
|
|
time.sleep(1)
|
|
|
|
def _socket_server_loop(self):
|
|
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
server_socket.bind(('localhost', self.socket_port))
|
|
server_socket.listen(5)
|
|
|
|
while True:
|
|
conn, addr = server_socket.accept()
|
|
session = ClientSession(conn, addr, self)
|
|
threading.Thread(target=session.start_loop, daemon=True).start()
|
|
|
|
def register_session(self, client_id, session):
|
|
with self.lock:
|
|
self.sessions[client_id] = session
|
|
|
|
def unregister_session(self, client_id):
|
|
with self.lock:
|
|
self.sessions.pop(client_id, None)
|
|
|
|
# --- RPC Service Interface for Workflow ---
|
|
def call(self, target_id: str, method: str, params: dict, timeout=3600):
|
|
# 1. Routing Logic
|
|
if not target_id:
|
|
providers = self.remote_registry.find_providers(method)
|
|
if not providers:
|
|
raise RuntimeError(f"No active client provides node '{method}'")
|
|
target_id = providers[0]
|
|
print(f"[Server] Auto-routing '{method}' to client '{target_id}'")
|
|
|
|
# 2. Get Session
|
|
session = self.sessions.get(target_id)
|
|
if not session:
|
|
raise RuntimeError(f"Client {target_id} not connected")
|
|
|
|
# 3. Execute via Session's RPC
|
|
return session.rpc.call(target_id, method, params, timeout)
|
|
|
|
# --- HTTP API Simulation ---
|
|
def _http_server_simulation(self):
|
|
"""
|
|
Simulates an HTTP endpoint that receives workflow execution requests.
|
|
In a real app, this would be Flask/FastAPI.
|
|
"""
|
|
print("[Server] HTTP API ready. Waiting for requests...")
|
|
|
|
# Simulate a user request coming in after 5 seconds
|
|
time.sleep(5)
|
|
|
|
# Mock Request Payload
|
|
# User explicitly wants 'ClientMathNode.add' to run on a specific worker if available
|
|
# Or they might leave it empty for auto-routing.
|
|
|
|
# Let's find a connected client to simulate the user picking one
|
|
target_worker_id = None
|
|
while not target_worker_id:
|
|
with self.lock:
|
|
if self.sessions:
|
|
target_worker_id = list(self.sessions.keys())[0]
|
|
if not target_worker_id:
|
|
print("[Server] HTTP Waiting for workers to connect...")
|
|
time.sleep(2)
|
|
|
|
print(f"\n[Server] >>> HTTP Request Received: Execute Workflow (Target: {target_worker_id}) <<<")
|
|
|
|
# The User Request Data
|
|
request_payload = {
|
|
"inputs": {"x": 10, "y": 20},
|
|
"node_targets": {
|
|
# The user specifies that the node with ID 'node_remote_1' in the graph
|
|
# MUST be executed by 'target_worker_id'
|
|
"node_remote_1": target_worker_id
|
|
}
|
|
}
|
|
|
|
self.run_workflow(request_payload)
|
|
|
|
def run_workflow(self, payload):
|
|
print("[Server] Loading test workflow: tests/assets/main.flow")
|
|
flow_path = os.path.abspath("tests/assets/main.flow")
|
|
with open(flow_path, 'r', encoding='utf-8') as f:
|
|
graph_data = json.load(f)
|
|
|
|
# Inject a node ID into the graph data for demonstration if needed,
|
|
# or assume the graph already has 'node_remote_1'.
|
|
# For this demo, let's assume the graph's first node is the one we want to target.
|
|
# (In reality, graph_data comes from the frontend with IDs)
|
|
|
|
graph = WorkflowGraph.from_dict(graph_data)
|
|
|
|
# Prepare Repository
|
|
repo = NodeRepository(remote_registry=self.remote_registry)
|
|
|
|
# Prepare Context with User Preferences (node_targets)
|
|
context = TraceContext(
|
|
contexts={},
|
|
services={"rpc": self, "node_repository": repo},
|
|
node_targets=payload.get("node_targets")
|
|
)
|
|
|
|
# Execute
|
|
print("[Server] Executing workflow...")
|
|
executor = WorkflowExecutor(graph, context, node_repository=repo)
|
|
|
|
inputs = payload.get("inputs")
|
|
executor.execute_batch(inputs)
|
|
|
|
# 7. Verify Results
|
|
outputs = context.context.get_outputs()
|
|
print(f"[Server] Workflow finished. Outputs: {outputs}")
|
|
|
|
res = outputs.get("final_result")
|
|
val = res.value if hasattr(res, 'value') else res
|
|
|
|
if val == 15.0: # main.flow logic: x(10) -> sub_flow(a=10, b=5) -> 10+5=15
|
|
print(f"[Server] ✅ Remote Execution SUCCESS! Result: {val}")
|
|
else:
|
|
print(f"[Server] ❌ Remote Execution FAILED. Expected 15.0, got {val}")
|
|
|
|
if __name__ == "__main__":
|
|
server = TraceServer()
|
|
server.start() |