TraceStudio-dev/server/app/api/endpoints_graph.py
Boshuang Zhao 5790ec164f add web v2
2026-01-10 19:08:49 +08:00

515 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图执行和查询相关 API
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import json
from pathlib import Path
from server.app.core.node_base import DimensionMode, NodeType
from ..core.user_manager import get_user_path
from ..core.security import is_safe_path, validate_filename, sanitize_path
router = APIRouter()
class NodeSchema(BaseModel):
"""节点数据结构"""
id: str
node_type: NodeType
class_name: Optional[str] = None
params: Optional[Dict[str, Any]] = None
class EdgeSchema(BaseModel):
"""连线数据结构"""
# 允许前端不携带 id后端会合成但推荐前端提供稳定 id
id: Optional[str] = None
source: str
target: str
dimension_mode: DimensionMode = DimensionMode.NONE
source_port: Optional[str] = None
target_port: Optional[str] = None
class GraphExecuteRequest(BaseModel):
"""图执行请求"""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
settings: Optional[Dict[str, Any]] = None
class NodePreviewRequest(BaseModel):
"""节点预览请求"""
node: NodeSchema
limit: int = 10
@router.get("/plugins")
async def get_plugins():
"""
获取所有可用的节点插件
从 NodeRegistry 动态获取
"""
from ..core.node_registry import NodeRegistry
# 获取所有注册的节点
all_nodes = NodeRegistry.list_all()
# 转换为前端需要的格式
plugins = {}
for class_name, node_class in all_nodes.items():
metadata = node_class.get_metadata()
# 转换输入端口格式
inputs = metadata.get("inputs", [])
# 转换输出端口格式
outputs = metadata.get("outputs", [])
# 转换参数格式
param_schema = metadata.get("params", [])
plugins[class_name] = {
"display_name": metadata.get("display_name", class_name),
"category": metadata.get("category", "Uncategorized"),
"description": metadata.get("description", ""),
"icon": metadata.get("icon", "📦"),
"node_type": metadata.get("node_type", "REGULAR"),
# 前端双击检测所需,使用实现类名 `class_name`
"class_name": class_name,
"node_logic": "standard",
"supports_preview": True,
"inputs": inputs,
"outputs": outputs,
"param_schema": param_schema,
"context_vars": metadata.get("context_vars", {}),
"cache_policy": metadata.get("cache_policy", "NONE")
}
return {
"plugins": plugins,
"total": len(plugins),
"categories": NodeRegistry.get_categories()
}
@router.post("/node/preview")
async def preview_node(request: NodePreviewRequest):
"""
预览单个节点的输出
"""
from ..core.node_registry import NodeRegistry
import polars as pl
try:
# 获取实现类名(严格要求存在)
class_name = request.node.class_name
if not class_name:
raise HTTPException(status_code=400, detail="Missing class_name for node preview")
node_class = NodeRegistry.get(class_name)
# 创建节点实例
node_instance = node_class()
# 设置参数(兼容前端可能发送的 `data` 或 `params`
pdata = request.node.data or request.node.params or {}
for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []):
# 使用兼容的 set_param 接口(有的节点实现可能不同)
try:
node_instance.set_param(param_name, param_value)
except Exception:
setattr(node_instance, param_name, param_value)
# 准备输入(预览模式下输入可能为空)
inputs = {}
# 准备上下文
context = {
"user_id": "guest",
"preview_mode": True
}
# 执行节点
result = await node_instance.wrap_process(inputs, context)
# 提取输出
outputs = result.get("outputs", {})
# 如果输出是 DataFrame转换为预览格式
preview_data = []
columns = []
for output_name, output_value in outputs.items():
if isinstance(output_value, pl.DataFrame):
# polars DataFrame 预览
columns = list(output_value.columns)
preview_data = output_value.head(request.limit).to_dicts()
break
elif isinstance(output_value, list):
# 数组预览
preview_data = output_value[:request.limit]
columns = ["value"]
preview_data = [{"value": item} for item in preview_data]
break
else:
# 单值
preview_data = [{"value": output_value}]
columns = ["value"]
break
return {
"class_name" : class_name,
"status": "success",
"columns": columns,
"preview": preview_data,
"context": result.get("context", {})
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"节点预览失败: {str(e)}"
)
@router.post("/graph/execute")
async def execute_graph(request: GraphExecuteRequest):
"""
执行完整的计算图
"""
from ..core.workflow_executor import WorkflowExecutor
import time
try:
start_time = time.time()
# 创建工作流执行器
executor = WorkflowExecutor(user_id="guest") # TODO: 从请求获取用户ID
# 转换节点格式,严格要求 class_name
nodes = []
for node in request.nodes:
# 支持前端将参数放在 `data` 或 `params` 中
if not getattr(node, 'class_name', None) :
raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name")
nodes.append({
"id": node.id,
"node_type": node.node_type,
"class_name": node.class_name,
# 将运行时参数统一为 `params` 字段传入执行器
"params": node.params or {},
})
# 转换连线格式
edges = []
for edge in request.edges:
# 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle
s_handle = edge.source_port or "output"
t_handle = edge.target_port or "input"
eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}"
dimension_mode = edge.dimension_mode
edges.append({
"id": eid,
"source": edge.source,
"source_port": s_handle,
"target": edge.target,
"target_port": t_handle,
"dimension_mode": dimension_mode
})
# 准备全局上下文
global_context = request.settings or {}
# 执行工作流
success, results = await executor.execute(nodes, edges, global_context)
if not success:
raise Exception(str(results))
return results
except Exception as e:
import traceback
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
@router.get("/users/list")
async def list_users():
"""
获取所有活跃用户列表(从数据库)
"""
from ..core.user_manager import list_users
users = list_users()
return {
"users": users,
"count": len(users)
}
@router.post("/users/add")
async def add_user(payload: Dict[str, Any]):
"""
添加新用户
"""
from ..core.user_manager import load_users_db, save_users_db, create_user_workspace
username = payload.get("username")
display_name = payload.get("display_name", username)
if not username:
raise HTTPException(status_code=400, detail="Missing username")
# 验证用户名格式
if not username.replace('_', '').replace('-', '').isalnum():
raise HTTPException(status_code=400, detail="Invalid username format")
# 加载数据库
db = load_users_db()
# 检查用户是否已存在
if any(u["username"] == username for u in db.get("users", [])):
raise HTTPException(status_code=400, detail="User already exists")
# 添加用户
from datetime import datetime
db["users"].append({
"username": username,
"display_name": display_name,
"created_at": datetime.now().isoformat(),
"active": True
})
# 保存数据库
if not save_users_db(db):
raise HTTPException(status_code=500, detail="Failed to save user database")
# 创建用户工作空间
create_user_workspace(username)
return {
"success": True,
"message": f"用户 {username} 创建成功"
}
@router.post("/nodes/save")
async def save_node(payload: Dict[str, Any]):
"""
保存节点配置到服务器
"""
node_id = payload.get("nodeId")
node_data = payload.get("nodeData")
if not node_id:
raise HTTPException(status_code=400, detail="Missing nodeId")
# TODO: 实现节点保存逻辑(可以存到数据库或文件)
return {
"success": True,
"message": f"节点 {node_id} 配置已保存",
"nodeId": node_id
}
@router.post("/workflows/save")
async def save_workflow(payload: Dict[str, Any]):
"""
保存工作流到用户目录
"""
filename = payload.get("filename")
username = payload.get("username", "guest")
workflow_data = payload.get("data")
if not filename:
raise HTTPException(status_code=400, detail="Missing filename")
if not workflow_data:
raise HTTPException(status_code=400, detail="Missing workflow data")
# 验证文件名
is_valid, error_msg = validate_filename(filename)
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 确保.json扩展名
if not filename.endswith(".json"):
filename = f"{filename}.json"
# 构建保存路径
workflow_dir = get_user_path(username, "workflows")
workflow_dir.mkdir(parents=True, exist_ok=True)
workflow_file = workflow_dir / filename
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
# 保存工作流
try:
with open(workflow_file, "w", encoding="utf-8") as f:
json.dump(workflow_data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"工作流已保存: {filename}",
"path": str(workflow_file.relative_to(user_root.parent))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/workflows/load")
async def load_workflow(filename: str, username: str = "guest"):
"""
加载工作流
"""
# 清理文件名
filename = sanitize_path(filename)
# 构建文件路径
workflow_file = get_user_path(username, f"workflows/{filename}")
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
if not workflow_file.exists():
raise HTTPException(status_code=404, detail="Workflow not found")
# 加载工作流
try:
with open(workflow_file, "r", encoding="utf-8") as f:
workflow_data = json.load(f)
return {
"success": True,
"filename": filename,
"data": workflow_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}")
class SaveFunctionRequest(BaseModel):
"""保存函数请求"""
function_name: str
display_name: str
description: Optional[str] = ""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
inputs: List[Dict[str, Any]]
outputs: List[Dict[str, Any]]
@router.post("/functions/save")
async def save_function(request: SaveFunctionRequest):
"""
保存工作流为函数
将选中的节点和连线封装为可复用的函数节点
"""
from ..core.node_loader import load_function_nodes
# 验证函数名
function_name = validate_filename(request.function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
# 构建函数文件路径
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
functions_dir.mkdir(parents=True, exist_ok=True)
function_file = functions_dir / f"{function_name}.json"
# 检查是否已存在
if function_file.exists():
raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在")
# 构建函数数据
function_data = {
"function_name": function_name,
"display_name": request.display_name,
"description": request.description,
"inputs": request.inputs,
"outputs": request.outputs,
"nodes": [node.dict() for node in request.nodes],
"edges": [edge.dict() for edge in request.edges]
}
# 保存函数
try:
with open(function_file, "w", encoding="utf-8") as f:
json.dump(function_data, f, indent=2, ensure_ascii=False)
# 重新加载函数节点
load_result = load_function_nodes()
return {
"success": True,
"message": f"函数已保存: {request.display_name}",
"function_name": function_name,
"path": str(function_file.relative_to(project_root)),
"reload_result": load_result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/functions/list")
async def list_functions():
"""
获取所有可用函数列表
"""
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
if not functions_dir.exists():
return {"functions": []}
functions = []
for json_file in functions_dir.glob("*.json"):
try:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
functions.append({
"function_name": data.get("function_name", json_file.stem),
"display_name": data.get("display_name", json_file.stem),
"description": data.get("description", ""),
"inputs": data.get("inputs", []),
"outputs": data.get("outputs", [])
})
except Exception as e:
print(f"❌ 读取函数失败: {json_file.name} - {str(e)}")
return {"functions": functions}
@router.get("/functions/{function_name}")
async def get_function(function_name: str):
"""
获取函数详细信息(包含内部工作流)
"""
# 验证函数名
function_name = validate_filename(function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
project_root = Path(__file__).resolve().parents[3]
function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json"
if not function_file.exists():
raise HTTPException(status_code=404, detail="Function not found")
try:
with open(function_file, 'r', encoding='utf-8') as f:
function_data = json.load(f)
return {
"success": True,
"data": function_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")