pytrace/server/remote_test_server.py
2026-01-22 23:22:01 +08:00

94 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any
import uvicorn
# 确保 pytrace 在 Python 路径中
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from pytrace.model.graph import WorkflowGraph
from pytrace.runtime.context import TraceContext
from pytrace.runtime.executor import WorkflowExecutor
from pytrace.runtime.repository import NodeRepository
# 导入项目中已有的 TraceServer
from app.services import trace_server
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
from pytrace.nodes.test_math_suite import MathSuite
# --- 1. 数据模型 ---
class GraphExecutionRequest(BaseModel):
graph_data: Dict[str, Any]
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
# --- 2. FastAPI 应用定义 ---
app = FastAPI(
title="TraceStudio - Remote Test Server",
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
version="2.0.0",
)
@app.on_event("startup")
async def startup_event():
# 1. 在后台启动TCP服务器用于监听客户端的连接和注册
trace_server.start_background()
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
NodeRepository.set_remote_registry(trace_server.remote_registry)
print("--- Remote Test Server ---")
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
print("HTTP Server ready to execute graphs at /execute_graph")
@app.post("/execute_graph", tags=["Execution"])
async def execute_graph(request: GraphExecutionRequest):
"""
接收图的JSON并执行可能会调用已通过TCP连接注册的远程节点。
"""
try:
print("\n--- New Graph Execution Request ---")
# 1. 从请求数据构建 WorkflowGraph
graph = WorkflowGraph.from_dict(request.graph_data)
print(f"Executing graph: '{graph.name}'")
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
# RemoteProxyNode 将通过这个服务来回调客户端
context = TraceContext(services={"rpc": trace_server})
# 3. 实例化执行器并运行
executor = WorkflowExecutor(graph, context)
executor.execute_batch(request.initial_inputs)
# 4. 获取并返回最终输出
final_outputs = context.get_outputs()
print(f"Graph execution finished. Final outputs: {final_outputs}")
return {
"success": True,
"outputs": final_outputs
}
except Exception as e:
import traceback
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
# --- 3. 启动服务器 ---
if __name__ == "__main__":
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
uvicorn.run(app, host="0.0.0.0", port=8000)