94 lines
3.3 KiB
Python
94 lines
3.3 KiB
Python
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from typing import Dict, Any
|
|||
|
|
import uvicorn
|
|||
|
|
|
|||
|
|
# 确保 pytrace 在 Python 路径中
|
|||
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|||
|
|
|
|||
|
|
from pytrace.model.graph import WorkflowGraph
|
|||
|
|
from pytrace.runtime.context import TraceContext
|
|||
|
|
from pytrace.runtime.executor import WorkflowExecutor
|
|||
|
|
from pytrace.runtime.repository import NodeRepository
|
|||
|
|
# 导入项目中已有的 TraceServer
|
|||
|
|
from app.services import trace_server
|
|||
|
|
# 假定数学节点套件已在某处定义并可以导入,以便在图中本地使用
|
|||
|
|
from pytrace.nodes.test_math_suite import MathSuite
|
|||
|
|
|
|||
|
|
|
|||
|
|
# --- 1. 数据模型 ---
|
|||
|
|
|
|||
|
|
class GraphExecutionRequest(BaseModel):
|
|||
|
|
graph_data: Dict[str, Any]
|
|||
|
|
initial_inputs: Dict[str, Any] = Field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# --- 2. FastAPI 应用定义 ---
|
|||
|
|
|
|||
|
|
app = FastAPI(
|
|||
|
|
title="TraceStudio - Remote Test Server",
|
|||
|
|
description="一个用于测试远程节点注册和执行的服务器。它同时开启HTTP和TCP服务。",
|
|||
|
|
version="2.0.0",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@app.on_event("startup")
|
|||
|
|
async def startup_event():
|
|||
|
|
# 1. 在后台启动TCP服务器,用于监听客户端的连接和注册
|
|||
|
|
trace_server.start_background()
|
|||
|
|
|
|||
|
|
# 2. 将TCP服务器的远程节点注册表配置给全局的NodeRepository
|
|||
|
|
NodeRepository.set_remote_registry(trace_server.remote_registry)
|
|||
|
|
|
|||
|
|
print("--- Remote Test Server ---")
|
|||
|
|
print(f"TCP RPC Server started on port {trace_server.port}. Waiting for clients to connect and register.")
|
|||
|
|
print("HTTP Server ready to execute graphs at /execute_graph")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/execute_graph", tags=["Execution"])
|
|||
|
|
async def execute_graph(request: GraphExecutionRequest):
|
|||
|
|
"""
|
|||
|
|
接收图的JSON并执行,可能会调用已通过TCP连接注册的远程节点。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
print("\n--- New Graph Execution Request ---")
|
|||
|
|
# 1. 从请求数据构建 WorkflowGraph
|
|||
|
|
graph = WorkflowGraph.from_dict(request.graph_data)
|
|||
|
|
print(f"Executing graph: '{graph.name}'")
|
|||
|
|
|
|||
|
|
# 2. 准备执行上下文,注入 trace_server 实例作为RPC服务
|
|||
|
|
# RemoteProxyNode 将通过这个服务来回调客户端
|
|||
|
|
context = TraceContext(services={"rpc": trace_server})
|
|||
|
|
|
|||
|
|
# 3. 实例化执行器并运行
|
|||
|
|
executor = WorkflowExecutor(graph, context)
|
|||
|
|
executor.execute_batch(request.initial_inputs)
|
|||
|
|
|
|||
|
|
# 4. 获取并返回最终输出
|
|||
|
|
final_outputs = context.get_outputs()
|
|||
|
|
print(f"Graph execution finished. Final outputs: {final_outputs}")
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"outputs": final_outputs
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
import traceback
|
|||
|
|
print(f"ERROR: Graph execution failed: {e}\n{traceback.format_exc()}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=500,
|
|||
|
|
detail={
|
|||
|
|
"error": str(e),
|
|||
|
|
"type": type(e).__name__,
|
|||
|
|
"traceback": traceback.format_exc()
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# --- 3. 启动服务器 ---
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 使用 uvicorn 启动HTTP服务器。TCP服务器会在startup事件中被启动。
|
|||
|
|
# 您可以通过命令行启动: uvicorn server.remote_test_server:app --reload --port 8000
|
|||
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|