""" 图执行和查询相关 API """ from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from pydantic import BaseModel from typing import List, Dict, Any, Optional import json from pathlib import Path import asyncio from server.app.core.node_base import DimensionMode, NodeType from ..core.user_manager import get_user_path from ..core.security import is_safe_path, validate_filename, sanitize_path router = APIRouter() # 简易用户级并发执行锁,防止重复触发执行导致后端过载 # 目前按 user_id 粒度(默认 guest)。未来可扩展为 per-workspace 或 per-graph。 _EXECUTION_LOCKS: Dict[str, asyncio.Lock] = {} def _get_user_lock(user_id: str = "guest") -> asyncio.Lock: lock = _EXECUTION_LOCKS.get(user_id) if lock is None: lock = asyncio.Lock() _EXECUTION_LOCKS[user_id] = lock return lock class NodeSchema(BaseModel): """节点数据结构""" id: str node_type: NodeType class_name: Optional[str] = None params: Optional[Dict[str, Any]] = None class EdgeSchema(BaseModel): """连线数据结构""" # 允许前端不携带 id(后端会合成),但推荐前端提供稳定 id id: Optional[str] = None source: str target: str dimension_mode: DimensionMode = DimensionMode.NONE source_port: Optional[str] = None target_port: Optional[str] = None class GraphExecuteRequest(BaseModel): """图执行请求""" nodes: List[NodeSchema] edges: List[EdgeSchema] settings: Optional[Dict[str, Any]] = None class NodePreviewRequest(BaseModel): """节点预览请求""" node: NodeSchema limit: int = 10 @router.get("/plugins") async def get_plugins(): """ 获取所有可用的节点插件 从 NodeRegistry 动态获取 """ from ..core.node_registry import NodeRegistry # 获取所有注册的节点 all_nodes = NodeRegistry.list_all() # 转换为前端需要的格式 plugins = {} for class_name, node_class in all_nodes.items(): metadata = node_class.get_metadata() # 转换输入端口格式 inputs = metadata.get("inputs", []) # 转换输出端口格式 outputs = metadata.get("outputs", []) # 转换参数格式 param_schema = metadata.get("params", []) plugins[class_name] = { "display_name": metadata.get("display_name", class_name), "category": metadata.get("category", "Uncategorized"), "description": metadata.get("description", ""), "icon": metadata.get("icon", "📦"), "node_type": metadata.get("node_type", "REGULAR"), # 前端双击检测所需,使用实现类名 `class_name` "class_name": class_name, "node_logic": "standard", "supports_preview": True, "inputs": inputs, "outputs": outputs, "param_schema": param_schema, "context_vars": metadata.get("context_vars", {}), "cache_policy": metadata.get("cache_policy", "NONE") } return { "plugins": plugins, "total": len(plugins), "categories": NodeRegistry.get_categories() } @router.post("/node/preview") async def preview_node(request: NodePreviewRequest): """ 预览单个节点的输出 """ from ..core.node_registry import NodeRegistry import polars as pl try: # 获取实现类名(严格要求存在) class_name = request.node.class_name if not class_name: raise HTTPException(status_code=400, detail="Missing class_name for node preview") node_class = NodeRegistry.get(class_name) # 创建节点实例 node_instance = node_class() # 设置参数(兼容前端可能发送的 `data` 或 `params`) pdata = request.node.data or request.node.params or {} for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []): # 使用兼容的 set_param 接口(有的节点实现可能不同) try: node_instance.set_param(param_name, param_value) except Exception: setattr(node_instance, param_name, param_value) # 准备输入(预览模式下输入可能为空) inputs = {} # 准备上下文 context = { "user_id": "guest", "preview_mode": True } # 执行节点 result = await node_instance.wrap_process(inputs, context) # 提取输出 outputs = result.get("outputs", {}) # 如果输出是 DataFrame,转换为预览格式 preview_data = [] columns = [] for output_name, output_value in outputs.items(): if isinstance(output_value, pl.DataFrame): # polars DataFrame 预览 columns = list(output_value.columns) preview_data = output_value.head(request.limit).to_dicts() break elif isinstance(output_value, list): # 数组预览 preview_data = output_value[:request.limit] columns = ["value"] preview_data = [{"value": item} for item in preview_data] break else: # 单值 preview_data = [{"value": output_value}] columns = ["value"] break return { "class_name" : class_name, "status": "success", "columns": columns, "preview": preview_data, "context": result.get("context", {}) } except Exception as e: raise HTTPException( status_code=500, detail=f"节点预览失败: {str(e)}" ) @router.post("/graph/execute") async def execute_graph(request: GraphExecuteRequest): """ 执行完整的计算图 """ from ..core.workflow_executor import WorkflowExecutor import time try: start_time = time.time() # 创建工作流执行器 user_id = "guest" # TODO: 从请求或鉴权中获取用户ID executor = WorkflowExecutor(user_id=user_id) lock = _get_user_lock(user_id) if lock.locked(): raise HTTPException(status_code=409, detail="已有执行在进行中,请稍后重试") # 转换节点格式,严格要求 class_name nodes = [] for node in request.nodes: # 支持前端将参数放在 `data` 或 `params` 中 if not getattr(node, 'class_name', None) : raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name") nodes.append({ "id": node.id, "node_type": node.node_type, "class_name": node.class_name, # 将运行时参数统一为 `params` 字段传入执行器 "params": node.params or {}, }) # 转换连线格式 edges = [] for edge in request.edges: # 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle s_handle = edge.source_port or "output" t_handle = edge.target_port or "input" eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}" dimension_mode = edge.dimension_mode edges.append({ "id": eid, "source": edge.source, "source_port": s_handle, "target": edge.target, "target_port": t_handle, "dimension_mode": dimension_mode }) # 准备全局上下文 global_context = request.settings or {} # 执行工作流(加并发锁) async with lock: success, results = await executor.execute(nodes, edges, global_context) if not success: raise Exception(str(results)) return results except Exception as e: import traceback raise HTTPException( status_code=500, detail={ "error": str(e), "type": type(e).__name__, "traceback": traceback.format_exc() } ) @router.websocket("/ws/graph/execute") async def ws_execute_graph(websocket: WebSocket): """WebSocket 流式执行计算图。 客户端需先发送一条初始消息(JSON): { "nodes": [...], "edges": [...], "settings": { ... } } 后端随后按事件流发送:start / node_started / node_completed / progress / error / completed """ await websocket.accept() try: init_msg = await websocket.receive_json() from ..core.workflow_executor import WorkflowExecutor user_id = "guest" # TODO: 从握手/鉴权中获取用户 ID lock = _get_user_lock(user_id) if lock.locked(): # 拒绝并发执行,返回错误事件 await websocket.send_json({ "event": "error", "run_id": None, "data": {"message": "已有执行在进行中,请稍后重试"} }) await websocket.close(code=1008) return # 基于 HTTP 路由的同样规范进行转换 nodes_in = [] for node in init_msg.get("nodes", []): if not node.get("class_name"): # 兼容前端旧数据结构(可能把 class 放到 data.meta) class_name = node.get("class_name") or (node.get("data", {}).get("meta", {}).get("class_name")) if not class_name: await websocket.send_json({ "event": "error", "run_id": None, "data": {"message": f"Node {node.get('id')} missing class_name"} }) await websocket.close(code=1003) return node["class_name"] = class_name nodes_in.append({ "id": node.get("id"), "node_type": node.get("node_type"), "class_name": node.get("class_name"), "params": node.get("params") or {}, }) edges_in = [] for edge in init_msg.get("edges", []): s_handle = edge.get("source_port") or edge.get("sourceHandle") or "output" if isinstance(s_handle, str) and s_handle.startswith("output-"): s_handle = s_handle.replace("output-", "") t_handle = edge.get("target_port") or edge.get("targetHandle") or "input" if isinstance(t_handle, str) and t_handle.startswith("input-"): t_handle = t_handle.replace("input-", "") eid = edge.get("id") or f"{edge.get('source')}:{s_handle}->{edge.get('target')}:{t_handle}" edges_in.append({ "id": eid, "source": edge.get("source"), "source_port": s_handle, "target": edge.get("target"), "target_port": t_handle, "dimension_mode": edge.get("dimension_mode"), }) settings = init_msg.get("settings") or {} executor = WorkflowExecutor(user_id=user_id) # 在流式执行期间持有锁,确保同一用户不会并发多次执行 async with lock: async for event in executor.execute_stream(nodes_in, edges_in, settings): await websocket.send_json(event) await websocket.close() except WebSocketDisconnect: # 客户端主动断开 return except Exception as e: try: await websocket.send_json({ "event": "error", "run_id": None, "data": {"message": str(e)} }) finally: await websocket.close(code=1011) @router.get("/users/list") async def list_users(): """ 获取所有活跃用户列表(从数据库) """ from ..core.user_manager import list_users users = list_users() return { "users": users, "count": len(users) } @router.post("/users/add") async def add_user(payload: Dict[str, Any]): """ 添加新用户 """ from ..core.user_manager import load_users_db, save_users_db, create_user_workspace username = payload.get("username") display_name = payload.get("display_name", username) if not username: raise HTTPException(status_code=400, detail="Missing username") # 验证用户名格式 if not username.replace('_', '').replace('-', '').isalnum(): raise HTTPException(status_code=400, detail="Invalid username format") # 加载数据库 db = load_users_db() # 检查用户是否已存在 if any(u["username"] == username for u in db.get("users", [])): raise HTTPException(status_code=400, detail="User already exists") # 添加用户 from datetime import datetime db["users"].append({ "username": username, "display_name": display_name, "created_at": datetime.now().isoformat(), "active": True }) # 保存数据库 if not save_users_db(db): raise HTTPException(status_code=500, detail="Failed to save user database") # 创建用户工作空间 create_user_workspace(username) return { "success": True, "message": f"用户 {username} 创建成功" } @router.post("/nodes/save") async def save_node(payload: Dict[str, Any]): """ 保存节点配置到服务器 """ node_id = payload.get("nodeId") node_data = payload.get("nodeData") if not node_id: raise HTTPException(status_code=400, detail="Missing nodeId") # TODO: 实现节点保存逻辑(可以存到数据库或文件) return { "success": True, "message": f"节点 {node_id} 配置已保存", "nodeId": node_id } @router.post("/workflows/save") async def save_workflow(payload: Dict[str, Any]): """ 保存工作流到用户目录 """ filename = payload.get("filename") username = payload.get("username", "guest") workflow_data = payload.get("data") if not filename: raise HTTPException(status_code=400, detail="Missing filename") if not workflow_data: raise HTTPException(status_code=400, detail="Missing workflow data") # 验证文件名 is_valid, error_msg = validate_filename(filename) if not is_valid: raise HTTPException(status_code=400, detail=error_msg) # 确保.json扩展名 if not filename.endswith(".json"): filename = f"{filename}.json" # 构建保存路径 workflow_dir = get_user_path(username, "workflows") workflow_dir.mkdir(parents=True, exist_ok=True) workflow_file = workflow_dir / filename # 安全检查 user_root = get_user_path(username) if not is_safe_path(user_root, workflow_file): raise HTTPException(status_code=403, detail="Access denied") # 保存工作流 try: with open(workflow_file, "w", encoding="utf-8") as f: json.dump(workflow_data, f, indent=2, ensure_ascii=False) return { "success": True, "message": f"工作流已保存: {filename}", "path": str(workflow_file.relative_to(user_root.parent)) } except Exception as e: raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}") @router.get("/workflows/load") async def load_workflow(filename: str, username: str = "guest"): """ 加载工作流 """ # 清理文件名 filename = sanitize_path(filename) # 构建文件路径 workflow_file = get_user_path(username, f"workflows/{filename}") # 安全检查 user_root = get_user_path(username) if not is_safe_path(user_root, workflow_file): raise HTTPException(status_code=403, detail="Access denied") if not workflow_file.exists(): raise HTTPException(status_code=404, detail="Workflow not found") # 加载工作流 try: with open(workflow_file, "r", encoding="utf-8") as f: workflow_data = json.load(f) return { "success": True, "filename": filename, "data": workflow_data } except Exception as e: raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}") class SaveFunctionRequest(BaseModel): """保存函数请求""" function_name: str display_name: str description: Optional[str] = "" nodes: List[NodeSchema] edges: List[EdgeSchema] inputs: List[Dict[str, Any]] outputs: List[Dict[str, Any]] @router.post("/functions/save") async def save_function(request: SaveFunctionRequest): """ 保存工作流为函数 将选中的节点和连线封装为可复用的函数节点 """ from ..core.node_loader import load_function_nodes # 验证函数名 function_name = validate_filename(request.function_name) if not function_name: raise HTTPException(status_code=400, detail="Invalid function name") # 构建函数文件路径 project_root = Path(__file__).resolve().parents[3] functions_dir = project_root / "cloud" / "custom_nodes" / "functions" functions_dir.mkdir(parents=True, exist_ok=True) function_file = functions_dir / f"{function_name}.json" # 检查是否已存在 if function_file.exists(): raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在") # 构建函数数据 function_data = { "function_name": function_name, "display_name": request.display_name, "description": request.description, "inputs": request.inputs, "outputs": request.outputs, "nodes": [node.dict() for node in request.nodes], "edges": [edge.dict() for edge in request.edges] } # 保存函数 try: with open(function_file, "w", encoding="utf-8") as f: json.dump(function_data, f, indent=2, ensure_ascii=False) # 重新加载函数节点 load_result = load_function_nodes() return { "success": True, "message": f"函数已保存: {request.display_name}", "function_name": function_name, "path": str(function_file.relative_to(project_root)), "reload_result": load_result } except Exception as e: raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}") @router.get("/functions/list") async def list_functions(): """ 获取所有可用函数列表 """ project_root = Path(__file__).resolve().parents[3] functions_dir = project_root / "cloud" / "custom_nodes" / "functions" if not functions_dir.exists(): return {"functions": []} functions = [] for json_file in functions_dir.glob("*.json"): try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) functions.append({ "function_name": data.get("function_name", json_file.stem), "display_name": data.get("display_name", json_file.stem), "description": data.get("description", ""), "inputs": data.get("inputs", []), "outputs": data.get("outputs", []) }) except Exception as e: print(f"❌ 读取函数失败: {json_file.name} - {str(e)}") return {"functions": functions} @router.get("/functions/{function_name}") async def get_function(function_name: str): """ 获取函数详细信息(包含内部工作流) """ # 验证函数名 function_name = validate_filename(function_name) if not function_name: raise HTTPException(status_code=400, detail="Invalid function name") project_root = Path(__file__).resolve().parents[3] function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json" if not function_file.exists(): raise HTTPException(status_code=404, detail="Function not found") try: with open(function_file, 'r', encoding='utf-8') as f: function_data = json.load(f) return { "success": True, "data": function_data } except Exception as e: raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")