94 lines
3.3 KiB
Python
94 lines
3.3 KiB
Python
|
||
import sys
|
||
import os
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
from typing import Dict, Any
|
||
import uvicorn
|
||
|
||
# 确保 pytrace 在 Python 路径中
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
from pytrace.model.graph import WorkflowGraph
|
||
from pytrace.runtime.context import TraceContext
|
||
from pytrace.runtime.executor import WorkflowExecutor
|
||
from pytrace.runtime.repository import NodeRepository
|
||
# 导入项目中已有的 TraceServer
|
||
from app.services import trace_server
|
||
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
|
||
from pytrace.nodes.test_math_suite import MathSuite
|
||
|
||
|
||
# --- 1. 数据模型 ---
|
||
|
||
class GraphExecutionRequest(BaseModel):
|
||
graph_data: Dict[str, Any]
|
||
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
# --- 2. FastAPI 应用定义 ---
|
||
|
||
app = FastAPI(
|
||
title="TraceStudio - Remote Test Server",
|
||
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
|
||
version="2.0.0",
|
||
)
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
# 1. 在后台启动TCP服务器,用于监听客户端的连接和注册
|
||
trace_server.start_background()
|
||
|
||
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
|
||
NodeRepository.set_remote_registry(trace_server.remote_registry)
|
||
|
||
print("--- Remote Test Server ---")
|
||
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
|
||
print("HTTP Server ready to execute graphs at /execute_graph")
|
||
|
||
|
||
@app.post("/execute_graph", tags=["Execution"])
|
||
async def execute_graph(request: GraphExecutionRequest):
|
||
"""
|
||
接收图的JSON并执行,可能会调用已通过TCP连接注册的远程节点。
|
||
"""
|
||
try:
|
||
print("\n--- New Graph Execution Request ---")
|
||
# 1. 从请求数据构建 WorkflowGraph
|
||
graph = WorkflowGraph.from_dict(request.graph_data)
|
||
print(f"Executing graph: '{graph.name}'")
|
||
|
||
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
|
||
# RemoteProxyNode 将通过这个服务来回调客户端
|
||
context = TraceContext(services={"rpc": trace_server})
|
||
|
||
# 3. 实例化执行器并运行
|
||
executor = WorkflowExecutor(graph, context)
|
||
executor.execute_batch(request.initial_inputs)
|
||
|
||
# 4. 获取并返回最终输出
|
||
final_outputs = context.get_outputs()
|
||
print(f"Graph execution finished. Final outputs: {final_outputs}")
|
||
return {
|
||
"success": True,
|
||
"outputs": final_outputs
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"error": str(e),
|
||
"type": type(e).__name__,
|
||
"traceback": traceback.format_exc()
|
||
}
|
||
)
|
||
|
||
# --- 3. 启动服务器 ---
|
||
|
||
if __name__ == "__main__":
|
||
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
|
||
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|