450 lines
13 KiB
Python
450 lines
13 KiB
Python
|
|
"""
|
|||
|
|
自定义节点管理 API
|
|||
|
|
提供节点的增删改查、验证、加载等功能
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Body
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, List, Optional
|
|||
|
|
import shutil
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
from ..core.node_validator import NodeValidator, validate_node_code
|
|||
|
|
from ..core.node_loader import reload_custom_nodes
|
|||
|
|
from ..core.security import sanitize_path, validate_filename
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/api/custom-nodes", tags=["自定义节点"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 数据模型 ============
|
|||
|
|
|
|||
|
|
class NodeValidateRequest(BaseModel):
|
|||
|
|
code: str
|
|||
|
|
filename: Optional[str] = "<string>"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeSaveRequest(BaseModel):
|
|||
|
|
filename: str
|
|||
|
|
code: str
|
|||
|
|
force: bool = False # 是否强制保存(覆盖)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeActionRequest(BaseModel):
|
|||
|
|
filename: str
|
|||
|
|
action: str # 'load', 'unload', 'delete', 'rename'
|
|||
|
|
new_filename: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ API 端点 ============
|
|||
|
|
|
|||
|
|
@router.get("/list")
|
|||
|
|
async def list_custom_nodes():
|
|||
|
|
"""
|
|||
|
|
列出所有自定义节点
|
|||
|
|
|
|||
|
|
返回节点文件列表,包含验证状态
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
nodes = loader.scan_nodes()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"total": len(nodes),
|
|||
|
|
"nodes": nodes
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"扫描节点失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/validate")
|
|||
|
|
async def validate_node(request: NodeValidateRequest):
|
|||
|
|
"""
|
|||
|
|
验证节点代码
|
|||
|
|
|
|||
|
|
不保存,仅检查语法和安全性
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = validate_node_code(request.code)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"validation": result
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"验证失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/read/{filename}")
|
|||
|
|
async def read_node(filename: str):
|
|||
|
|
"""
|
|||
|
|
读取节点代码
|
|||
|
|
|
|||
|
|
用于前端编辑器加载
|
|||
|
|
"""
|
|||
|
|
# 验证文件名
|
|||
|
|
is_safe, error = NodeValidator.is_safe_filename(filename)
|
|||
|
|
if not is_safe:
|
|||
|
|
raise HTTPException(status_code=400, detail=error)
|
|||
|
|
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
file_path = loader.custom_nodes_dir / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
code = file_path.read_text(encoding='utf-8')
|
|||
|
|
|
|||
|
|
# 获取文件信息
|
|||
|
|
stat = file_path.stat()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"filename": filename,
|
|||
|
|
"code": code,
|
|||
|
|
"size": stat.st_size,
|
|||
|
|
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/save")
|
|||
|
|
async def save_node(request: NodeSaveRequest):
|
|||
|
|
"""
|
|||
|
|
保存节点代码
|
|||
|
|
|
|||
|
|
需要先验证,然后保存并自动加载
|
|||
|
|
"""
|
|||
|
|
# 1. 验证文件名
|
|||
|
|
is_safe, error = NodeValidator.is_safe_filename(request.filename)
|
|||
|
|
if not is_safe:
|
|||
|
|
raise HTTPException(status_code=400, detail=error)
|
|||
|
|
|
|||
|
|
# 2. 验证代码
|
|||
|
|
validation = validate_node_code(request.code)
|
|||
|
|
if not validation['valid']:
|
|||
|
|
return {
|
|||
|
|
"success": False,
|
|||
|
|
"message": "代码验证失败,请修复错误后重试",
|
|||
|
|
"validation": validation
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 3. 检查文件是否存在
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
file_path = loader.custom_nodes_dir / request.filename
|
|||
|
|
|
|||
|
|
if file_path.exists() and not request.force:
|
|||
|
|
return {
|
|||
|
|
"success": False,
|
|||
|
|
"message": "文件已存在,请确认是否覆盖",
|
|||
|
|
"require_confirm": True,
|
|||
|
|
"validation": validation
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 4. 备份旧文件(如果存在)
|
|||
|
|
if file_path.exists():
|
|||
|
|
backup_name = f"{request.filename}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|||
|
|
backup_path = loader.custom_nodes_dir / backup_name
|
|||
|
|
shutil.copy2(file_path, backup_path)
|
|||
|
|
|
|||
|
|
# 5. 保存文件
|
|||
|
|
try:
|
|||
|
|
file_path.write_text(request.code, encoding='utf-8')
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
|
|||
|
|
|
|||
|
|
# 6. 自动加载节点
|
|||
|
|
load_result = loader.load_node(request.filename, force_reload=True)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"节点已保存并加载",
|
|||
|
|
"filename": request.filename,
|
|||
|
|
"validation": validation,
|
|||
|
|
"load_result": load_result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/action")
|
|||
|
|
async def node_action(request: NodeActionRequest):
|
|||
|
|
"""
|
|||
|
|
对节点执行操作
|
|||
|
|
|
|||
|
|
- load: 加载节点
|
|||
|
|
- unload: 卸载节点
|
|||
|
|
- delete: 删除节点文件
|
|||
|
|
- rename: 重命名节点文件
|
|||
|
|
"""
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
file_path = loader.custom_nodes_dir / request.filename
|
|||
|
|
|
|||
|
|
if not file_path.exists() and request.action != 'load':
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if request.action == 'load':
|
|||
|
|
result = loader.load_node(request.filename, force_reload=True)
|
|||
|
|
return {
|
|||
|
|
"success": result['success'],
|
|||
|
|
"message": result['message'],
|
|||
|
|
"classes": result['classes']
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif request.action == 'unload':
|
|||
|
|
loader.unload_node(request.filename)
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"已卸载节点: {request.filename}"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif request.action == 'delete':
|
|||
|
|
# 先卸载
|
|||
|
|
loader.unload_node(request.filename)
|
|||
|
|
|
|||
|
|
# 备份后删除
|
|||
|
|
backup_name = f"{request.filename}.deleted_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|||
|
|
backup_path = loader.custom_nodes_dir / backup_name
|
|||
|
|
shutil.move(file_path, backup_path)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"已删除节点(备份为 {backup_name})"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif request.action == 'rename':
|
|||
|
|
if not request.new_filename:
|
|||
|
|
raise HTTPException(status_code=400, detail="缺少 new_filename 参数")
|
|||
|
|
|
|||
|
|
# 验证新文件名
|
|||
|
|
is_safe, error = NodeValidator.is_safe_filename(request.new_filename)
|
|||
|
|
if not is_safe:
|
|||
|
|
raise HTTPException(status_code=400, detail=error)
|
|||
|
|
|
|||
|
|
new_path = loader.custom_nodes_dir / request.new_filename
|
|||
|
|
if new_path.exists():
|
|||
|
|
raise HTTPException(status_code=400, detail="目标文件名已存在")
|
|||
|
|
|
|||
|
|
# 先卸载
|
|||
|
|
loader.unload_node(request.filename)
|
|||
|
|
|
|||
|
|
# 重命名
|
|||
|
|
file_path.rename(new_path)
|
|||
|
|
|
|||
|
|
# 重新加载
|
|||
|
|
loader.load_node(request.new_filename)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"已重命名: {request.filename} -> {request.new_filename}"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"不支持的操作: {request.action}")
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/download/{filename}")
|
|||
|
|
async def download_node(filename: str):
|
|||
|
|
"""下载节点文件"""
|
|||
|
|
# 验证文件名
|
|||
|
|
is_safe, error = NodeValidator.is_safe_filename(filename)
|
|||
|
|
if not is_safe:
|
|||
|
|
raise HTTPException(status_code=400, detail=error)
|
|||
|
|
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
file_path = loader.custom_nodes_dir / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
return FileResponse(
|
|||
|
|
path=file_path,
|
|||
|
|
media_type='text/x-python',
|
|||
|
|
filename=filename
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/upload")
|
|||
|
|
async def upload_node(file: UploadFile = File(...)):
|
|||
|
|
"""
|
|||
|
|
上传节点文件
|
|||
|
|
|
|||
|
|
上传后自动验证和加载
|
|||
|
|
"""
|
|||
|
|
# 验证文件名
|
|||
|
|
is_safe, error = NodeValidator.is_safe_filename(file.filename)
|
|||
|
|
if not is_safe:
|
|||
|
|
raise HTTPException(status_code=400, detail=error)
|
|||
|
|
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
file_path = loader.custom_nodes_dir / file.filename
|
|||
|
|
|
|||
|
|
# 检查文件是否存在
|
|||
|
|
if file_path.exists():
|
|||
|
|
return {
|
|||
|
|
"success": False,
|
|||
|
|
"message": "文件已存在,请先删除或使用不同的文件名",
|
|||
|
|
"require_confirm": True
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 保存临时文件
|
|||
|
|
try:
|
|||
|
|
content = await file.read()
|
|||
|
|
code = content.decode('utf-8')
|
|||
|
|
|
|||
|
|
# 验证代码
|
|||
|
|
validation = validate_node_code(code)
|
|||
|
|
if not validation['valid']:
|
|||
|
|
return {
|
|||
|
|
"success": False,
|
|||
|
|
"message": "节点验证失败",
|
|||
|
|
"validation": validation
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 保存文件
|
|||
|
|
file_path.write_text(code, encoding='utf-8')
|
|||
|
|
|
|||
|
|
# 加载节点
|
|||
|
|
load_result = loader.load_node(file.filename)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": "节点已上传并加载",
|
|||
|
|
"filename": file.filename,
|
|||
|
|
"validation": validation,
|
|||
|
|
"load_result": load_result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except UnicodeDecodeError:
|
|||
|
|
raise HTTPException(status_code=400, detail="文件编码错误,必须是UTF-8")
|
|||
|
|
except Exception as e:
|
|||
|
|
# 如果失败,删除已保存的文件
|
|||
|
|
if file_path.exists():
|
|||
|
|
file_path.unlink()
|
|||
|
|
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/loaded")
|
|||
|
|
async def get_loaded_nodes():
|
|||
|
|
"""
|
|||
|
|
获取已加载的节点
|
|||
|
|
|
|||
|
|
返回当前内存中的节点类
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
loaded = loader.get_all_loaded_nodes()
|
|||
|
|
|
|||
|
|
result = []
|
|||
|
|
for class_name, data in loaded.items():
|
|||
|
|
result.append({
|
|||
|
|
'class_name': class_name,
|
|||
|
|
'metadata': data['metadata']
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"total": len(result),
|
|||
|
|
"nodes": result
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"获取失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/reload-all")
|
|||
|
|
async def reload_all_nodes():
|
|||
|
|
"""重新加载所有自定义节点"""
|
|||
|
|
try:
|
|||
|
|
loader = get_node_loader()
|
|||
|
|
result = loader.load_all_nodes()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"已加载 {result['loaded']}/{result['total']} 个节点文件",
|
|||
|
|
"details": result
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"重新加载失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/example")
|
|||
|
|
async def get_example_node():
|
|||
|
|
"""获取示例节点代码"""
|
|||
|
|
example_code = '''"""
|
|||
|
|
示例自定义节点
|
|||
|
|
"""
|
|||
|
|
from app.core.node_base import TraceNode
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExampleFilterNode(TraceNode):
|
|||
|
|
"""数据过滤节点"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_metadata():
|
|||
|
|
return {
|
|||
|
|
"display_name": "示例过滤节点",
|
|||
|
|
"description": "根据阈值过滤数据",
|
|||
|
|
"category": "Custom/Examples",
|
|||
|
|
"author": "Your Name",
|
|||
|
|
"version": "1.0.0",
|
|||
|
|
"inputs": [
|
|||
|
|
{
|
|||
|
|
"name": "data",
|
|||
|
|
"type": "DataFrame",
|
|||
|
|
"description": "输入数据"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "filtered_data",
|
|||
|
|
"type": "DataFrame",
|
|||
|
|
"description": "过滤后的数据"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"params": [
|
|||
|
|
{
|
|||
|
|
"name": "threshold",
|
|||
|
|
"type": "float",
|
|||
|
|
"default": 0.5,
|
|||
|
|
"description": "过滤阈值"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "column",
|
|||
|
|
"type": "str",
|
|||
|
|
"default": "value",
|
|||
|
|
"description": "过滤列名"
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def execute(self, inputs):
|
|||
|
|
"""执行节点逻辑"""
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
# 获取输入
|
|||
|
|
data = inputs.get('data')
|
|||
|
|
threshold = self.params.get('threshold', 0.5)
|
|||
|
|
column = self.params.get('column', 'value')
|
|||
|
|
|
|||
|
|
# 过滤数据
|
|||
|
|
if isinstance(data, pd.DataFrame):
|
|||
|
|
filtered = data[data[column] > threshold]
|
|||
|
|
return {'filtered_data': filtered}
|
|||
|
|
else:
|
|||
|
|
raise ValueError("输入必须是 DataFrame")
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"code": example_code,
|
|||
|
|
"filename": "example_filter_node.py"
|
|||
|
|
}
|