TraceStudio-dev/server/app/core/node_loader.py

177 lines
4.9 KiB
Python
Raw Permalink Normal View History

"""
节点加载器
负责自动加载内置节点和自定义节点
支持函数节点加载
"""
import importlib
import sys
import json
from pathlib import Path
from typing import Dict, Any
from ..core.user_manager import CLOUD_ROOT
from ..core.cache_manager import CacheManager
def load_builtin_nodes() -> Dict[str, Any]:
"""
加载内置节点
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
# 确保项目根目录在 sys.path 中,便于按包名导入
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 获取节点模块目录
nodes_dir = Path(__file__).parent.parent / "nodes"
custom_nodes_dir = CLOUD_ROOT / "custom_nodes"
# 要加载的模块列表
ignore_node_modules = [
"example_nodes",
"advanced_example_nodes",
]
loaded = 0
failed = 0
errors = []
for module_path in nodes_dir.glob("*.py"):
module_name = module_path.stem
if module_name in ignore_node_modules:
continue
try:
# 导入模块(按完整包路径)
module = importlib.import_module(f"server.app.nodes.{module_name}")
loaded += 1
print(f"📦 加载节点模块: {module_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载失败: {module_name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
# todo: 加载 custom_nodes 目录下的用户自定义节点模块,如果有的话
if custom_nodes_dir.exists():
for module_path in custom_nodes_dir.glob("*.py"):
module_name = module_path.stem
try:
module = importlib.import_module(f"cloud.custom_nodes.{module_name}")
loaded += 1
print(f"📦 加载自定义节点模块: {module_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载自定义节点失败: {module_name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
# 获取注册统计
stats = NodeRegistry.get_stats()
return {
"modules_loaded": loaded,
"modules_failed": failed,
"nodes_registered": stats["total"],
"categories": stats["categories"],
"errors": errors
}
def load_custom_nodes() -> Dict[str, Any]:
"""
加载自定义节点
Returns:
加载结果统计
"""
# TODO: 实现自定义节点加载逻辑
return {
"loaded": 0,
"failed": 0,
"errors": []
}
def load_function_nodes() -> Dict[str, Any]:
"""
加载函数节点 functions 目录
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
from .function_nodes import create_function_node_class
# 获取 functions 目录
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
if not functions_dir.exists():
return {
"loaded": 0,
"failed": 0,
"errors": ["functions 目录不存在"]
}
loaded = 0
failed = 0
errors = []
# 遍历所有 .json 文件
for json_file in functions_dir.glob("*.json"):
try:
# 读取函数定义
with open(json_file, 'r', encoding='utf-8') as f:
workflow_data = json.load(f)
function_name = workflow_data.get("function_name", json_file.stem)
# 创建函数节点类
node_class = create_function_node_class(function_name, workflow_data)
# 注册到节点注册表
NodeRegistry.register(node_class)
loaded += 1
print(f"📦 加载函数节点: {function_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载函数失败: {json_file.name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
return {
"loaded": loaded,
"failed": failed,
"errors": errors
}
def reload_custom_nodes() -> Dict[str, Any]:
"""
重新加载所有节点包括内置和自定义
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
# 清空注册表
NodeRegistry.clear()
# 加载内置节点
builtin_result = load_builtin_nodes()
# 加载自定义节点
custom_result = load_custom_nodes()
total = builtin_result["nodes_registered"] + custom_result["loaded"]
return {
"total": total,
"loaded": total,
"builtin": builtin_result,
"custom": custom_result
}