TraceStudio-dev/server/app/api/endpoints_graph.py

622 lines
20 KiB
Python
Raw Permalink Normal View History

"""
图执行和查询相关 API
"""
2026-01-12 03:32:51 +08:00
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import json
from pathlib import Path
2026-01-12 03:32:51 +08:00
import asyncio
2026-01-10 19:08:49 +08:00
from server.app.core.node_base import DimensionMode, NodeType
from ..core.user_manager import get_user_path
from ..core.security import is_safe_path, validate_filename, sanitize_path
router = APIRouter()
2026-01-12 03:32:51 +08:00
# 简易用户级并发执行锁,防止重复触发执行导致后端过载
# 目前按 user_id 粒度(默认 guest。未来可扩展为 per-workspace 或 per-graph。
_EXECUTION_LOCKS: Dict[str, asyncio.Lock] = {}
def _get_user_lock(user_id: str = "guest") -> asyncio.Lock:
lock = _EXECUTION_LOCKS.get(user_id)
if lock is None:
lock = asyncio.Lock()
_EXECUTION_LOCKS[user_id] = lock
return lock
class NodeSchema(BaseModel):
"""节点数据结构"""
id: str
2026-01-10 19:08:49 +08:00
node_type: NodeType
class_name: Optional[str] = None
params: Optional[Dict[str, Any]] = None
class EdgeSchema(BaseModel):
"""连线数据结构"""
# 允许前端不携带 id后端会合成但推荐前端提供稳定 id
id: Optional[str] = None
source: str
target: str
dimension_mode: DimensionMode = DimensionMode.NONE
source_port: Optional[str] = None
target_port: Optional[str] = None
class GraphExecuteRequest(BaseModel):
"""图执行请求"""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
settings: Optional[Dict[str, Any]] = None
class NodePreviewRequest(BaseModel):
"""节点预览请求"""
node: NodeSchema
limit: int = 10
@router.get("/plugins")
async def get_plugins():
"""
获取所有可用的节点插件
NodeRegistry 动态获取
"""
from ..core.node_registry import NodeRegistry
# 获取所有注册的节点
all_nodes = NodeRegistry.list_all()
# 转换为前端需要的格式
plugins = {}
for class_name, node_class in all_nodes.items():
metadata = node_class.get_metadata()
# 转换输入端口格式
inputs = metadata.get("inputs", [])
# 转换输出端口格式
outputs = metadata.get("outputs", [])
# 转换参数格式
param_schema = metadata.get("params", [])
plugins[class_name] = {
"display_name": metadata.get("display_name", class_name),
"category": metadata.get("category", "Uncategorized"),
"description": metadata.get("description", ""),
"icon": metadata.get("icon", "📦"),
"node_type": metadata.get("node_type", "REGULAR"),
# 前端双击检测所需,使用实现类名 `class_name`
"class_name": class_name,
"node_logic": "standard",
"supports_preview": True,
"inputs": inputs,
"outputs": outputs,
"param_schema": param_schema,
"context_vars": metadata.get("context_vars", {}),
"cache_policy": metadata.get("cache_policy", "NONE")
}
return {
"plugins": plugins,
"total": len(plugins),
"categories": NodeRegistry.get_categories()
}
@router.post("/node/preview")
async def preview_node(request: NodePreviewRequest):
"""
预览单个节点的输出
"""
from ..core.node_registry import NodeRegistry
import polars as pl
try:
# 获取实现类名(严格要求存在)
class_name = request.node.class_name
if not class_name:
raise HTTPException(status_code=400, detail="Missing class_name for node preview")
node_class = NodeRegistry.get(class_name)
# 创建节点实例
node_instance = node_class()
# 设置参数(兼容前端可能发送的 `data` 或 `params`
pdata = request.node.data or request.node.params or {}
for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []):
# 使用兼容的 set_param 接口(有的节点实现可能不同)
try:
node_instance.set_param(param_name, param_value)
except Exception:
setattr(node_instance, param_name, param_value)
# 准备输入(预览模式下输入可能为空)
inputs = {}
# 准备上下文
context = {
"user_id": "guest",
"preview_mode": True
}
# 执行节点
result = await node_instance.wrap_process(inputs, context)
# 提取输出
outputs = result.get("outputs", {})
# 如果输出是 DataFrame转换为预览格式
preview_data = []
columns = []
for output_name, output_value in outputs.items():
if isinstance(output_value, pl.DataFrame):
# polars DataFrame 预览
columns = list(output_value.columns)
preview_data = output_value.head(request.limit).to_dicts()
break
elif isinstance(output_value, list):
# 数组预览
preview_data = output_value[:request.limit]
columns = ["value"]
preview_data = [{"value": item} for item in preview_data]
break
else:
# 单值
preview_data = [{"value": output_value}]
columns = ["value"]
break
return {
"class_name" : class_name,
"status": "success",
"columns": columns,
"preview": preview_data,
"context": result.get("context", {})
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"节点预览失败: {str(e)}"
)
@router.post("/graph/execute")
async def execute_graph(request: GraphExecuteRequest):
"""
执行完整的计算图
"""
from ..core.workflow_executor import WorkflowExecutor
import time
try:
start_time = time.time()
# 创建工作流执行器
2026-01-12 03:32:51 +08:00
user_id = "guest" # TODO: 从请求或鉴权中获取用户ID
executor = WorkflowExecutor(user_id=user_id)
lock = _get_user_lock(user_id)
if lock.locked():
raise HTTPException(status_code=409, detail="已有执行在进行中,请稍后重试")
# 转换节点格式,严格要求 class_name
nodes = []
for node in request.nodes:
# 支持前端将参数放在 `data` 或 `params` 中
2026-01-10 19:08:49 +08:00
if not getattr(node, 'class_name', None) :
raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name")
nodes.append({
"id": node.id,
2026-01-10 19:08:49 +08:00
"node_type": node.node_type,
"class_name": node.class_name,
# 将运行时参数统一为 `params` 字段传入执行器
2026-01-10 19:08:49 +08:00
"params": node.params or {},
})
# 转换连线格式
edges = []
for edge in request.edges:
# 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle
s_handle = edge.source_port or "output"
t_handle = edge.target_port or "input"
eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}"
dimension_mode = edge.dimension_mode
edges.append({
"id": eid,
"source": edge.source,
"source_port": s_handle,
"target": edge.target,
"target_port": t_handle,
"dimension_mode": dimension_mode
})
# 准备全局上下文
global_context = request.settings or {}
2026-01-12 03:32:51 +08:00
# 执行工作流(加并发锁)
async with lock:
success, results = await executor.execute(nodes, edges, global_context)
if not success:
raise Exception(str(results))
2026-01-10 19:08:49 +08:00
return results
except Exception as e:
import traceback
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
2026-01-12 03:32:51 +08:00
@router.websocket("/ws/graph/execute")
async def ws_execute_graph(websocket: WebSocket):
"""WebSocket 流式执行计算图。
客户端需先发送一条初始消息JSON
{
"nodes": [...],
"edges": [...],
"settings": { ... }
}
后端随后按事件流发送start / node_started / node_completed / progress / error / completed
"""
await websocket.accept()
try:
init_msg = await websocket.receive_json()
from ..core.workflow_executor import WorkflowExecutor
user_id = "guest" # TODO: 从握手/鉴权中获取用户 ID
lock = _get_user_lock(user_id)
if lock.locked():
# 拒绝并发执行,返回错误事件
await websocket.send_json({
"event": "error",
"run_id": None,
"data": {"message": "已有执行在进行中,请稍后重试"}
})
await websocket.close(code=1008)
return
# 基于 HTTP 路由的同样规范进行转换
nodes_in = []
for node in init_msg.get("nodes", []):
if not node.get("class_name"):
# 兼容前端旧数据结构(可能把 class 放到 data.meta
class_name = node.get("class_name") or (node.get("data", {}).get("meta", {}).get("class_name"))
if not class_name:
await websocket.send_json({
"event": "error",
"run_id": None,
"data": {"message": f"Node {node.get('id')} missing class_name"}
})
await websocket.close(code=1003)
return
node["class_name"] = class_name
nodes_in.append({
"id": node.get("id"),
"node_type": node.get("node_type"),
"class_name": node.get("class_name"),
"params": node.get("params") or {},
})
edges_in = []
for edge in init_msg.get("edges", []):
s_handle = edge.get("source_port") or edge.get("sourceHandle") or "output"
if isinstance(s_handle, str) and s_handle.startswith("output-"):
s_handle = s_handle.replace("output-", "")
t_handle = edge.get("target_port") or edge.get("targetHandle") or "input"
if isinstance(t_handle, str) and t_handle.startswith("input-"):
t_handle = t_handle.replace("input-", "")
eid = edge.get("id") or f"{edge.get('source')}:{s_handle}->{edge.get('target')}:{t_handle}"
edges_in.append({
"id": eid,
"source": edge.get("source"),
"source_port": s_handle,
"target": edge.get("target"),
"target_port": t_handle,
2026-01-12 11:13:01 +08:00
"dimension_mode": DimensionMode(edge.get("dimension_mode")),
2026-01-12 03:32:51 +08:00
})
settings = init_msg.get("settings") or {}
executor = WorkflowExecutor(user_id=user_id)
# 在流式执行期间持有锁,确保同一用户不会并发多次执行
async with lock:
async for event in executor.execute_stream(nodes_in, edges_in, settings):
await websocket.send_json(event)
await websocket.close()
except WebSocketDisconnect:
# 客户端主动断开
return
except Exception as e:
try:
await websocket.send_json({
"event": "error",
"run_id": None,
"data": {"message": str(e)}
})
finally:
await websocket.close(code=1011)
@router.get("/users/list")
async def list_users():
"""
获取所有活跃用户列表从数据库
"""
from ..core.user_manager import list_users
users = list_users()
return {
"users": users,
"count": len(users)
}
@router.post("/users/add")
async def add_user(payload: Dict[str, Any]):
"""
添加新用户
"""
from ..core.user_manager import load_users_db, save_users_db, create_user_workspace
username = payload.get("username")
display_name = payload.get("display_name", username)
if not username:
raise HTTPException(status_code=400, detail="Missing username")
# 验证用户名格式
if not username.replace('_', '').replace('-', '').isalnum():
raise HTTPException(status_code=400, detail="Invalid username format")
# 加载数据库
db = load_users_db()
# 检查用户是否已存在
if any(u["username"] == username for u in db.get("users", [])):
raise HTTPException(status_code=400, detail="User already exists")
# 添加用户
from datetime import datetime
db["users"].append({
"username": username,
"display_name": display_name,
"created_at": datetime.now().isoformat(),
"active": True
})
# 保存数据库
if not save_users_db(db):
raise HTTPException(status_code=500, detail="Failed to save user database")
# 创建用户工作空间
create_user_workspace(username)
return {
"success": True,
"message": f"用户 {username} 创建成功"
}
@router.post("/nodes/save")
async def save_node(payload: Dict[str, Any]):
"""
保存节点配置到服务器
"""
node_id = payload.get("nodeId")
node_data = payload.get("nodeData")
if not node_id:
raise HTTPException(status_code=400, detail="Missing nodeId")
# TODO: 实现节点保存逻辑(可以存到数据库或文件)
return {
"success": True,
"message": f"节点 {node_id} 配置已保存",
"nodeId": node_id
}
@router.post("/workflows/save")
async def save_workflow(payload: Dict[str, Any]):
"""
保存工作流到用户目录
"""
filename = payload.get("filename")
username = payload.get("username", "guest")
workflow_data = payload.get("data")
if not filename:
raise HTTPException(status_code=400, detail="Missing filename")
if not workflow_data:
raise HTTPException(status_code=400, detail="Missing workflow data")
# 验证文件名
is_valid, error_msg = validate_filename(filename)
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 确保.json扩展名
if not filename.endswith(".json"):
filename = f"{filename}.json"
# 构建保存路径
workflow_dir = get_user_path(username, "workflows")
workflow_dir.mkdir(parents=True, exist_ok=True)
workflow_file = workflow_dir / filename
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
# 保存工作流
try:
with open(workflow_file, "w", encoding="utf-8") as f:
json.dump(workflow_data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"工作流已保存: {filename}",
"path": str(workflow_file.relative_to(user_root.parent))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/workflows/load")
async def load_workflow(filename: str, username: str = "guest"):
"""
加载工作流
"""
# 清理文件名
filename = sanitize_path(filename)
# 构建文件路径
workflow_file = get_user_path(username, f"workflows/{filename}")
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
if not workflow_file.exists():
raise HTTPException(status_code=404, detail="Workflow not found")
# 加载工作流
try:
with open(workflow_file, "r", encoding="utf-8") as f:
workflow_data = json.load(f)
return {
"success": True,
"filename": filename,
"data": workflow_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}")
class SaveFunctionRequest(BaseModel):
"""保存函数请求"""
function_name: str
display_name: str
description: Optional[str] = ""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
inputs: List[Dict[str, Any]]
outputs: List[Dict[str, Any]]
@router.post("/functions/save")
async def save_function(request: SaveFunctionRequest):
"""
保存工作流为函数
将选中的节点和连线封装为可复用的函数节点
"""
from ..core.node_loader import load_function_nodes
# 验证函数名
function_name = validate_filename(request.function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
# 构建函数文件路径
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
functions_dir.mkdir(parents=True, exist_ok=True)
function_file = functions_dir / f"{function_name}.json"
# 检查是否已存在
if function_file.exists():
raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在")
# 构建函数数据
function_data = {
"function_name": function_name,
"display_name": request.display_name,
"description": request.description,
"inputs": request.inputs,
"outputs": request.outputs,
"nodes": [node.dict() for node in request.nodes],
"edges": [edge.dict() for edge in request.edges]
}
# 保存函数
try:
with open(function_file, "w", encoding="utf-8") as f:
json.dump(function_data, f, indent=2, ensure_ascii=False)
# 重新加载函数节点
load_result = load_function_nodes()
return {
"success": True,
"message": f"函数已保存: {request.display_name}",
"function_name": function_name,
"path": str(function_file.relative_to(project_root)),
"reload_result": load_result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/functions/list")
async def list_functions():
"""
获取所有可用函数列表
"""
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
if not functions_dir.exists():
return {"functions": []}
functions = []
for json_file in functions_dir.glob("*.json"):
try:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
functions.append({
"function_name": data.get("function_name", json_file.stem),
"display_name": data.get("display_name", json_file.stem),
"description": data.get("description", ""),
"inputs": data.get("inputs", []),
"outputs": data.get("outputs", [])
})
except Exception as e:
print(f"❌ 读取函数失败: {json_file.name} - {str(e)}")
return {"functions": functions}
@router.get("/functions/{function_name}")
async def get_function(function_name: str):
"""
获取函数详细信息包含内部工作流
"""
# 验证函数名
function_name = validate_filename(function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
project_root = Path(__file__).resolve().parents[3]
function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json"
if not function_file.exists():
raise HTTPException(status_code=404, detail="Function not found")
try:
with open(function_file, 'r', encoding='utf-8') as f:
function_data = json.load(f)
return {
"success": True,
"data": function_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")