TraceStudio/server/app/core/node_loader.py

177 lines
4.9 KiB
Python
Raw Normal View History

2026-01-12 21:51:45 +08:00
"""
节点加载器
负责自动加载内置节点和自定义节点
支持函数节点加载
"""
import importlib
import sys
import json
from pathlib import Path
from typing import Dict, Any
from ..core.user_manager import CLOUD_ROOT
from ..core.cache_manager import CacheManager
def load_builtin_nodes() -> Dict[str, Any]:
"""
加载内置节点
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
# 确保项目根目录在 sys.path 中,便于按包名导入
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 获取节点模块目录
nodes_dir = Path(__file__).parent.parent / "nodes"
custom_nodes_dir = CLOUD_ROOT / "custom_nodes"
# 要加载的模块列表
ignore_node_modules = [
"example_nodes",
"advanced_example_nodes",
]
loaded = 0
failed = 0
errors = []
for module_path in nodes_dir.glob("*.py"):
module_name = module_path.stem
if module_name in ignore_node_modules:
continue
try:
# 导入模块(按完整包路径)
module = importlib.import_module(f"server.app.nodes.{module_name}")
loaded += 1
print(f"📦 加载节点模块: {module_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载失败: {module_name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
# todo: 加载 custom_nodes 目录下的用户自定义节点模块,如果有的话
if custom_nodes_dir.exists():
for module_path in custom_nodes_dir.glob("*.py"):
module_name = module_path.stem
try:
module = importlib.import_module(f"cloud.custom_nodes.{module_name}")
loaded += 1
print(f"📦 加载自定义节点模块: {module_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载自定义节点失败: {module_name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
# 获取注册统计
stats = NodeRegistry.get_stats()
return {
"modules_loaded": loaded,
"modules_failed": failed,
"nodes_registered": stats["total"],
"categories": stats["categories"],
"errors": errors
}
def load_custom_nodes() -> Dict[str, Any]:
"""
加载自定义节点
Returns:
加载结果统计
"""
# TODO: 实现自定义节点加载逻辑
return {
"loaded": 0,
"failed": 0,
"errors": []
}
def load_function_nodes() -> Dict[str, Any]:
"""
加载函数节点 functions 目录
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
from .function_nodes import create_function_node_class
# 获取 functions 目录
project_root = Path(__file__).resolve().parents[3]
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
if not functions_dir.exists():
return {
"loaded": 0,
"failed": 0,
"errors": ["functions 目录不存在"]
}
loaded = 0
failed = 0
errors = []
# 遍历所有 .json 文件
for json_file in functions_dir.glob("*.json"):
try:
# 读取函数定义
with open(json_file, 'r', encoding='utf-8') as f:
workflow_data = json.load(f)
function_name = workflow_data.get("function_name", json_file.stem)
# 创建函数节点类
node_class = create_function_node_class(function_name, workflow_data)
# 注册到节点注册表
NodeRegistry.register(node_class)
loaded += 1
print(f"📦 加载函数节点: {function_name}")
except Exception as e:
failed += 1
error_msg = f"❌ 加载函数失败: {json_file.name} - {str(e)}"
errors.append(error_msg)
print(error_msg)
return {
"loaded": loaded,
"failed": failed,
"errors": errors
}
def reload_custom_nodes() -> Dict[str, Any]:
"""
重新加载所有节点包括内置和自定义
Returns:
加载结果统计
"""
from .node_registry import NodeRegistry
# 清空注册表
NodeRegistry.clear()
# 加载内置节点
builtin_result = load_builtin_nodes()
# 加载自定义节点
custom_result = load_custom_nodes()
total = builtin_result["nodes_registered"] + custom_result["loaded"]
return {
"total": total,
"loaded": total,
"builtin": builtin_result,
"custom": custom_result
}