pytrace/server/remote_test_server.py

94 lines
3.3 KiB
Python
Raw Normal View History

2026-01-22 23:22:01 +08:00
import sys
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any
import uvicorn
# 确保 pytrace 在 Python 路径中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.model.graph import WorkflowGraph
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.runtime.repository import NodeRepository
# 导入项目中已有的 TraceServer
from app.services import trace_server
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
from pytrace.nodes.test_math_suite import MathSuite
# --- 1. 数据模型 ---
class GraphExecutionRequest(BaseModel):
graph_data: Dict[str, Any]
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
# --- 2. FastAPI 应用定义 ---
app = FastAPI(
title="TraceStudio - Remote Test Server",
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
version="2.0.0",
)
@app.on_event("startup")
async def startup_event():
# 1. 在后台启动TCP服务器用于监听客户端的连接和注册
trace_server.start_background()
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
NodeRepository.set_remote_registry(trace_server.remote_registry)
print("--- Remote Test Server ---")
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
print("HTTP Server ready to execute graphs at /execute_graph")
@app.post("/execute_graph", tags=["Execution"])
async def execute_graph(request: GraphExecutionRequest):
"""
接收图的JSON并执行可能会调用已通过TCP连接注册的远程节点
"""
try:
print("\n--- New Graph Execution Request ---")
# 1. 从请求数据构建 WorkflowGraph
graph = WorkflowGraph.from_dict(request.graph_data)
print(f"Executing graph: '{graph.name}'")
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
# RemoteProxyNode 将通过这个服务来回调客户端
context = TraceContext(services={"rpc": trace_server})
# 3. 实例化执行器并运行
executor = WorkflowExecutor(graph, context)
executor.execute_batch(request.initial_inputs)
# 4. 获取并返回最终输出
final_outputs = context.get_outputs()
print(f"Graph execution finished. Final outputs: {final_outputs}")
return {
"success": True,
"outputs": final_outputs
}
except Exception as e:
import traceback
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
# --- 3. 启动服务器 ---
if __name__ == "__main__":
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
uvicorn.run(app, host="0.0.0.0", port=8000)