2026-01-09 21:37:02 +08:00
|
|
|
|
"""
|
|
|
|
|
|
图执行和查询相关 API
|
|
|
|
|
|
"""
|
2026-01-12 03:32:51 +08:00
|
|
|
|
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
2026-01-09 21:37:02 +08:00
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
import json
|
|
|
|
|
|
from pathlib import Path
|
2026-01-12 03:32:51 +08:00
|
|
|
|
import asyncio
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
2026-01-10 19:08:49 +08:00
|
|
|
|
from server.app.core.node_base import DimensionMode, NodeType
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
from ..core.user_manager import get_user_path
|
|
|
|
|
|
from ..core.security import is_safe_path, validate_filename, sanitize_path
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
2026-01-12 03:32:51 +08:00
|
|
|
|
# 简易用户级并发执行锁,防止重复触发执行导致后端过载
|
|
|
|
|
|
# 目前按 user_id 粒度(默认 guest)。未来可扩展为 per-workspace 或 per-graph。
|
|
|
|
|
|
_EXECUTION_LOCKS: Dict[str, asyncio.Lock] = {}
|
|
|
|
|
|
|
|
|
|
|
|
def _get_user_lock(user_id: str = "guest") -> asyncio.Lock:
|
|
|
|
|
|
lock = _EXECUTION_LOCKS.get(user_id)
|
|
|
|
|
|
if lock is None:
|
|
|
|
|
|
lock = asyncio.Lock()
|
|
|
|
|
|
_EXECUTION_LOCKS[user_id] = lock
|
|
|
|
|
|
return lock
|
|
|
|
|
|
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
class NodeSchema(BaseModel):
|
|
|
|
|
|
"""节点数据结构"""
|
|
|
|
|
|
id: str
|
2026-01-10 19:08:49 +08:00
|
|
|
|
node_type: NodeType
|
2026-01-09 21:37:02 +08:00
|
|
|
|
class_name: Optional[str] = None
|
|
|
|
|
|
params: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EdgeSchema(BaseModel):
|
|
|
|
|
|
"""连线数据结构"""
|
|
|
|
|
|
# 允许前端不携带 id(后端会合成),但推荐前端提供稳定 id
|
|
|
|
|
|
id: Optional[str] = None
|
|
|
|
|
|
source: str
|
|
|
|
|
|
target: str
|
|
|
|
|
|
dimension_mode: DimensionMode = DimensionMode.NONE
|
|
|
|
|
|
source_port: Optional[str] = None
|
|
|
|
|
|
target_port: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphExecuteRequest(BaseModel):
|
|
|
|
|
|
"""图执行请求"""
|
|
|
|
|
|
nodes: List[NodeSchema]
|
|
|
|
|
|
edges: List[EdgeSchema]
|
|
|
|
|
|
settings: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodePreviewRequest(BaseModel):
|
|
|
|
|
|
"""节点预览请求"""
|
|
|
|
|
|
node: NodeSchema
|
|
|
|
|
|
limit: int = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/plugins")
|
|
|
|
|
|
async def get_plugins():
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取所有可用的节点插件
|
|
|
|
|
|
从 NodeRegistry 动态获取
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.node_registry import NodeRegistry
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有注册的节点
|
|
|
|
|
|
all_nodes = NodeRegistry.list_all()
|
|
|
|
|
|
|
|
|
|
|
|
# 转换为前端需要的格式
|
|
|
|
|
|
plugins = {}
|
|
|
|
|
|
for class_name, node_class in all_nodes.items():
|
|
|
|
|
|
metadata = node_class.get_metadata()
|
|
|
|
|
|
|
|
|
|
|
|
# 转换输入端口格式
|
|
|
|
|
|
inputs = metadata.get("inputs", [])
|
|
|
|
|
|
|
|
|
|
|
|
# 转换输出端口格式
|
|
|
|
|
|
outputs = metadata.get("outputs", [])
|
|
|
|
|
|
|
|
|
|
|
|
# 转换参数格式
|
|
|
|
|
|
param_schema = metadata.get("params", [])
|
|
|
|
|
|
|
|
|
|
|
|
plugins[class_name] = {
|
|
|
|
|
|
"display_name": metadata.get("display_name", class_name),
|
|
|
|
|
|
"category": metadata.get("category", "Uncategorized"),
|
|
|
|
|
|
"description": metadata.get("description", ""),
|
|
|
|
|
|
"icon": metadata.get("icon", "📦"),
|
|
|
|
|
|
"node_type": metadata.get("node_type", "REGULAR"),
|
|
|
|
|
|
# 前端双击检测所需,使用实现类名 `class_name`
|
|
|
|
|
|
"class_name": class_name,
|
|
|
|
|
|
"node_logic": "standard",
|
|
|
|
|
|
"supports_preview": True,
|
|
|
|
|
|
"inputs": inputs,
|
|
|
|
|
|
"outputs": outputs,
|
|
|
|
|
|
"param_schema": param_schema,
|
|
|
|
|
|
"context_vars": metadata.get("context_vars", {}),
|
|
|
|
|
|
"cache_policy": metadata.get("cache_policy", "NONE")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"plugins": plugins,
|
|
|
|
|
|
"total": len(plugins),
|
|
|
|
|
|
"categories": NodeRegistry.get_categories()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/node/preview")
|
|
|
|
|
|
async def preview_node(request: NodePreviewRequest):
|
|
|
|
|
|
"""
|
|
|
|
|
|
预览单个节点的输出
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.node_registry import NodeRegistry
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取实现类名(严格要求存在)
|
|
|
|
|
|
class_name = request.node.class_name
|
|
|
|
|
|
if not class_name:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing class_name for node preview")
|
|
|
|
|
|
node_class = NodeRegistry.get(class_name)
|
|
|
|
|
|
# 创建节点实例
|
|
|
|
|
|
node_instance = node_class()
|
|
|
|
|
|
# 设置参数(兼容前端可能发送的 `data` 或 `params`)
|
|
|
|
|
|
pdata = request.node.data or request.node.params or {}
|
|
|
|
|
|
for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []):
|
|
|
|
|
|
# 使用兼容的 set_param 接口(有的节点实现可能不同)
|
|
|
|
|
|
try:
|
|
|
|
|
|
node_instance.set_param(param_name, param_value)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
setattr(node_instance, param_name, param_value)
|
|
|
|
|
|
# 准备输入(预览模式下输入可能为空)
|
|
|
|
|
|
inputs = {}
|
|
|
|
|
|
# 准备上下文
|
|
|
|
|
|
context = {
|
|
|
|
|
|
"user_id": "guest",
|
|
|
|
|
|
"preview_mode": True
|
|
|
|
|
|
}
|
|
|
|
|
|
# 执行节点
|
|
|
|
|
|
result = await node_instance.wrap_process(inputs, context)
|
|
|
|
|
|
# 提取输出
|
|
|
|
|
|
outputs = result.get("outputs", {})
|
|
|
|
|
|
# 如果输出是 DataFrame,转换为预览格式
|
|
|
|
|
|
preview_data = []
|
|
|
|
|
|
columns = []
|
|
|
|
|
|
for output_name, output_value in outputs.items():
|
|
|
|
|
|
if isinstance(output_value, pl.DataFrame):
|
|
|
|
|
|
# polars DataFrame 预览
|
|
|
|
|
|
columns = list(output_value.columns)
|
|
|
|
|
|
preview_data = output_value.head(request.limit).to_dicts()
|
|
|
|
|
|
break
|
|
|
|
|
|
elif isinstance(output_value, list):
|
|
|
|
|
|
# 数组预览
|
|
|
|
|
|
preview_data = output_value[:request.limit]
|
|
|
|
|
|
columns = ["value"]
|
|
|
|
|
|
preview_data = [{"value": item} for item in preview_data]
|
|
|
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 单值
|
|
|
|
|
|
preview_data = [{"value": output_value}]
|
|
|
|
|
|
columns = ["value"]
|
|
|
|
|
|
break
|
|
|
|
|
|
return {
|
|
|
|
|
|
"class_name" : class_name,
|
|
|
|
|
|
"status": "success",
|
|
|
|
|
|
"columns": columns,
|
|
|
|
|
|
"preview": preview_data,
|
|
|
|
|
|
"context": result.get("context", {})
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=500,
|
|
|
|
|
|
detail=f"节点预览失败: {str(e)}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/graph/execute")
|
|
|
|
|
|
async def execute_graph(request: GraphExecuteRequest):
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行完整的计算图
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.workflow_executor import WorkflowExecutor
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建工作流执行器
|
2026-01-12 03:32:51 +08:00
|
|
|
|
user_id = "guest" # TODO: 从请求或鉴权中获取用户ID
|
|
|
|
|
|
executor = WorkflowExecutor(user_id=user_id)
|
|
|
|
|
|
lock = _get_user_lock(user_id)
|
|
|
|
|
|
if lock.locked():
|
|
|
|
|
|
raise HTTPException(status_code=409, detail="已有执行在进行中,请稍后重试")
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
# 转换节点格式,严格要求 class_name
|
|
|
|
|
|
nodes = []
|
|
|
|
|
|
for node in request.nodes:
|
|
|
|
|
|
# 支持前端将参数放在 `data` 或 `params` 中
|
2026-01-10 19:08:49 +08:00
|
|
|
|
if not getattr(node, 'class_name', None) :
|
2026-01-09 21:37:02 +08:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name")
|
|
|
|
|
|
nodes.append({
|
|
|
|
|
|
"id": node.id,
|
2026-01-10 19:08:49 +08:00
|
|
|
|
"node_type": node.node_type,
|
|
|
|
|
|
"class_name": node.class_name,
|
2026-01-09 21:37:02 +08:00
|
|
|
|
# 将运行时参数统一为 `params` 字段传入执行器
|
2026-01-10 19:08:49 +08:00
|
|
|
|
"params": node.params or {},
|
2026-01-09 21:37:02 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 转换连线格式
|
|
|
|
|
|
edges = []
|
|
|
|
|
|
for edge in request.edges:
|
|
|
|
|
|
# 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle
|
|
|
|
|
|
s_handle = edge.source_port or "output"
|
|
|
|
|
|
t_handle = edge.target_port or "input"
|
|
|
|
|
|
eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}"
|
|
|
|
|
|
dimension_mode = edge.dimension_mode
|
|
|
|
|
|
edges.append({
|
|
|
|
|
|
"id": eid,
|
|
|
|
|
|
"source": edge.source,
|
|
|
|
|
|
"source_port": s_handle,
|
|
|
|
|
|
"target": edge.target,
|
|
|
|
|
|
"target_port": t_handle,
|
|
|
|
|
|
"dimension_mode": dimension_mode
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 准备全局上下文
|
|
|
|
|
|
global_context = request.settings or {}
|
|
|
|
|
|
|
2026-01-12 03:32:51 +08:00
|
|
|
|
# 执行工作流(加并发锁)
|
|
|
|
|
|
async with lock:
|
|
|
|
|
|
success, results = await executor.execute(nodes, edges, global_context)
|
2026-01-09 21:37:02 +08:00
|
|
|
|
if not success:
|
|
|
|
|
|
raise Exception(str(results))
|
2026-01-10 19:08:49 +08:00
|
|
|
|
return results
|
2026-01-09 21:37:02 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=500,
|
|
|
|
|
|
detail={
|
|
|
|
|
|
"error": str(e),
|
|
|
|
|
|
"type": type(e).__name__,
|
|
|
|
|
|
"traceback": traceback.format_exc()
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-12 03:32:51 +08:00
|
|
|
|
@router.websocket("/ws/graph/execute")
|
|
|
|
|
|
async def ws_execute_graph(websocket: WebSocket):
|
|
|
|
|
|
"""WebSocket 流式执行计算图。
|
|
|
|
|
|
|
|
|
|
|
|
客户端需先发送一条初始消息(JSON):
|
|
|
|
|
|
{
|
|
|
|
|
|
"nodes": [...],
|
|
|
|
|
|
"edges": [...],
|
|
|
|
|
|
"settings": { ... }
|
|
|
|
|
|
}
|
|
|
|
|
|
后端随后按事件流发送:start / node_started / node_completed / progress / error / completed
|
|
|
|
|
|
"""
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
try:
|
|
|
|
|
|
init_msg = await websocket.receive_json()
|
|
|
|
|
|
from ..core.workflow_executor import WorkflowExecutor
|
|
|
|
|
|
user_id = "guest" # TODO: 从握手/鉴权中获取用户 ID
|
|
|
|
|
|
lock = _get_user_lock(user_id)
|
|
|
|
|
|
if lock.locked():
|
|
|
|
|
|
# 拒绝并发执行,返回错误事件
|
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
|
"event": "error",
|
|
|
|
|
|
"run_id": None,
|
|
|
|
|
|
"data": {"message": "已有执行在进行中,请稍后重试"}
|
|
|
|
|
|
})
|
|
|
|
|
|
await websocket.close(code=1008)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 基于 HTTP 路由的同样规范进行转换
|
|
|
|
|
|
nodes_in = []
|
|
|
|
|
|
for node in init_msg.get("nodes", []):
|
|
|
|
|
|
if not node.get("class_name"):
|
|
|
|
|
|
# 兼容前端旧数据结构(可能把 class 放到 data.meta)
|
|
|
|
|
|
class_name = node.get("class_name") or (node.get("data", {}).get("meta", {}).get("class_name"))
|
|
|
|
|
|
if not class_name:
|
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
|
"event": "error",
|
|
|
|
|
|
"run_id": None,
|
|
|
|
|
|
"data": {"message": f"Node {node.get('id')} missing class_name"}
|
|
|
|
|
|
})
|
|
|
|
|
|
await websocket.close(code=1003)
|
|
|
|
|
|
return
|
|
|
|
|
|
node["class_name"] = class_name
|
|
|
|
|
|
nodes_in.append({
|
|
|
|
|
|
"id": node.get("id"),
|
|
|
|
|
|
"node_type": node.get("node_type"),
|
|
|
|
|
|
"class_name": node.get("class_name"),
|
|
|
|
|
|
"params": node.get("params") or {},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
edges_in = []
|
|
|
|
|
|
for edge in init_msg.get("edges", []):
|
|
|
|
|
|
s_handle = edge.get("source_port") or edge.get("sourceHandle") or "output"
|
|
|
|
|
|
if isinstance(s_handle, str) and s_handle.startswith("output-"):
|
|
|
|
|
|
s_handle = s_handle.replace("output-", "")
|
|
|
|
|
|
t_handle = edge.get("target_port") or edge.get("targetHandle") or "input"
|
|
|
|
|
|
if isinstance(t_handle, str) and t_handle.startswith("input-"):
|
|
|
|
|
|
t_handle = t_handle.replace("input-", "")
|
|
|
|
|
|
eid = edge.get("id") or f"{edge.get('source')}:{s_handle}->{edge.get('target')}:{t_handle}"
|
|
|
|
|
|
edges_in.append({
|
|
|
|
|
|
"id": eid,
|
|
|
|
|
|
"source": edge.get("source"),
|
|
|
|
|
|
"source_port": s_handle,
|
|
|
|
|
|
"target": edge.get("target"),
|
|
|
|
|
|
"target_port": t_handle,
|
2026-01-12 11:13:01 +08:00
|
|
|
|
"dimension_mode": DimensionMode(edge.get("dimension_mode")),
|
2026-01-12 03:32:51 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
settings = init_msg.get("settings") or {}
|
|
|
|
|
|
executor = WorkflowExecutor(user_id=user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 在流式执行期间持有锁,确保同一用户不会并发多次执行
|
|
|
|
|
|
async with lock:
|
|
|
|
|
|
async for event in executor.execute_stream(nodes_in, edges_in, settings):
|
|
|
|
|
|
await websocket.send_json(event)
|
|
|
|
|
|
await websocket.close()
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
|
# 客户端主动断开
|
|
|
|
|
|
return
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
|
"event": "error",
|
|
|
|
|
|
"run_id": None,
|
|
|
|
|
|
"data": {"message": str(e)}
|
|
|
|
|
|
})
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await websocket.close(code=1011)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-09 21:37:02 +08:00
|
|
|
|
@router.get("/users/list")
|
|
|
|
|
|
async def list_users():
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取所有活跃用户列表(从数据库)
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.user_manager import list_users
|
|
|
|
|
|
users = list_users()
|
|
|
|
|
|
return {
|
|
|
|
|
|
"users": users,
|
|
|
|
|
|
"count": len(users)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/users/add")
|
|
|
|
|
|
async def add_user(payload: Dict[str, Any]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
添加新用户
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.user_manager import load_users_db, save_users_db, create_user_workspace
|
|
|
|
|
|
|
|
|
|
|
|
username = payload.get("username")
|
|
|
|
|
|
display_name = payload.get("display_name", username)
|
|
|
|
|
|
|
|
|
|
|
|
if not username:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing username")
|
|
|
|
|
|
|
|
|
|
|
|
# 验证用户名格式
|
|
|
|
|
|
if not username.replace('_', '').replace('-', '').isalnum():
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid username format")
|
|
|
|
|
|
|
|
|
|
|
|
# 加载数据库
|
|
|
|
|
|
db = load_users_db()
|
|
|
|
|
|
|
|
|
|
|
|
# 检查用户是否已存在
|
|
|
|
|
|
if any(u["username"] == username for u in db.get("users", [])):
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="User already exists")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加用户
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
db["users"].append({
|
|
|
|
|
|
"username": username,
|
|
|
|
|
|
"display_name": display_name,
|
|
|
|
|
|
"created_at": datetime.now().isoformat(),
|
|
|
|
|
|
"active": True
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 保存数据库
|
|
|
|
|
|
if not save_users_db(db):
|
|
|
|
|
|
raise HTTPException(status_code=500, detail="Failed to save user database")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建用户工作空间
|
|
|
|
|
|
create_user_workspace(username)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"message": f"用户 {username} 创建成功"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/nodes/save")
|
|
|
|
|
|
async def save_node(payload: Dict[str, Any]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
保存节点配置到服务器
|
|
|
|
|
|
"""
|
|
|
|
|
|
node_id = payload.get("nodeId")
|
|
|
|
|
|
node_data = payload.get("nodeData")
|
|
|
|
|
|
|
|
|
|
|
|
if not node_id:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing nodeId")
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: 实现节点保存逻辑(可以存到数据库或文件)
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"message": f"节点 {node_id} 配置已保存",
|
|
|
|
|
|
"nodeId": node_id
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/workflows/save")
|
|
|
|
|
|
async def save_workflow(payload: Dict[str, Any]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
保存工作流到用户目录
|
|
|
|
|
|
"""
|
|
|
|
|
|
filename = payload.get("filename")
|
|
|
|
|
|
username = payload.get("username", "guest")
|
|
|
|
|
|
workflow_data = payload.get("data")
|
|
|
|
|
|
|
|
|
|
|
|
if not filename:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing filename")
|
|
|
|
|
|
|
|
|
|
|
|
if not workflow_data:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing workflow data")
|
|
|
|
|
|
|
|
|
|
|
|
# 验证文件名
|
|
|
|
|
|
is_valid, error_msg = validate_filename(filename)
|
|
|
|
|
|
if not is_valid:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
# 确保.json扩展名
|
|
|
|
|
|
if not filename.endswith(".json"):
|
|
|
|
|
|
filename = f"{filename}.json"
|
|
|
|
|
|
|
|
|
|
|
|
# 构建保存路径
|
|
|
|
|
|
workflow_dir = get_user_path(username, "workflows")
|
|
|
|
|
|
workflow_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
workflow_file = workflow_dir / filename
|
|
|
|
|
|
|
|
|
|
|
|
# 安全检查
|
|
|
|
|
|
user_root = get_user_path(username)
|
|
|
|
|
|
if not is_safe_path(user_root, workflow_file):
|
|
|
|
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
|
|
|
|
|
|
|
|
# 保存工作流
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(workflow_file, "w", encoding="utf-8") as f:
|
|
|
|
|
|
json.dump(workflow_data, f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"message": f"工作流已保存: {filename}",
|
|
|
|
|
|
"path": str(workflow_file.relative_to(user_root.parent))
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/workflows/load")
|
|
|
|
|
|
async def load_workflow(filename: str, username: str = "guest"):
|
|
|
|
|
|
"""
|
|
|
|
|
|
加载工作流
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 清理文件名
|
|
|
|
|
|
filename = sanitize_path(filename)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建文件路径
|
|
|
|
|
|
workflow_file = get_user_path(username, f"workflows/{filename}")
|
|
|
|
|
|
|
|
|
|
|
|
# 安全检查
|
|
|
|
|
|
user_root = get_user_path(username)
|
|
|
|
|
|
if not is_safe_path(user_root, workflow_file):
|
|
|
|
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
|
|
|
|
|
|
|
|
if not workflow_file.exists():
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
|
|
|
|
|
|
|
|
|
|
|
# 加载工作流
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(workflow_file, "r", encoding="utf-8") as f:
|
|
|
|
|
|
workflow_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"filename": filename,
|
|
|
|
|
|
"data": workflow_data
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SaveFunctionRequest(BaseModel):
|
|
|
|
|
|
"""保存函数请求"""
|
|
|
|
|
|
function_name: str
|
|
|
|
|
|
display_name: str
|
|
|
|
|
|
description: Optional[str] = ""
|
|
|
|
|
|
nodes: List[NodeSchema]
|
|
|
|
|
|
edges: List[EdgeSchema]
|
|
|
|
|
|
inputs: List[Dict[str, Any]]
|
|
|
|
|
|
outputs: List[Dict[str, Any]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/functions/save")
|
|
|
|
|
|
async def save_function(request: SaveFunctionRequest):
|
|
|
|
|
|
"""
|
|
|
|
|
|
保存工作流为函数
|
|
|
|
|
|
将选中的节点和连线封装为可复用的函数节点
|
|
|
|
|
|
"""
|
|
|
|
|
|
from ..core.node_loader import load_function_nodes
|
|
|
|
|
|
|
|
|
|
|
|
# 验证函数名
|
|
|
|
|
|
function_name = validate_filename(request.function_name)
|
|
|
|
|
|
if not function_name:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid function name")
|
|
|
|
|
|
|
|
|
|
|
|
# 构建函数文件路径
|
|
|
|
|
|
project_root = Path(__file__).resolve().parents[3]
|
|
|
|
|
|
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
|
|
|
|
|
|
functions_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
function_file = functions_dir / f"{function_name}.json"
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否已存在
|
|
|
|
|
|
if function_file.exists():
|
|
|
|
|
|
raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 构建函数数据
|
|
|
|
|
|
function_data = {
|
|
|
|
|
|
"function_name": function_name,
|
|
|
|
|
|
"display_name": request.display_name,
|
|
|
|
|
|
"description": request.description,
|
|
|
|
|
|
"inputs": request.inputs,
|
|
|
|
|
|
"outputs": request.outputs,
|
|
|
|
|
|
"nodes": [node.dict() for node in request.nodes],
|
|
|
|
|
|
"edges": [edge.dict() for edge in request.edges]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 保存函数
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(function_file, "w", encoding="utf-8") as f:
|
|
|
|
|
|
json.dump(function_data, f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# 重新加载函数节点
|
|
|
|
|
|
load_result = load_function_nodes()
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"message": f"函数已保存: {request.display_name}",
|
|
|
|
|
|
"function_name": function_name,
|
|
|
|
|
|
"path": str(function_file.relative_to(project_root)),
|
|
|
|
|
|
"reload_result": load_result
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/functions/list")
|
|
|
|
|
|
async def list_functions():
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取所有可用函数列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
project_root = Path(__file__).resolve().parents[3]
|
|
|
|
|
|
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
|
|
|
|
|
|
|
|
|
|
|
|
if not functions_dir.exists():
|
|
|
|
|
|
return {"functions": []}
|
|
|
|
|
|
|
|
|
|
|
|
functions = []
|
|
|
|
|
|
for json_file in functions_dir.glob("*.json"):
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
functions.append({
|
|
|
|
|
|
"function_name": data.get("function_name", json_file.stem),
|
|
|
|
|
|
"display_name": data.get("display_name", json_file.stem),
|
|
|
|
|
|
"description": data.get("description", ""),
|
|
|
|
|
|
"inputs": data.get("inputs", []),
|
|
|
|
|
|
"outputs": data.get("outputs", [])
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"❌ 读取函数失败: {json_file.name} - {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
return {"functions": functions}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/functions/{function_name}")
|
|
|
|
|
|
async def get_function(function_name: str):
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取函数详细信息(包含内部工作流)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 验证函数名
|
|
|
|
|
|
function_name = validate_filename(function_name)
|
|
|
|
|
|
if not function_name:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid function name")
|
|
|
|
|
|
|
|
|
|
|
|
project_root = Path(__file__).resolve().parents[3]
|
|
|
|
|
|
function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json"
|
|
|
|
|
|
|
|
|
|
|
|
if not function_file.exists():
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="Function not found")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(function_file, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
function_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"data": function_data
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")
|
|
|
|
|
|
|