647 lines
21 KiB
Python
647 lines
21 KiB
Python
"""
|
||
图执行和查询相关 API
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||
from pydantic import BaseModel
|
||
from typing import List, Dict, Any, Optional
|
||
import json
|
||
from pathlib import Path
|
||
import asyncio
|
||
|
||
from server.app.core.node_base import DimensionMode, NodeType
|
||
|
||
from ..core.user_manager import get_user_path
|
||
from ..core.security import is_safe_path, validate_filename, sanitize_path
|
||
from ..core.node_loader import reload_custom_nodes
|
||
|
||
router = APIRouter()
|
||
|
||
# 简易用户级并发执行锁,防止重复触发执行导致后端过载
|
||
# 目前按 user_id 粒度(默认 guest)。未来可扩展为 per-workspace 或 per-graph。
|
||
_EXECUTION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||
|
||
def _get_user_lock(user_id: str = "guest") -> asyncio.Lock:
|
||
lock = _EXECUTION_LOCKS.get(user_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
_EXECUTION_LOCKS[user_id] = lock
|
||
return lock
|
||
|
||
|
||
class NodeSchema(BaseModel):
|
||
"""节点数据结构"""
|
||
id: str
|
||
node_type: NodeType
|
||
class_name: Optional[str] = None
|
||
params: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class EdgeSchema(BaseModel):
|
||
"""连线数据结构"""
|
||
# 允许前端不携带 id(后端会合成),但推荐前端提供稳定 id
|
||
id: Optional[str] = None
|
||
source: str
|
||
target: str
|
||
dimension_mode: DimensionMode = DimensionMode.NONE
|
||
source_port: Optional[str] = None
|
||
target_port: Optional[str] = None
|
||
|
||
|
||
class GraphExecuteRequest(BaseModel):
|
||
"""图执行请求"""
|
||
nodes: List[NodeSchema]
|
||
edges: List[EdgeSchema]
|
||
settings: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class NodePreviewRequest(BaseModel):
|
||
"""节点预览请求"""
|
||
node: NodeSchema
|
||
limit: int = 10
|
||
|
||
|
||
@router.get("/plugins")
|
||
async def get_plugins():
|
||
"""
|
||
获取所有可用的节点插件
|
||
从 NodeRegistry 动态获取
|
||
"""
|
||
from ..core.node_registry import NodeRegistry
|
||
|
||
# 获取所有注册的节点
|
||
all_nodes = NodeRegistry.list_all()
|
||
|
||
# 转换为前端需要的格式
|
||
plugins = {}
|
||
for class_name, node_class in all_nodes.items():
|
||
metadata = node_class.get_metadata()
|
||
|
||
# 转换输入端口格式
|
||
inputs = metadata.get("inputs", [])
|
||
|
||
# 转换输出端口格式
|
||
outputs = metadata.get("outputs", [])
|
||
|
||
# 转换参数格式
|
||
param_schema = metadata.get("params", [])
|
||
|
||
plugins[class_name] = {
|
||
"display_name": metadata.get("display_name", class_name),
|
||
"category": metadata.get("category", "Uncategorized"),
|
||
"description": metadata.get("description", ""),
|
||
"icon": metadata.get("icon", "📦"),
|
||
"node_type": metadata.get("node_type", "REGULAR"),
|
||
# 前端双击检测所需,使用实现类名 `class_name`
|
||
"class_name": class_name,
|
||
"node_logic": "standard",
|
||
"supports_preview": True,
|
||
"inputs": inputs,
|
||
"outputs": outputs,
|
||
"param_schema": param_schema,
|
||
"context_vars": metadata.get("context_vars", {}),
|
||
"cache_policy": metadata.get("cache_policy", "NONE")
|
||
}
|
||
|
||
return {
|
||
"plugins": plugins,
|
||
"total": len(plugins),
|
||
"categories": NodeRegistry.get_categories()
|
||
}
|
||
|
||
|
||
@router.post("/node/preview")
|
||
async def preview_node(request: NodePreviewRequest):
|
||
"""
|
||
预览单个节点的输出
|
||
"""
|
||
from ..core.node_registry import NodeRegistry
|
||
import polars as pl
|
||
|
||
try:
|
||
# 获取实现类名(严格要求存在)
|
||
class_name = request.node.class_name
|
||
if not class_name:
|
||
raise HTTPException(status_code=400, detail="Missing class_name for node preview")
|
||
node_class = NodeRegistry.get(class_name)
|
||
# 创建节点实例
|
||
node_instance = node_class()
|
||
# 设置参数(兼容前端可能发送的 `data` 或 `params`)
|
||
pdata = request.node.data or request.node.params or {}
|
||
for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []):
|
||
# 使用兼容的 set_param 接口(有的节点实现可能不同)
|
||
try:
|
||
node_instance.set_param(param_name, param_value)
|
||
except Exception:
|
||
setattr(node_instance, param_name, param_value)
|
||
# 准备输入(预览模式下输入可能为空)
|
||
inputs = {}
|
||
# 准备上下文
|
||
context = {
|
||
"user_id": "guest",
|
||
"preview_mode": True
|
||
}
|
||
# 执行节点
|
||
result = await node_instance.wrap_process(inputs, context)
|
||
# 提取输出
|
||
outputs = result.get("outputs", {})
|
||
# 如果输出是 DataFrame,转换为预览格式
|
||
preview_data = []
|
||
columns = []
|
||
for output_name, output_value in outputs.items():
|
||
if isinstance(output_value, pl.DataFrame):
|
||
# polars DataFrame 预览
|
||
columns = list(output_value.columns)
|
||
preview_data = output_value.head(request.limit).to_dicts()
|
||
break
|
||
elif isinstance(output_value, list):
|
||
# 数组预览
|
||
preview_data = output_value[:request.limit]
|
||
columns = ["value"]
|
||
preview_data = [{"value": item} for item in preview_data]
|
||
break
|
||
else:
|
||
# 单值
|
||
preview_data = [{"value": output_value}]
|
||
columns = ["value"]
|
||
break
|
||
return {
|
||
"class_name" : class_name,
|
||
"status": "success",
|
||
"columns": columns,
|
||
"preview": preview_data,
|
||
"context": result.get("context", {})
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"节点预览失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/graph/execute")
|
||
async def execute_graph(request: GraphExecuteRequest):
|
||
"""
|
||
执行完整的计算图
|
||
"""
|
||
from ..core.workflow_executor import WorkflowExecutor
|
||
import time
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 创建工作流执行器
|
||
user_id = "guest" # TODO: 从请求或鉴权中获取用户ID
|
||
executor = WorkflowExecutor(user_id=user_id)
|
||
lock = _get_user_lock(user_id)
|
||
if lock.locked():
|
||
raise HTTPException(status_code=409, detail="已有执行在进行中,请稍后重试")
|
||
|
||
# 转换节点格式,严格要求 class_name
|
||
nodes = []
|
||
for node in request.nodes:
|
||
# 支持前端将参数放在 `data` 或 `params` 中
|
||
if not getattr(node, 'class_name', None) :
|
||
raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name")
|
||
nodes.append({
|
||
"id": node.id,
|
||
"node_type": node.node_type,
|
||
"class_name": node.class_name,
|
||
# 将运行时参数统一为 `params` 字段传入执行器
|
||
"params": node.params or {},
|
||
})
|
||
|
||
# 转换连线格式
|
||
edges = []
|
||
for edge in request.edges:
|
||
# 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle
|
||
s_handle = edge.source_port or "output"
|
||
t_handle = edge.target_port or "input"
|
||
eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}"
|
||
dimension_mode = edge.dimension_mode
|
||
edges.append({
|
||
"id": eid,
|
||
"source": edge.source,
|
||
"source_port": s_handle,
|
||
"target": edge.target,
|
||
"target_port": t_handle,
|
||
"dimension_mode": dimension_mode
|
||
})
|
||
|
||
# 准备全局上下文
|
||
global_context = request.settings or {}
|
||
|
||
# 执行工作流(加并发锁)
|
||
async with lock:
|
||
success, results = await executor.execute(nodes, edges, global_context)
|
||
if not success:
|
||
raise Exception(str(results))
|
||
return results
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"error": str(e),
|
||
"type": type(e).__name__,
|
||
"traceback": traceback.format_exc()
|
||
}
|
||
)
|
||
|
||
|
||
@router.websocket("/ws/graph/execute")
|
||
async def ws_execute_graph(websocket: WebSocket):
|
||
"""WebSocket 流式执行计算图。
|
||
|
||
客户端需先发送一条初始消息(JSON):
|
||
{
|
||
"nodes": [...],
|
||
"edges": [...],
|
||
"settings": { ... }
|
||
}
|
||
后端随后按事件流发送:start / node_started / node_completed / progress / error / completed
|
||
"""
|
||
await websocket.accept()
|
||
try:
|
||
init_msg = await websocket.receive_json()
|
||
from ..core.workflow_executor import WorkflowExecutor
|
||
user_id = "guest" # TODO: 从握手/鉴权中获取用户 ID
|
||
lock = _get_user_lock(user_id)
|
||
if lock.locked():
|
||
# 拒绝并发执行,返回错误事件
|
||
await websocket.send_json({
|
||
"event": "error",
|
||
"run_id": None,
|
||
"data": {"message": "已有执行在进行中,请稍后重试"}
|
||
})
|
||
await websocket.close(code=1008)
|
||
return
|
||
|
||
# 基于 HTTP 路由的同样规范进行转换
|
||
nodes_in = []
|
||
for node in init_msg.get("nodes", []):
|
||
if not node.get("class_name"):
|
||
# 兼容前端旧数据结构(可能把 class 放到 data.meta)
|
||
class_name = node.get("class_name") or (node.get("data", {}).get("meta", {}).get("class_name"))
|
||
if not class_name:
|
||
await websocket.send_json({
|
||
"event": "error",
|
||
"run_id": None,
|
||
"data": {"message": f"Node {node.get('id')} missing class_name"}
|
||
})
|
||
await websocket.close(code=1003)
|
||
return
|
||
node["class_name"] = class_name
|
||
nodes_in.append({
|
||
"id": node.get("id"),
|
||
"node_type": node.get("node_type"),
|
||
"class_name": node.get("class_name"),
|
||
"params": node.get("params") or {},
|
||
})
|
||
|
||
edges_in = []
|
||
for edge in init_msg.get("edges", []):
|
||
s_handle = edge.get("source_port") or edge.get("sourceHandle") or "output"
|
||
if isinstance(s_handle, str) and s_handle.startswith("output-"):
|
||
s_handle = s_handle.replace("output-", "")
|
||
t_handle = edge.get("target_port") or edge.get("targetHandle") or "input"
|
||
if isinstance(t_handle, str) and t_handle.startswith("input-"):
|
||
t_handle = t_handle.replace("input-", "")
|
||
eid = edge.get("id") or f"{edge.get('source')}:{s_handle}->{edge.get('target')}:{t_handle}"
|
||
edges_in.append({
|
||
"id": eid,
|
||
"source": edge.get("source"),
|
||
"source_port": s_handle,
|
||
"target": edge.get("target"),
|
||
"target_port": t_handle,
|
||
"dimension_mode": DimensionMode(edge.get("dimension_mode")),
|
||
})
|
||
|
||
settings = init_msg.get("settings") or {}
|
||
executor = WorkflowExecutor(user_id=user_id)
|
||
|
||
# 在流式执行期间持有锁,确保同一用户不会并发多次执行
|
||
async with lock:
|
||
async for event in executor.execute_stream(nodes_in, edges_in, settings):
|
||
await websocket.send_json(event)
|
||
await websocket.close()
|
||
except WebSocketDisconnect:
|
||
# 客户端主动断开
|
||
return
|
||
except Exception as e:
|
||
try:
|
||
await websocket.send_json({
|
||
"event": "error",
|
||
"run_id": None,
|
||
"data": {"message": str(e)}
|
||
})
|
||
finally:
|
||
await websocket.close(code=1011)
|
||
|
||
|
||
@router.get("/users/list")
|
||
async def list_users():
|
||
"""
|
||
获取所有活跃用户列表(从数据库)
|
||
"""
|
||
from ..core.user_manager import list_users
|
||
users = list_users()
|
||
return {
|
||
"users": users,
|
||
"count": len(users)
|
||
}
|
||
|
||
|
||
@router.post("/users/add")
|
||
async def add_user(payload: Dict[str, Any]):
|
||
"""
|
||
添加新用户
|
||
"""
|
||
from ..core.user_manager import load_users_db, save_users_db, create_user_workspace
|
||
|
||
username = payload.get("username")
|
||
display_name = payload.get("display_name", username)
|
||
|
||
if not username:
|
||
raise HTTPException(status_code=400, detail="Missing username")
|
||
|
||
# 验证用户名格式
|
||
if not username.replace('_', '').replace('-', '').isalnum():
|
||
raise HTTPException(status_code=400, detail="Invalid username format")
|
||
|
||
# 加载数据库
|
||
db = load_users_db()
|
||
|
||
# 检查用户是否已存在
|
||
if any(u["username"] == username for u in db.get("users", [])):
|
||
raise HTTPException(status_code=400, detail="User already exists")
|
||
|
||
# 添加用户
|
||
from datetime import datetime
|
||
db["users"].append({
|
||
"username": username,
|
||
"display_name": display_name,
|
||
"created_at": datetime.now().isoformat(),
|
||
"active": True
|
||
})
|
||
|
||
# 保存数据库
|
||
if not save_users_db(db):
|
||
raise HTTPException(status_code=500, detail="Failed to save user database")
|
||
|
||
# 创建用户工作空间
|
||
create_user_workspace(username)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"用户 {username} 创建成功"
|
||
}
|
||
|
||
|
||
@router.post("/reload-custom-nodes")
|
||
async def reload_custom_nodes_endpoint():
|
||
"""
|
||
触发后端重新加载所有自定义节点(包括内置节点的刷新统计)
|
||
|
||
Returns:
|
||
包含加载统计和错误详情的 JSON
|
||
"""
|
||
try:
|
||
result = reload_custom_nodes()
|
||
return {
|
||
"success": True,
|
||
"message": f"已加载 {result.get('loaded', 0)}/{result.get('total', 0)} 个节点",
|
||
"details": result
|
||
}
|
||
except Exception as e:
|
||
import traceback
|
||
tb = traceback.format_exc()
|
||
raise HTTPException(status_code=500, detail={
|
||
"error": str(e),
|
||
"traceback": tb
|
||
})
|
||
|
||
|
||
@router.post("/nodes/save")
|
||
async def save_node(payload: Dict[str, Any]):
|
||
"""
|
||
保存节点配置到服务器
|
||
"""
|
||
node_id = payload.get("nodeId")
|
||
node_data = payload.get("nodeData")
|
||
|
||
if not node_id:
|
||
raise HTTPException(status_code=400, detail="Missing nodeId")
|
||
|
||
# TODO: 实现节点保存逻辑(可以存到数据库或文件)
|
||
return {
|
||
"success": True,
|
||
"message": f"节点 {node_id} 配置已保存",
|
||
"nodeId": node_id
|
||
}
|
||
|
||
|
||
@router.post("/workflows/save")
|
||
async def save_workflow(payload: Dict[str, Any]):
|
||
"""
|
||
保存工作流到用户目录
|
||
"""
|
||
filename = payload.get("filename")
|
||
username = payload.get("username", "guest")
|
||
workflow_data = payload.get("data")
|
||
|
||
if not filename:
|
||
raise HTTPException(status_code=400, detail="Missing filename")
|
||
|
||
if not workflow_data:
|
||
raise HTTPException(status_code=400, detail="Missing workflow data")
|
||
|
||
# 验证文件名
|
||
is_valid, error_msg = validate_filename(filename)
|
||
if not is_valid:
|
||
raise HTTPException(status_code=400, detail=error_msg)
|
||
|
||
# 确保.json扩展名
|
||
if not filename.endswith(".json"):
|
||
filename = f"{filename}.json"
|
||
|
||
# 构建保存路径
|
||
workflow_dir = get_user_path(username, "workflows")
|
||
workflow_dir.mkdir(parents=True, exist_ok=True)
|
||
workflow_file = workflow_dir / filename
|
||
|
||
# 安全检查
|
||
user_root = get_user_path(username)
|
||
if not is_safe_path(user_root, workflow_file):
|
||
raise HTTPException(status_code=403, detail="Access denied")
|
||
|
||
# 保存工作流
|
||
try:
|
||
with open(workflow_file, "w", encoding="utf-8") as f:
|
||
json.dump(workflow_data, f, indent=2, ensure_ascii=False)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"工作流已保存: {filename}",
|
||
"path": str(workflow_file.relative_to(user_root.parent))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
|
||
|
||
|
||
@router.get("/workflows/load")
|
||
async def load_workflow(filename: str, username: str = "guest"):
|
||
"""
|
||
加载工作流
|
||
"""
|
||
# 清理文件名
|
||
filename = sanitize_path(filename)
|
||
|
||
# 构建文件路径
|
||
workflow_file = get_user_path(username, f"workflows/{filename}")
|
||
|
||
# 安全检查
|
||
user_root = get_user_path(username)
|
||
if not is_safe_path(user_root, workflow_file):
|
||
raise HTTPException(status_code=403, detail="Access denied")
|
||
|
||
if not workflow_file.exists():
|
||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||
|
||
# 加载工作流
|
||
try:
|
||
with open(workflow_file, "r", encoding="utf-8") as f:
|
||
workflow_data = json.load(f)
|
||
|
||
return {
|
||
"success": True,
|
||
"filename": filename,
|
||
"data": workflow_data
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}")
|
||
|
||
|
||
class SaveFunctionRequest(BaseModel):
|
||
"""保存函数请求"""
|
||
function_name: str
|
||
display_name: str
|
||
description: Optional[str] = ""
|
||
nodes: List[NodeSchema]
|
||
edges: List[EdgeSchema]
|
||
inputs: List[Dict[str, Any]]
|
||
outputs: List[Dict[str, Any]]
|
||
|
||
|
||
@router.post("/functions/save")
|
||
async def save_function(request: SaveFunctionRequest):
|
||
"""
|
||
保存工作流为函数
|
||
将选中的节点和连线封装为可复用的函数节点
|
||
"""
|
||
from ..core.node_loader import load_function_nodes
|
||
|
||
# 验证函数名
|
||
function_name = validate_filename(request.function_name)
|
||
if not function_name:
|
||
raise HTTPException(status_code=400, detail="Invalid function name")
|
||
|
||
# 构建函数文件路径
|
||
project_root = Path(__file__).resolve().parents[3]
|
||
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
|
||
functions_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
function_file = functions_dir / f"{function_name}.json"
|
||
|
||
# 检查是否已存在
|
||
if function_file.exists():
|
||
raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在")
|
||
|
||
# 构建函数数据
|
||
function_data = {
|
||
"function_name": function_name,
|
||
"display_name": request.display_name,
|
||
"description": request.description,
|
||
"inputs": request.inputs,
|
||
"outputs": request.outputs,
|
||
"nodes": [node.dict() for node in request.nodes],
|
||
"edges": [edge.dict() for edge in request.edges]
|
||
}
|
||
|
||
# 保存函数
|
||
try:
|
||
with open(function_file, "w", encoding="utf-8") as f:
|
||
json.dump(function_data, f, indent=2, ensure_ascii=False)
|
||
|
||
# 重新加载函数节点
|
||
load_result = load_function_nodes()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"函数已保存: {request.display_name}",
|
||
"function_name": function_name,
|
||
"path": str(function_file.relative_to(project_root)),
|
||
"reload_result": load_result
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
|
||
|
||
|
||
@router.get("/functions/list")
|
||
async def list_functions():
|
||
"""
|
||
获取所有可用函数列表
|
||
"""
|
||
project_root = Path(__file__).resolve().parents[3]
|
||
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
|
||
|
||
if not functions_dir.exists():
|
||
return {"functions": []}
|
||
|
||
functions = []
|
||
for json_file in functions_dir.glob("*.json"):
|
||
try:
|
||
with open(json_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
functions.append({
|
||
"function_name": data.get("function_name", json_file.stem),
|
||
"display_name": data.get("display_name", json_file.stem),
|
||
"description": data.get("description", ""),
|
||
"inputs": data.get("inputs", []),
|
||
"outputs": data.get("outputs", [])
|
||
})
|
||
except Exception as e:
|
||
print(f"❌ 读取函数失败: {json_file.name} - {str(e)}")
|
||
|
||
return {"functions": functions}
|
||
|
||
|
||
@router.get("/functions/{function_name}")
|
||
async def get_function(function_name: str):
|
||
"""
|
||
获取函数详细信息(包含内部工作流)
|
||
"""
|
||
# 验证函数名
|
||
function_name = validate_filename(function_name)
|
||
if not function_name:
|
||
raise HTTPException(status_code=400, detail="Invalid function name")
|
||
|
||
project_root = Path(__file__).resolve().parents[3]
|
||
function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json"
|
||
|
||
if not function_file.exists():
|
||
raise HTTPException(status_code=404, detail="Function not found")
|
||
|
||
try:
|
||
with open(function_file, 'r', encoding='utf-8') as f:
|
||
function_data = json.load(f)
|
||
|
||
return {
|
||
"success": True,
|
||
"data": function_data
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")
|
||
|