TraceStudio-dev/server/app/api/endpoints_custom_nodes.py

450 lines
13 KiB
Python
Raw Permalink Normal View History

"""
自定义节点管理 API
提供节点的增删改查验证加载等功能
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Body
from fastapi.responses import FileResponse
from pydantic import BaseModel
from pathlib import Path
from typing import Dict, List, Optional
import shutil
from datetime import datetime
from ..core.node_validator import NodeValidator, validate_node_code
from ..core.node_loader import reload_custom_nodes
from ..core.security import sanitize_path, validate_filename
router = APIRouter(prefix="/api/custom-nodes", tags=["自定义节点"])
# ============ 数据模型 ============
class NodeValidateRequest(BaseModel):
code: str
filename: Optional[str] = "<string>"
class NodeSaveRequest(BaseModel):
filename: str
code: str
force: bool = False # 是否强制保存(覆盖)
class NodeActionRequest(BaseModel):
filename: str
action: str # 'load', 'unload', 'delete', 'rename'
new_filename: Optional[str] = None
# ============ API 端点 ============
@router.get("/list")
async def list_custom_nodes():
"""
列出所有自定义节点
返回节点文件列表包含验证状态
"""
try:
loader = get_node_loader()
nodes = loader.scan_nodes()
return {
"success": True,
"total": len(nodes),
"nodes": nodes
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"扫描节点失败: {str(e)}")
@router.post("/validate")
async def validate_node(request: NodeValidateRequest):
"""
验证节点代码
不保存仅检查语法和安全性
"""
try:
result = validate_node_code(request.code)
return {
"success": True,
"validation": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
@router.get("/read/{filename}")
async def read_node(filename: str):
"""
读取节点代码
用于前端编辑器加载
"""
# 验证文件名
is_safe, error = NodeValidator.is_safe_filename(filename)
if not is_safe:
raise HTTPException(status_code=400, detail=error)
loader = get_node_loader()
file_path = loader.custom_nodes_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
try:
code = file_path.read_text(encoding='utf-8')
# 获取文件信息
stat = file_path.stat()
return {
"success": True,
"filename": filename,
"code": code,
"size": stat.st_size,
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
@router.post("/save")
async def save_node(request: NodeSaveRequest):
"""
保存节点代码
需要先验证然后保存并自动加载
"""
# 1. 验证文件名
is_safe, error = NodeValidator.is_safe_filename(request.filename)
if not is_safe:
raise HTTPException(status_code=400, detail=error)
# 2. 验证代码
validation = validate_node_code(request.code)
if not validation['valid']:
return {
"success": False,
"message": "代码验证失败,请修复错误后重试",
"validation": validation
}
# 3. 检查文件是否存在
loader = get_node_loader()
file_path = loader.custom_nodes_dir / request.filename
if file_path.exists() and not request.force:
return {
"success": False,
"message": "文件已存在,请确认是否覆盖",
"require_confirm": True,
"validation": validation
}
# 4. 备份旧文件(如果存在)
if file_path.exists():
backup_name = f"{request.filename}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
backup_path = loader.custom_nodes_dir / backup_name
shutil.copy2(file_path, backup_path)
# 5. 保存文件
try:
file_path.write_text(request.code, encoding='utf-8')
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
# 6. 自动加载节点
load_result = loader.load_node(request.filename, force_reload=True)
return {
"success": True,
"message": f"节点已保存并加载",
"filename": request.filename,
"validation": validation,
"load_result": load_result
}
@router.post("/action")
async def node_action(request: NodeActionRequest):
"""
对节点执行操作
- load: 加载节点
- unload: 卸载节点
- delete: 删除节点文件
- rename: 重命名节点文件
"""
loader = get_node_loader()
file_path = loader.custom_nodes_dir / request.filename
if not file_path.exists() and request.action != 'load':
raise HTTPException(status_code=404, detail="文件不存在")
try:
if request.action == 'load':
result = loader.load_node(request.filename, force_reload=True)
return {
"success": result['success'],
"message": result['message'],
"classes": result['classes']
}
elif request.action == 'unload':
loader.unload_node(request.filename)
return {
"success": True,
"message": f"已卸载节点: {request.filename}"
}
elif request.action == 'delete':
# 先卸载
loader.unload_node(request.filename)
# 备份后删除
backup_name = f"{request.filename}.deleted_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
backup_path = loader.custom_nodes_dir / backup_name
shutil.move(file_path, backup_path)
return {
"success": True,
"message": f"已删除节点(备份为 {backup_name}"
}
elif request.action == 'rename':
if not request.new_filename:
raise HTTPException(status_code=400, detail="缺少 new_filename 参数")
# 验证新文件名
is_safe, error = NodeValidator.is_safe_filename(request.new_filename)
if not is_safe:
raise HTTPException(status_code=400, detail=error)
new_path = loader.custom_nodes_dir / request.new_filename
if new_path.exists():
raise HTTPException(status_code=400, detail="目标文件名已存在")
# 先卸载
loader.unload_node(request.filename)
# 重命名
file_path.rename(new_path)
# 重新加载
loader.load_node(request.new_filename)
return {
"success": True,
"message": f"已重命名: {request.filename} -> {request.new_filename}"
}
else:
raise HTTPException(status_code=400, detail=f"不支持的操作: {request.action}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
@router.get("/download/{filename}")
async def download_node(filename: str):
"""下载节点文件"""
# 验证文件名
is_safe, error = NodeValidator.is_safe_filename(filename)
if not is_safe:
raise HTTPException(status_code=400, detail=error)
loader = get_node_loader()
file_path = loader.custom_nodes_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(
path=file_path,
media_type='text/x-python',
filename=filename
)
@router.post("/upload")
async def upload_node(file: UploadFile = File(...)):
"""
上传节点文件
上传后自动验证和加载
"""
# 验证文件名
is_safe, error = NodeValidator.is_safe_filename(file.filename)
if not is_safe:
raise HTTPException(status_code=400, detail=error)
loader = get_node_loader()
file_path = loader.custom_nodes_dir / file.filename
# 检查文件是否存在
if file_path.exists():
return {
"success": False,
"message": "文件已存在,请先删除或使用不同的文件名",
"require_confirm": True
}
# 保存临时文件
try:
content = await file.read()
code = content.decode('utf-8')
# 验证代码
validation = validate_node_code(code)
if not validation['valid']:
return {
"success": False,
"message": "节点验证失败",
"validation": validation
}
# 保存文件
file_path.write_text(code, encoding='utf-8')
# 加载节点
load_result = loader.load_node(file.filename)
return {
"success": True,
"message": "节点已上传并加载",
"filename": file.filename,
"validation": validation,
"load_result": load_result
}
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="文件编码错误必须是UTF-8")
except Exception as e:
# 如果失败,删除已保存的文件
if file_path.exists():
file_path.unlink()
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.get("/loaded")
async def get_loaded_nodes():
"""
获取已加载的节点
返回当前内存中的节点类
"""
try:
loader = get_node_loader()
loaded = loader.get_all_loaded_nodes()
result = []
for class_name, data in loaded.items():
result.append({
'class_name': class_name,
'metadata': data['metadata']
})
return {
"success": True,
"total": len(result),
"nodes": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取失败: {str(e)}")
@router.post("/reload-all")
async def reload_all_nodes():
"""重新加载所有自定义节点"""
try:
loader = get_node_loader()
result = loader.load_all_nodes()
return {
"success": True,
"message": f"已加载 {result['loaded']}/{result['total']} 个节点文件",
"details": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"重新加载失败: {str(e)}")
@router.get("/example")
async def get_example_node():
"""获取示例节点代码"""
example_code = '''"""
示例自定义节点
"""
from app.core.node_base import TraceNode
class ExampleFilterNode(TraceNode):
"""数据过滤节点"""
@staticmethod
def get_metadata():
return {
"display_name": "示例过滤节点",
"description": "根据阈值过滤数据",
"category": "Custom/Examples",
"author": "Your Name",
"version": "1.0.0",
"inputs": [
{
"name": "data",
"type": "DataFrame",
"description": "输入数据"
}
],
"outputs": [
{
"name": "filtered_data",
"type": "DataFrame",
"description": "过滤后的数据"
}
],
"params": [
{
"name": "threshold",
"type": "float",
"default": 0.5,
"description": "过滤阈值"
},
{
"name": "column",
"type": "str",
"default": "value",
"description": "过滤列名"
}
]
}
def execute(self, inputs):
"""执行节点逻辑"""
import pandas as pd
# 获取输入
data = inputs.get('data')
threshold = self.params.get('threshold', 0.5)
column = self.params.get('column', 'value')
# 过滤数据
if isinstance(data, pd.DataFrame):
filtered = data[data[column] > threshold]
return {'filtered_data': filtered}
else:
raise ValueError("输入必须是 DataFrame")
'''
return {
"success": True,
"code": example_code,
"filename": "example_filter_node.py"
}