TraceStudio-dev/server/app/api/endpoints_graph.py

604 lines
20 KiB
Python
Raw Normal View History

"""
图执行和查询相关 API
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import json
from pathlib import Path
from server.app.core.node_base import DimensionMode
from ..core.user_manager import get_user_path
from ..core.security import is_safe_path, validate_filename, sanitize_path
router = APIRouter()
class NodeSchema(BaseModel):
"""节点数据结构"""
id: str
type: str
class_name: Optional[str] = None
function: Optional[str] = None
# 支持前端发送 `data` 或历史字段 `params`,两者任选其一
data: Optional[Dict[str, Any]] = None
params: Optional[Dict[str, Any]] = None
class EdgeSchema(BaseModel):
"""连线数据结构"""
# 允许前端不携带 id后端会合成但推荐前端提供稳定 id
id: Optional[str] = None
source: str
target: str
dimension_mode: DimensionMode = DimensionMode.NONE
source_port: Optional[str] = None
target_port: Optional[str] = None
class GraphExecuteRequest(BaseModel):
"""图执行请求"""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
settings: Optional[Dict[str, Any]] = None
class NodePreviewRequest(BaseModel):
"""节点预览请求"""
node: NodeSchema
limit: int = 10
@router.get("/plugins")
async def get_plugins():
"""
获取所有可用的节点插件
NodeRegistry 动态获取
"""
from ..core.node_registry import NodeRegistry
# 获取所有注册的节点
all_nodes = NodeRegistry.list_all()
# 转换为前端需要的格式
plugins = {}
for class_name, node_class in all_nodes.items():
metadata = node_class.get_metadata()
# 转换输入端口格式
inputs = metadata.get("inputs", [])
# 转换输出端口格式
outputs = metadata.get("outputs", [])
# 转换参数格式
param_schema = metadata.get("params", [])
plugins[class_name] = {
"display_name": metadata.get("display_name", class_name),
"category": metadata.get("category", "Uncategorized"),
"description": metadata.get("description", ""),
"icon": metadata.get("icon", "📦"),
"node_type": metadata.get("node_type", "REGULAR"),
# 前端双击检测所需,使用实现类名 `class_name`
"class_name": class_name,
"node_logic": "standard",
"supports_preview": True,
"inputs": inputs,
"outputs": outputs,
"param_schema": param_schema,
"context_vars": metadata.get("context_vars", {}),
"cache_policy": metadata.get("cache_policy", "NONE")
}
return {
"plugins": plugins,
"total": len(plugins),
"categories": NodeRegistry.get_categories()
}
@router.post("/node/preview")
async def preview_node(request: NodePreviewRequest):
"""
预览单个节点的输出
"""
from ..core.node_registry import NodeRegistry
import polars as pl
try:
# 获取实现类名(严格要求存在)
class_name = request.node.class_name
if not class_name:
raise HTTPException(status_code=400, detail="Missing class_name for node preview")
node_class = NodeRegistry.get(class_name)
# 创建节点实例
node_instance = node_class()
# 设置参数(兼容前端可能发送的 `data` 或 `params`
pdata = request.node.data or request.node.params or {}
for param_name, param_value in (pdata.items() if isinstance(pdata, dict) else []):
# 使用兼容的 set_param 接口(有的节点实现可能不同)
try:
node_instance.set_param(param_name, param_value)
except Exception:
setattr(node_instance, param_name, param_value)
# 准备输入(预览模式下输入可能为空)
inputs = {}
# 准备上下文
context = {
"user_id": "guest",
"preview_mode": True
}
# 执行节点
result = await node_instance.wrap_process(inputs, context)
# 提取输出
outputs = result.get("outputs", {})
# 如果输出是 DataFrame转换为预览格式
preview_data = []
columns = []
for output_name, output_value in outputs.items():
if isinstance(output_value, pl.DataFrame):
# polars DataFrame 预览
columns = list(output_value.columns)
preview_data = output_value.head(request.limit).to_dicts()
break
elif isinstance(output_value, list):
# 数组预览
preview_data = output_value[:request.limit]
columns = ["value"]
preview_data = [{"value": item} for item in preview_data]
break
else:
# 单值
preview_data = [{"value": output_value}]
columns = ["value"]
break
return {
"class_name" : class_name,
"status": "success",
"columns": columns,
"preview": preview_data,
"context": result.get("context", {})
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"节点预览失败: {str(e)}"
)
@router.post("/graph/execute")
async def execute_graph(request: GraphExecuteRequest):
"""
执行完整的计算图
"""
from ..core.workflow_executor import WorkflowExecutor
import time
try:
start_time = time.time()
# 创建工作流执行器
executor = WorkflowExecutor(user_id="guest") # TODO: 从请求获取用户ID
# 转换节点格式,严格要求 class_name
nodes = []
for node in request.nodes:
# 支持前端将参数放在 `data` 或 `params` 中
pdata = node.data or node.params or {}
if not getattr(node, 'class_name', None) and not (isinstance(pdata, dict) and pdata.get('meta', {}).get('class_name')):
raise HTTPException(status_code=400, detail=f"Node {node.id} missing class_name")
nodes.append({
"id": node.id,
"type": node.type,
"class_name": node.class_name or (pdata.get('meta', {}) or {}).get('class_name'),
# 将运行时参数统一为 `params` 字段传入执行器
"params": pdata,
"function": node.function,
})
# 转换连线格式
edges = []
for edge in request.edges:
# 支持前端使用 source_port/target_port 或 sourceHandle/targetHandle
s_handle = edge.source_port or "output"
t_handle = edge.target_port or "input"
eid = edge.id or f"{edge.source}:{s_handle}->{edge.target}:{t_handle}"
dimension_mode = edge.dimension_mode
edges.append({
"id": eid,
"source": edge.source,
"source_port": s_handle,
"target": edge.target,
"target_port": t_handle,
"dimension_mode": dimension_mode
})
# 准备全局上下文
global_context = request.settings or {}
# 执行工作流
success, results = await executor.execute(nodes, edges, global_context)
if not success:
print(results)
raise Exception(str(results))
execution_time = time.time() - start_time
# 兼容处理executor 可能返回旧格式mapping node_id->info或新格式报告包含 node_infos/node_results
report = results
node_infos = {}
node_results = {}
if isinstance(report, dict) and ('node_infos' in report or 'node_results' in report):
node_infos = report.get('node_infos', {}) or {}
node_results = report.get('node_results', {}) or {}
elif isinstance(report, dict):
# 退回兼容report 可能本身就是 node_id->NodeExecutionInfo 映射
node_infos = report
node_results = {}
def _get_status(info):
if info is None:
return 'unknown'
if isinstance(info, dict):
return info.get('status')
if hasattr(info, 'status'):
return info.status.value
return str(info)
def _get_error(info):
if info is None:
return None
if isinstance(info, dict):
return info.get('error')
if hasattr(info, 'error'):
return info.error
return None
# 合并输出:优先从 node_infos 获取状态与错误,从 node_results 获取 outputs
all_node_ids = set(list(node_infos.keys()) + list(node_results.keys()))
import polars as pl
# 把 node_results 中的 polars.DataFrame 转为可序列化的预览字典,防止 FastAPI 在响应序列化时报错
results_out: Dict[str, Any] = {}
PREVIEW_LIMIT = 20
for nid in all_node_ids:
info = node_infos.get(nid)
outputs = node_results.get(nid)
def _serialize_value(v):
# polars DataFrame -> preview dict
if isinstance(v, pl.DataFrame):
return {
"__type": "DataFrame",
"columns": list(v.columns),
"preview": v.head(PREVIEW_LIMIT).to_dicts(),
"rows": v.height
}
# list/tuple -> recursively serialize
if isinstance(v, (list, tuple)):
return [_serialize_value(x) for x in v]
# dict -> recursively serialize values
if isinstance(v, dict):
return {k: _serialize_value(val) for k, val in v.items()}
return v
ser_outputs = None
if isinstance(outputs, dict):
ser_outputs = {k: _serialize_value(v) for k, v in outputs.items()}
else:
ser_outputs = _serialize_value(outputs)
results_out[nid] = {
'status': _get_status(info),
'outputs': ser_outputs,
'error': _get_error(info)
}
return {
"status": "success",
"message": f"成功执行 {len(request.nodes)} 个节点",
"execution_time": round(execution_time, 3),
"results": results_out,
"stats": {
"total_nodes": len(request.nodes),
"total_edges": len(request.edges),
"execution_order": list(results_out.keys())
}
}
except Exception as e:
import traceback
raise HTTPException(
status_code=500,
detail={
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}
)
@router.get("/users/list")
async def list_users():
"""
获取所有活跃用户列表从数据库
"""
from ..core.user_manager import list_users
users = list_users()
return {
"users": users,
"count": len(users)
}
@router.post("/users/add")
async def add_user(payload: Dict[str, Any]):
"""
添加新用户
"""
from ..core.user_manager import load_users_db, save_users_db, create_user_workspace
username = payload.get("username")
display_name = payload.get("display_name", username)
if not username:
raise HTTPException(status_code=400, detail="Missing username")
# 验证用户名格式
if not username.replace('_', '').replace('-', '').isalnum():
raise HTTPException(status_code=400, detail="Invalid username format")
# 加载数据库
db = load_users_db()
# 检查用户是否已存在
if any(u["username"] == username for u in db.get("users", [])):
raise HTTPException(status_code=400, detail="User already exists")
# 添加用户
from datetime import datetime
db["users"].append({
"username": username,
"display_name": display_name,
"created_at": datetime.now().isoformat(),
"active": True
})
# 保存数据库
if not save_users_db(db):
raise HTTPException(status_code=500, detail="Failed to save user database")
# 创建用户工作空间
create_user_workspace(username)
return {
"success": True,
"message": f"用户 {username} 创建成功"
}
@router.post("/nodes/save")
async def save_node(payload: Dict[str, Any]):
"""
保存节点配置到服务器
"""
node_id = payload.get("nodeId")
node_data = payload.get("nodeData")
if not node_id:
raise HTTPException(status_code=400, detail="Missing nodeId")
# TODO: 实现节点保存逻辑(可以存到数据库或文件)
return {
"success": True,
"message": f"节点 {node_id} 配置已保存",
"nodeId": node_id
}
@router.post("/workflows/save")
async def save_workflow(payload: Dict[str, Any]):
"""
保存工作流到用户目录
"""
filename = payload.get("filename")
username = payload.get("username", "guest")
workflow_data = payload.get("data")
if not filename:
raise HTTPException(status_code=400, detail="Missing filename")
if not workflow_data:
raise HTTPException(status_code=400, detail="Missing workflow data")
# 验证文件名
is_valid, error_msg = validate_filename(filename)
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 确保.json扩展名
if not filename.endswith(".json"):
filename = f"{filename}.json"
# 构建保存路径
workflow_dir = get_user_path(username, "workflows")
workflow_dir.mkdir(parents=True, exist_ok=True)
workflow_file = workflow_dir / filename
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
# 保存工作流
try:
with open(workflow_file, "w", encoding="utf-8") as f:
json.dump(workflow_data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"工作流已保存: {filename}",
"path": str(workflow_file.relative_to(user_root.parent))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/workflows/load")
async def load_workflow(filename: str, username: str = "guest"):
"""
加载工作流
"""
# 清理文件名
filename = sanitize_path(filename)
# 构建文件路径
workflow_file = get_user_path(username, f"workflows/{filename}")
# 安全检查
user_root = get_user_path(username)
if not is_safe_path(user_root, workflow_file):
raise HTTPException(status_code=403, detail="Access denied")
if not workflow_file.exists():
raise HTTPException(status_code=404, detail="Workflow not found")
# 加载工作流
try:
with open(workflow_file, "r", encoding="utf-8") as f:
workflow_data = json.load(f)
return {
"success": True,
"filename": filename,
"data": workflow_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"加载失败: {str(e)}")
class SaveFunctionRequest(BaseModel):
"""保存函数请求"""
function_name: str
display_name: str
description: Optional[str] = ""
nodes: List[NodeSchema]
edges: List[EdgeSchema]
inputs: List[Dict[str, Any]]
outputs: List[Dict[str, Any]]
@router.post("/functions/save")
async def save_function(request: SaveFunctionRequest):
"""
保存工作流为函数
将选中的节点和连线封装为可复用的函数节点
"""
from ..core.node_loader import load_function_nodes
# 验证函数名
function_name = validate_filename(request.function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
# 构建函数文件路径
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
functions_dir.mkdir(parents=True, exist_ok=True)
function_file = functions_dir / f"{function_name}.json"
# 检查是否已存在
if function_file.exists():
raise HTTPException(status_code=409, detail=f"函数 {function_name} 已存在")
# 构建函数数据
function_data = {
"function_name": function_name,
"display_name": request.display_name,
"description": request.description,
"inputs": request.inputs,
"outputs": request.outputs,
"nodes": [node.dict() for node in request.nodes],
"edges": [edge.dict() for edge in request.edges]
}
# 保存函数
try:
with open(function_file, "w", encoding="utf-8") as f:
json.dump(function_data, f, indent=2, ensure_ascii=False)
# 重新加载函数节点
load_result = load_function_nodes()
return {
"success": True,
"message": f"函数已保存: {request.display_name}",
"function_name": function_name,
"path": str(function_file.relative_to(project_root)),
"reload_result": load_result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
@router.get("/functions/list")
async def list_functions():
"""
获取所有可用函数列表
"""
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
if not functions_dir.exists():
return {"functions": []}
functions = []
for json_file in functions_dir.glob("*.json"):
try:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
functions.append({
"function_name": data.get("function_name", json_file.stem),
"display_name": data.get("display_name", json_file.stem),
"description": data.get("description", ""),
"inputs": data.get("inputs", []),
"outputs": data.get("outputs", [])
})
except Exception as e:
print(f"❌ 读取函数失败: {json_file.name} - {str(e)}")
return {"functions": functions}
@router.get("/functions/{function_name}")
async def get_function(function_name: str):
"""
获取函数详细信息包含内部工作流
"""
# 验证函数名
function_name = validate_filename(function_name)
if not function_name:
raise HTTPException(status_code=400, detail="Invalid function name")
project_root = Path(__file__).resolve().parents[3]
function_file = project_root / "cloud" / "custom_nodes" / "functions" / f"{function_name}.json"
if not function_file.exists():
raise HTTPException(status_code=404, detail="Function not found")
try:
with open(function_file, 'r', encoding='utf-8') as f:
function_data = json.load(f)
return {
"success": True,
"data": function_data
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取失败: {str(e)}")