""" 图执行和查询相关 API """ from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import List, Dict, Any, Optional import json from pathlib import Path from server.app.core.node_base import DimensionMode from ..core.user_manager import get_user_path from ..core.security import is_safe_path, validate_filename, sanitize_path router = APIRouter() class NodeSchema(BaseModel): """节点数据结构""" id: str type: str class_name: Optional[str] = None function: Optional[str] = None # 支持前端发送 `data` 或历史字段 `params`,两者任选其一 data: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None class EdgeSchema(BaseModel): """连线数据结构""" # 允许前端不携带 id(后端会合成),但推荐前端提供稳定 id id: Optional[str] = None source: str target: str dimension_mode: DimensionMode = DimensionMode.NONE source_port: Optional[str] = None target_port: Optional[str] = None class GraphExecuteRequest(BaseModel): """图执行请求""" nodes: List[NodeSchema] edges: List[EdgeSchema] settings: Optional[Dict[str, Any]] = None class NodePreviewRequest(BaseModel): """节点预览请求""" node: NodeSchema limit: int = 10 @router.get("/plugins") async def get_plugins(): """ 获取所有可用的节点插件 从 NodeRegistry 动态获取 """ from ..core.node_registry import NodeRegistry # 获取所有注册的节点 all_nodes = NodeRegistry.list_all() # 转换为前端需要的格式 plugins = {} for class_name, node_class in all_nodes.items(): metadata = node_class.get_metadata() # 转换输入端口格式 inputs = metadata.get("inputs", []) # 转换输出端口格式 outputs = metadata.get("outputs", []) # 转换参数格式 param_schema = metadata.get("params", []) plugins[class_name] = { "display_name": metadata.get("display_name", class_name), "category": metadata.get("category", "Uncategorized"), "description": metadata.get("description", ""), "icon": metadata.get("icon", "📦"), "node_type": metadata.get("node_type", "REGULAR"), # 前端双击检测所需,使用实现类名 `class_name` "class_name": class_name, "node_logic": "standard", "supports_preview": True, "inputs": inputs, "outputs": outputs, "param_schema": param_schema, "context_vars": metadata.get("context_vars", {}), "cache_policy": metadata.get("cache_policy", "NONE") } return { "plugins": plugins, "total": len(plugins), "categories": NodeRegistry.get_categories() } @router.post("/node/preview") async def preview_node(request: NodePreviewRequest): """ 预览单个节点的输出 """ from ..core.node_registry import NodeRegistry import polars as pl try: # 获取实现类名(严格要求存在) class_name = request.node.class_name if not class_name: raise HTTPException(status_code=400, detail="Missing class_name for node preview") node_class = NodeRegistry.get(class_name) # 创建节点实例 node_instance = node_class() # 设置参数(兼容前端可能发送的 `data` 或 `params`) pdata = request.node.data or request.node.params or {} for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []): # 使用兼容的 set_param 接口(有的节点实现可能不同) try: node_instance.set_param(param_name, param_value) except Exception: setattr(node_instance, param_name, param_value) # 准备输入(预览模式下输入可能为空) inputs = {} # 准备上下文 context = { "user_id": "guest", "preview_mode": True } # 执行节点 result = await node_instance.wrap_process(inputs, context) # 提取输出 outputs = result.get("outputs", {}) # 如果输出是 DataFrame,转换为预览格式 preview_data = [] columns = [] for output_name, output_value in outputs.items(): if isinstance(output_value, pl.DataFrame): # polars DataFrame 预览 columns = list(output_value.columns) preview_data = output_value.head(request.limit).to_dicts() break elif isinstance(output_value, list): # 数组预览 preview_data = output_value[:request.limit] columns = ["value"] preview_data = [{"value": item} for item in preview_data] break else: # 单值 preview_data = [{"value": output_value}] columns = ["value"] break return { "class_name" : class_name, "status": "success", "columns": columns, "preview": preview_data, "context": result.get("context", {}) } except Exception as e: raise HTTPException( status_code=500, detail=f"节点预览失败: {str(e)}" ) @router.post("/graph/execute") async def execute_graph(request: GraphExecuteRequest): """ 执行完整的计算图 """ from ..core.workflow_executor import WorkflowExecutor import time try: start_time = time.time() # 创建工作流执行器 executor = WorkflowExecutor(user_id="guest") # TODO: 从请求获取用户ID # 转换节点格式,严格要求 class_name nodes = [] for node in request.nodes: # 支持前端将参数放在 `data` 或 `params` 中 pdata = node.data or node.params or {} if not getattr(node, 'class_name', None) and not (isinstance(pdata, dict) and pdata.get('meta', {}).get('class_name')): raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name") nodes.append({ "id": node.id, "type": node.type, "class_name": node.class_name or (pdata.get('meta', {}) or {}).get('class_name'), # 将运行时参数统一为 `params` 字段传入执行器 "params": pdata, "function": node.function, }) # 转换连线格式 edges = [] for edge in request.edges: # 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle s_handle = edge.source_port or "output" t_handle = edge.target_port or "input" eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}" dimension_mode = edge.dimension_mode edges.append({ "id": eid, "source": edge.source, "source_port": s_handle, "target": edge.target, "target_port": t_handle, "dimension_mode": dimension_mode }) # 准备全局上下文 global_context = request.settings or {} # 执行工作流 success, results = await executor.execute(nodes, edges, global_context) if not success: print(results) raise Exception(str(results)) execution_time = time.time() - start_time # 兼容处理:executor 可能返回旧格式(mapping node_id->info)或新格式报告(包含 node_infos/node_results) report = results node_infos = {} node_results = {} if isinstance(report, dict) and ('node_infos' in report or 'node_results' in report): node_infos = report.get('node_infos', {}) or {} node_results = report.get('node_results', {}) or {} elif isinstance(report, dict): # 退回兼容:report 可能本身就是 node_id->NodeExecutionInfo 映射 node_infos = report node_results = {} def _get_status(info): if info is None: return 'unknown' if isinstance(info, dict): return info.get('status') if hasattr(info, 'status'): return info.status.value return str(info) def _get_error(info): if info is None: return None if isinstance(info, dict): return info.get('error') if hasattr(info, 'error'): return info.error return None # 合并输出:优先从 node_infos 获取状态与错误,从 node_results 获取 outputs all_node_ids = set(list(node_infos.keys()) + list(node_results.keys())) import polars as pl # 把 node_results 中的 polars.DataFrame 转为可序列化的预览字典,防止 FastAPI 在响应序列化时报错 results_out: Dict[str, Any] = {} PREVIEW_LIMIT = 20 for nid in all_node_ids: info = node_infos.get(nid) outputs = node_results.get(nid) def _serialize_value(v): # polars DataFrame -> preview dict if isinstance(v, pl.DataFrame): return { "__type": "DataFrame", "columns": list(v.columns), "preview": v.head(PREVIEW_LIMIT).to_dicts(), "rows": v.height } # list/tuple -> recursively serialize if isinstance(v, (list, tuple)): return [_serialize_value(x) for x in v] # dict -> recursively serialize values if isinstance(v, dict): return {k: _serialize_value(val) for k, val in v.items()} return v ser_outputs = None if isinstance(outputs, dict): ser_outputs = {k: _serialize_value(v) for k, v in outputs.items()} else: ser_outputs = _serialize_value(outputs) results_out[nid] = { 'status': _get_status(info), 'outputs': ser_outputs, 'error': _get_error(info) } return { "status": "success", "message": f"成功执行 {len(request.nodes)} 个节点", "execution_time": round(execution_time, 3), "results": results_out, "stats": { "total_nodes": len(request.nodes), "total_edges": len(request.edges), "execution_order": list(results_out.keys()) } } except Exception as e: import traceback raise HTTPException( status_code=500, detail={ "error": str(e), "type": type(e).__name__, "traceback": traceback.format_exc() } ) @router.get("/users/list") async def list_users(): """ 获取所有活跃用户列表(从数据库) """ from ..core.user_manager import list_users users = list_users() return { "users": users, "count": len(users) } @router.post("/users/add") async def add_user(payload: Dict[str, Any]): """ 添加新用户 """ from ..core.user_manager import load_users_db, save_users_db, create_user_workspace username = payload.get("username") display_name = payload.get("display_name", username) if not username: raise HTTPException(status_code=400, detail="Missing username") # 验证用户名格式 if not username.replace('_', '').replace('-', '').isalnum(): raise HTTPException(status_code=400, detail="Invalid username format") # 加载数据库 db = load_users_db() # 检查用户是否已存在 if any(u["username"] == username for u in db.get("users", [])): raise HTTPException(status_code=400, detail="User already exists") # 添加用户 from datetime import datetime db["users"].append({ "username": username, "display_name": display_name, "created_at": datetime.now().isoformat(), "active": True }) # 保存数据库 if not save_users_db(db): raise HTTPException(status_code=500, detail="Failed to save user database") # 创建用户工作空间 create_user_workspace(username) return { "success": True, "message": f"用户 {username} 创建成功" } @router.post("/nodes/save") async def save_node(payload: Dict[str, Any]): """ 保存节点配置到服务器 """ node_id = payload.get("nodeId") node_data = payload.get("nodeData") if not node_id: raise HTTPException(status_code=400, detail="Missing nodeId") # TODO: 实现节点保存逻辑(可以存到数据库或文件) return { "success": True, "message": f"节点 {node_id} 配置已保存", "nodeId": node_id } @router.post("/workflows/save") async def save_workflow(payload: Dict[str, Any]): """ 保存工作流到用户目录 """ filename = payload.get("filename") username = payload.get("username", "guest") workflow_data = payload.get("data") if not filename: raise HTTPException(status_code=400, detail="Missing filename") if not workflow_data: raise HTTPException(status_code=400, detail="Missing workflow data") # 验证文件名 is_valid, error_msg = validate_filename(filename) if not is_valid: raise HTTPException(status_code=400, detail=error_msg) # 确保.json扩展名 if not filename.endswith(".json"): filename = f"{filename}.json" # 构建保存路径 workflow_dir = get_user_path(username, "workflows") workflow_dir.mkdir(parents=True, exist_ok=True) workflow_file = workflow_dir / filename # 安全检查 user_root = get_user_path(username) if not is_safe_path(user_root, workflow_file): raise HTTPException(status_code=403, detail="Access denied") # 保存工作流 try: with open(workflow_file, "w", encoding="utf-8") as f: json.dump(workflow_data, f, indent=2, ensure_ascii=False) return { "success": True, "message": f"工作流已保存: {filename}", "path": str(workflow_file.relative_to(user_root.parent)) } except Exception as e: raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}") @router.get("/workflows/load") async def load_workflow(filename: str, username: str = "guest"): """ 加载工作流 """ # 清理文件名 filename = sanitize_path(filename) # 构建文件路径 workflow_file = get_user_path(username, f"workflows/{filename}") # 安全检查 user_root = get_user_path(username) if not is_safe_path(user_root, workflow_file): raise HTTPException(status_code=403, detail="Access denied") if not workflow_file.exists(): raise HTTPException(status_code=404, detail="Workflow not found") # 加载工作流 try: with open(workflow_file, "r", encoding="utf-8") as f: workflow_data = json.load(f) return { "success": True, "filename": filename, "data": workflow_data } except Exception as e: raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}") class SaveFunctionRequest(BaseModel): """保存函数请求""" function_name: str display_name: str description: Optional[str] = "" nodes: List[NodeSchema] edges: List[EdgeSchema] inputs: List[Dict[str, Any]] outputs: List[Dict[str, Any]] @router.post("/functions/save") async def save_function(request: SaveFunctionRequest): """ 保存工作流为函数 将选中的节点和连线封装为可复用的函数节点 """ from ..core.node_loader import load_function_nodes # 验证函数名 function_name = validate_filename(request.function_name) if not function_name: raise HTTPException(status_code=400, detail="Invalid function name") # 构建函数文件路径 project_root = Path(__file__).resolve().parents[3] functions_dir = project_root / "cloud" / "custom_nodes" / "functions" functions_dir.mkdir(parents=True, exist_ok=True) function_file = functions_dir / f"{function_name}.json" # 检查是否已存在 if function_file.exists(): raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在") # 构建函数数据 function_data = { "function_name": function_name, "display_name": request.display_name, "description": request.description, "inputs": request.inputs, "outputs": request.outputs, "nodes": [node.dict() for node in request.nodes], "edges": [edge.dict() for edge in request.edges] } # 保存函数 try: with open(function_file, "w", encoding="utf-8") as f: json.dump(function_data, f, indent=2, ensure_ascii=False) # 重新加载函数节点 load_result = load_function_nodes() return { "success": True, "message": f"函数已保存: {request.display_name}", "function_name": function_name, "path": str(function_file.relative_to(project_root)), "reload_result": load_result } except Exception as e: raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}") @router.get("/functions/list") async def list_functions(): """ 获取所有可用函数列表 """ project_root = Path(__file__).resolve().parents[3] functions_dir = project_root / "cloud" / "custom_nodes" / "functions" if not functions_dir.exists(): return {"functions": []} functions = [] for json_file in functions_dir.glob("*.json"): try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) functions.append({ "function_name": data.get("function_name", json_file.stem), "display_name": data.get("display_name", json_file.stem), "description": data.get("description", ""), "inputs": data.get("inputs", []), "outputs": data.get("outputs", []) }) except Exception as e: print(f"❌ 读取函数失败: {json_file.name} - {str(e)}") return {"functions": functions} @router.get("/functions/{function_name}") async def get_function(function_name: str): """ 获取函数详细信息(包含内部工作流) """ # 验证函数名 function_name = validate_filename(function_name) if not function_name: raise HTTPException(status_code=400, detail="Invalid function name") project_root = Path(__file__).resolve().parents[3] function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json" if not function_file.exists(): raise HTTPException(status_code=404, detail="Function not found") try: with open(function_file, 'r', encoding='utf-8') as f: function_data = json.load(f) return { "success": True, "data": function_data } except Exception as e: raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")