450 lines
13 KiB
Python
450 lines
13 KiB
Python
"""
|
||
自定义节点管理 API
|
||
提供节点的增删改查、验证、加载等功能
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, UploadFile, File, Body
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional
|
||
import shutil
|
||
from datetime import datetime
|
||
|
||
from ..core.node_validator import NodeValidator, validate_node_code
|
||
from ..core.node_loader import reload_custom_nodes
|
||
from ..core.security import sanitize_path, validate_filename
|
||
|
||
|
||
router = APIRouter(prefix="/api/custom-nodes", tags=["自定义节点"])
|
||
|
||
|
||
# ============ 数据模型 ============
|
||
|
||
class NodeValidateRequest(BaseModel):
|
||
code: str
|
||
filename: Optional[str] = "<string>"
|
||
|
||
|
||
class NodeSaveRequest(BaseModel):
|
||
filename: str
|
||
code: str
|
||
force: bool = False # 是否强制保存(覆盖)
|
||
|
||
|
||
class NodeActionRequest(BaseModel):
|
||
filename: str
|
||
action: str # 'load', 'unload', 'delete', 'rename'
|
||
new_filename: Optional[str] = None
|
||
|
||
|
||
# ============ API 端点 ============
|
||
|
||
@router.get("/list")
|
||
async def list_custom_nodes():
|
||
"""
|
||
列出所有自定义节点
|
||
|
||
返回节点文件列表,包含验证状态
|
||
"""
|
||
try:
|
||
loader = get_node_loader()
|
||
nodes = loader.scan_nodes()
|
||
|
||
return {
|
||
"success": True,
|
||
"total": len(nodes),
|
||
"nodes": nodes
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"扫描节点失败: {str(e)}")
|
||
|
||
|
||
@router.post("/validate")
|
||
async def validate_node(request: NodeValidateRequest):
|
||
"""
|
||
验证节点代码
|
||
|
||
不保存,仅检查语法和安全性
|
||
"""
|
||
try:
|
||
result = validate_node_code(request.code)
|
||
|
||
return {
|
||
"success": True,
|
||
"validation": result
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
|
||
|
||
|
||
@router.get("/read/{filename}")
|
||
async def read_node(filename: str):
|
||
"""
|
||
读取节点代码
|
||
|
||
用于前端编辑器加载
|
||
"""
|
||
# 验证文件名
|
||
is_safe, error = NodeValidator.is_safe_filename(filename)
|
||
if not is_safe:
|
||
raise HTTPException(status_code=400, detail=error)
|
||
|
||
loader = get_node_loader()
|
||
file_path = loader.custom_nodes_dir / filename
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
try:
|
||
code = file_path.read_text(encoding='utf-8')
|
||
|
||
# 获取文件信息
|
||
stat = file_path.stat()
|
||
|
||
return {
|
||
"success": True,
|
||
"filename": filename,
|
||
"code": code,
|
||
"size": stat.st_size,
|
||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
|
||
|
||
|
||
@router.post("/save")
|
||
async def save_node(request: NodeSaveRequest):
|
||
"""
|
||
保存节点代码
|
||
|
||
需要先验证,然后保存并自动加载
|
||
"""
|
||
# 1. 验证文件名
|
||
is_safe, error = NodeValidator.is_safe_filename(request.filename)
|
||
if not is_safe:
|
||
raise HTTPException(status_code=400, detail=error)
|
||
|
||
# 2. 验证代码
|
||
validation = validate_node_code(request.code)
|
||
if not validation['valid']:
|
||
return {
|
||
"success": False,
|
||
"message": "代码验证失败,请修复错误后重试",
|
||
"validation": validation
|
||
}
|
||
|
||
# 3. 检查文件是否存在
|
||
loader = get_node_loader()
|
||
file_path = loader.custom_nodes_dir / request.filename
|
||
|
||
if file_path.exists() and not request.force:
|
||
return {
|
||
"success": False,
|
||
"message": "文件已存在,请确认是否覆盖",
|
||
"require_confirm": True,
|
||
"validation": validation
|
||
}
|
||
|
||
# 4. 备份旧文件(如果存在)
|
||
if file_path.exists():
|
||
backup_name = f"{request.filename}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
backup_path = loader.custom_nodes_dir / backup_name
|
||
shutil.copy2(file_path, backup_path)
|
||
|
||
# 5. 保存文件
|
||
try:
|
||
file_path.write_text(request.code, encoding='utf-8')
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
|
||
|
||
# 6. 自动加载节点
|
||
load_result = loader.load_node(request.filename, force_reload=True)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"节点已保存并加载",
|
||
"filename": request.filename,
|
||
"validation": validation,
|
||
"load_result": load_result
|
||
}
|
||
|
||
|
||
@router.post("/action")
|
||
async def node_action(request: NodeActionRequest):
|
||
"""
|
||
对节点执行操作
|
||
|
||
- load: 加载节点
|
||
- unload: 卸载节点
|
||
- delete: 删除节点文件
|
||
- rename: 重命名节点文件
|
||
"""
|
||
loader = get_node_loader()
|
||
file_path = loader.custom_nodes_dir / request.filename
|
||
|
||
if not file_path.exists() and request.action != 'load':
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
try:
|
||
if request.action == 'load':
|
||
result = loader.load_node(request.filename, force_reload=True)
|
||
return {
|
||
"success": result['success'],
|
||
"message": result['message'],
|
||
"classes": result['classes']
|
||
}
|
||
|
||
elif request.action == 'unload':
|
||
loader.unload_node(request.filename)
|
||
return {
|
||
"success": True,
|
||
"message": f"已卸载节点: {request.filename}"
|
||
}
|
||
|
||
elif request.action == 'delete':
|
||
# 先卸载
|
||
loader.unload_node(request.filename)
|
||
|
||
# 备份后删除
|
||
backup_name = f"{request.filename}.deleted_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
backup_path = loader.custom_nodes_dir / backup_name
|
||
shutil.move(file_path, backup_path)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已删除节点(备份为 {backup_name})"
|
||
}
|
||
|
||
elif request.action == 'rename':
|
||
if not request.new_filename:
|
||
raise HTTPException(status_code=400, detail="缺少 new_filename 参数")
|
||
|
||
# 验证新文件名
|
||
is_safe, error = NodeValidator.is_safe_filename(request.new_filename)
|
||
if not is_safe:
|
||
raise HTTPException(status_code=400, detail=error)
|
||
|
||
new_path = loader.custom_nodes_dir / request.new_filename
|
||
if new_path.exists():
|
||
raise HTTPException(status_code=400, detail="目标文件名已存在")
|
||
|
||
# 先卸载
|
||
loader.unload_node(request.filename)
|
||
|
||
# 重命名
|
||
file_path.rename(new_path)
|
||
|
||
# 重新加载
|
||
loader.load_node(request.new_filename)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已重命名: {request.filename} -> {request.new_filename}"
|
||
}
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的操作: {request.action}")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
||
|
||
|
||
@router.get("/download/{filename}")
|
||
async def download_node(filename: str):
|
||
"""下载节点文件"""
|
||
# 验证文件名
|
||
is_safe, error = NodeValidator.is_safe_filename(filename)
|
||
if not is_safe:
|
||
raise HTTPException(status_code=400, detail=error)
|
||
|
||
loader = get_node_loader()
|
||
file_path = loader.custom_nodes_dir / filename
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
return FileResponse(
|
||
path=file_path,
|
||
media_type='text/x-python',
|
||
filename=filename
|
||
)
|
||
|
||
|
||
@router.post("/upload")
|
||
async def upload_node(file: UploadFile = File(...)):
|
||
"""
|
||
上传节点文件
|
||
|
||
上传后自动验证和加载
|
||
"""
|
||
# 验证文件名
|
||
is_safe, error = NodeValidator.is_safe_filename(file.filename)
|
||
if not is_safe:
|
||
raise HTTPException(status_code=400, detail=error)
|
||
|
||
loader = get_node_loader()
|
||
file_path = loader.custom_nodes_dir / file.filename
|
||
|
||
# 检查文件是否存在
|
||
if file_path.exists():
|
||
return {
|
||
"success": False,
|
||
"message": "文件已存在,请先删除或使用不同的文件名",
|
||
"require_confirm": True
|
||
}
|
||
|
||
# 保存临时文件
|
||
try:
|
||
content = await file.read()
|
||
code = content.decode('utf-8')
|
||
|
||
# 验证代码
|
||
validation = validate_node_code(code)
|
||
if not validation['valid']:
|
||
return {
|
||
"success": False,
|
||
"message": "节点验证失败",
|
||
"validation": validation
|
||
}
|
||
|
||
# 保存文件
|
||
file_path.write_text(code, encoding='utf-8')
|
||
|
||
# 加载节点
|
||
load_result = loader.load_node(file.filename)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "节点已上传并加载",
|
||
"filename": file.filename,
|
||
"validation": validation,
|
||
"load_result": load_result
|
||
}
|
||
|
||
except UnicodeDecodeError:
|
||
raise HTTPException(status_code=400, detail="文件编码错误,必须是UTF-8")
|
||
except Exception as e:
|
||
# 如果失败,删除已保存的文件
|
||
if file_path.exists():
|
||
file_path.unlink()
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
|
||
@router.get("/loaded")
|
||
async def get_loaded_nodes():
|
||
"""
|
||
获取已加载的节点
|
||
|
||
返回当前内存中的节点类
|
||
"""
|
||
try:
|
||
loader = get_node_loader()
|
||
loaded = loader.get_all_loaded_nodes()
|
||
|
||
result = []
|
||
for class_name, data in loaded.items():
|
||
result.append({
|
||
'class_name': class_name,
|
||
'metadata': data['metadata']
|
||
})
|
||
|
||
return {
|
||
"success": True,
|
||
"total": len(result),
|
||
"nodes": result
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取失败: {str(e)}")
|
||
|
||
|
||
@router.post("/reload-all")
|
||
async def reload_all_nodes():
|
||
"""重新加载所有自定义节点"""
|
||
try:
|
||
loader = get_node_loader()
|
||
result = loader.load_all_nodes()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已加载 {result['loaded']}/{result['total']} 个节点文件",
|
||
"details": result
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重新加载失败: {str(e)}")
|
||
|
||
|
||
@router.get("/example")
|
||
async def get_example_node():
|
||
"""获取示例节点代码"""
|
||
example_code = '''"""
|
||
示例自定义节点
|
||
"""
|
||
from app.core.node_base import TraceNode
|
||
|
||
|
||
class ExampleFilterNode(TraceNode):
|
||
"""数据过滤节点"""
|
||
|
||
@staticmethod
|
||
def get_metadata():
|
||
return {
|
||
"display_name": "示例过滤节点",
|
||
"description": "根据阈值过滤数据",
|
||
"category": "Custom/Examples",
|
||
"author": "Your Name",
|
||
"version": "1.0.0",
|
||
"inputs": [
|
||
{
|
||
"name": "data",
|
||
"type": "DataFrame",
|
||
"description": "输入数据"
|
||
}
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "filtered_data",
|
||
"type": "DataFrame",
|
||
"description": "过滤后的数据"
|
||
}
|
||
],
|
||
"params": [
|
||
{
|
||
"name": "threshold",
|
||
"type": "float",
|
||
"default": 0.5,
|
||
"description": "过滤阈值"
|
||
},
|
||
{
|
||
"name": "column",
|
||
"type": "str",
|
||
"default": "value",
|
||
"description": "过滤列名"
|
||
}
|
||
]
|
||
}
|
||
|
||
def execute(self, inputs):
|
||
"""执行节点逻辑"""
|
||
import pandas as pd
|
||
|
||
# 获取输入
|
||
data = inputs.get('data')
|
||
threshold = self.params.get('threshold', 0.5)
|
||
column = self.params.get('column', 'value')
|
||
|
||
# 过滤数据
|
||
if isinstance(data, pd.DataFrame):
|
||
filtered = data[data[column] > threshold]
|
||
return {'filtered_data': filtered}
|
||
else:
|
||
raise ValueError("输入必须是 DataFrame")
|
||
'''
|
||
|
||
return {
|
||
"success": True,
|
||
"code": example_code,
|
||
"filename": "example_filter_node.py"
|
||
}
|