217 lines
6.9 KiB
Python
217 lines
6.9 KiB
Python
|
|
"""
|
|||
|
|
节点加载器
|
|||
|
|
负责自动加载内置节点和自定义节点
|
|||
|
|
支持函数节点加载
|
|||
|
|
"""
|
|||
|
|
import importlib
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, Any
|
|||
|
|
from ..core.user_manager import CLOUD_ROOT
|
|||
|
|
from ..core.cache_manager import CacheManager
|
|||
|
|
|
|||
|
|
def load_builtin_nodes() -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
加载内置节点
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
加载结果统计
|
|||
|
|
"""
|
|||
|
|
from .node_registry import NodeRegistry
|
|||
|
|
|
|||
|
|
# 确保项目根目录在 sys.path 中,便于按包名导入
|
|||
|
|
project_root = Path(__file__).resolve().parents[2]
|
|||
|
|
if str(project_root) not in sys.path:
|
|||
|
|
sys.path.insert(0, str(project_root))
|
|||
|
|
|
|||
|
|
# 获取节点模块目录
|
|||
|
|
nodes_dir = Path(__file__).parent.parent / "nodes"
|
|||
|
|
custom_nodes_dir = CLOUD_ROOT / "custom_nodes"
|
|||
|
|
|
|||
|
|
# 要加载的模块列表
|
|||
|
|
ignore_node_modules = [
|
|||
|
|
"example_nodes",
|
|||
|
|
"advanced_example_nodes",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
loaded = 0
|
|||
|
|
failed = 0
|
|||
|
|
errors = []
|
|||
|
|
|
|||
|
|
import importlib.util
|
|||
|
|
import types
|
|||
|
|
|
|||
|
|
for module_path in nodes_dir.glob("*.py"):
|
|||
|
|
module_name = module_path.stem
|
|||
|
|
if module_name in ignore_node_modules:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
# 使用文件路径导入模块,更稳健于不同运行上下文(容器、打包后的 dist 等)
|
|||
|
|
spec = importlib.util.spec_from_file_location(f"server.app.nodes.{module_name}", str(module_path))
|
|||
|
|
if spec is None or spec.loader is None:
|
|||
|
|
raise ImportError(f"无法为 {module_name} 创建模块规范")
|
|||
|
|
module = importlib.util.module_from_spec(spec)
|
|||
|
|
|
|||
|
|
# 支持热重载:若模块已在 sys.modules 中,先移除再加载
|
|||
|
|
if module.__name__ in sys.modules:
|
|||
|
|
del sys.modules[module.__name__]
|
|||
|
|
|
|||
|
|
spec.loader.exec_module(module)
|
|||
|
|
# 将模块放入 sys.modules,允许其它地方按包名引用(如果需要)
|
|||
|
|
sys.modules[module.__name__] = module
|
|||
|
|
|
|||
|
|
loaded += 1
|
|||
|
|
print(f"📦 加载节点模块: {module_name}")
|
|||
|
|
except Exception as e:
|
|||
|
|
failed += 1
|
|||
|
|
# 输出更详细的异常堆栈信息,便于诊断
|
|||
|
|
import traceback
|
|||
|
|
tb = traceback.format_exc()
|
|||
|
|
error_msg = f"❌ 加载失败: {module_name} - {str(e)}\n{tb}"
|
|||
|
|
errors.append(error_msg)
|
|||
|
|
print(error_msg)
|
|||
|
|
# todo: 加载 custom_nodes 目录下的用户自定义节点模块,如果有的话
|
|||
|
|
if custom_nodes_dir.exists():
|
|||
|
|
# `custom_nodes.<name>` 导入;不需要在磁盘上创建 package 文件。
|
|||
|
|
if 'custom_nodes' not in sys.modules:
|
|||
|
|
cn_pkg = types.ModuleType('custom_nodes')
|
|||
|
|
cn_pkg.__path__ = [str(custom_nodes_dir)]
|
|||
|
|
sys.modules['custom_nodes'] = cn_pkg
|
|||
|
|
|
|||
|
|
for module_path in custom_nodes_dir.glob("*.py"):
|
|||
|
|
module_name = module_path.stem
|
|||
|
|
try:
|
|||
|
|
# 使用文件路径导入自定义节点,模块名使用 custom_nodes.<module_name>
|
|||
|
|
spec = importlib.util.spec_from_file_location(f"custom_nodes.{module_name}", str(module_path))
|
|||
|
|
if spec is None or spec.loader is None:
|
|||
|
|
raise ImportError(f"无法为自定义节点 {module_name} 创建模块规范")
|
|||
|
|
module = importlib.util.module_from_spec(spec)
|
|||
|
|
|
|||
|
|
if module.__name__ in sys.modules:
|
|||
|
|
del sys.modules[module.__name__]
|
|||
|
|
|
|||
|
|
spec.loader.exec_module(module)
|
|||
|
|
# 注册为 custom_nodes.<name>
|
|||
|
|
sys.modules[module.__name__] = module
|
|||
|
|
# 兼容性:如果现有代码仍然尝试以 cloud.custom_nodes.<name> 导入,
|
|||
|
|
# 可以在此处创建别名(可选)。当前不创建 cloud 别名以遵循你的要求。
|
|||
|
|
|
|||
|
|
loaded += 1
|
|||
|
|
print(f"📦 加载自定义节点模块: {module_name}")
|
|||
|
|
except Exception as e:
|
|||
|
|
failed += 1
|
|||
|
|
import traceback
|
|||
|
|
tb = traceback.format_exc()
|
|||
|
|
error_msg = f"❌ 加载自定义节点失败: {module_name} - {str(e)}\n{tb}"
|
|||
|
|
errors.append(error_msg)
|
|||
|
|
print(error_msg)
|
|||
|
|
# 获取注册统计
|
|||
|
|
stats = NodeRegistry.get_stats()
|
|||
|
|
return {
|
|||
|
|
"modules_loaded": loaded,
|
|||
|
|
"modules_failed": failed,
|
|||
|
|
"nodes_registered": stats["total"],
|
|||
|
|
"categories": stats["categories"],
|
|||
|
|
"errors": errors
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_custom_nodes() -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
加载自定义节点
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
加载结果统计
|
|||
|
|
"""
|
|||
|
|
# TODO: 实现自定义节点加载逻辑
|
|||
|
|
return {
|
|||
|
|
"loaded": 0,
|
|||
|
|
"failed": 0,
|
|||
|
|
"errors": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_function_nodes() -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
加载函数节点(从 functions 目录)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
加载结果统计
|
|||
|
|
"""
|
|||
|
|
from .node_registry import NodeRegistry
|
|||
|
|
from .function_nodes import create_function_node_class
|
|||
|
|
|
|||
|
|
# 获取 functions 目录
|
|||
|
|
project_root = Path(__file__).resolve().parents[3]
|
|||
|
|
functions_dir = project_root / "cloud" / "custom_nodes" / "functions"
|
|||
|
|
|
|||
|
|
if not functions_dir.exists():
|
|||
|
|
return {
|
|||
|
|
"loaded": 0,
|
|||
|
|
"failed": 0,
|
|||
|
|
"errors": ["functions 目录不存在"]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
loaded = 0
|
|||
|
|
failed = 0
|
|||
|
|
errors = []
|
|||
|
|
|
|||
|
|
# 遍历所有 .json 文件
|
|||
|
|
for json_file in functions_dir.glob("*.json"):
|
|||
|
|
try:
|
|||
|
|
# 读取函数定义
|
|||
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|||
|
|
workflow_data = json.load(f)
|
|||
|
|
|
|||
|
|
function_name = workflow_data.get("function_name", json_file.stem)
|
|||
|
|
|
|||
|
|
# 创建函数节点类
|
|||
|
|
node_class = create_function_node_class(function_name, workflow_data)
|
|||
|
|
|
|||
|
|
# 注册到节点注册表
|
|||
|
|
NodeRegistry.register(node_class)
|
|||
|
|
|
|||
|
|
loaded += 1
|
|||
|
|
print(f"📦 加载函数节点: {function_name}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
failed += 1
|
|||
|
|
error_msg = f"❌ 加载函数失败: {json_file.name} - {str(e)}"
|
|||
|
|
errors.append(error_msg)
|
|||
|
|
print(error_msg)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"loaded": loaded,
|
|||
|
|
"failed": failed,
|
|||
|
|
"errors": errors
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def reload_custom_nodes() -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
重新加载所有节点(包括内置和自定义)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
加载结果统计
|
|||
|
|
"""
|
|||
|
|
from .node_registry import NodeRegistry
|
|||
|
|
|
|||
|
|
# 清空注册表
|
|||
|
|
NodeRegistry.clear()
|
|||
|
|
|
|||
|
|
# 加载内置节点
|
|||
|
|
builtin_result = load_builtin_nodes()
|
|||
|
|
|
|||
|
|
# 加载自定义节点
|
|||
|
|
custom_result = load_custom_nodes()
|
|||
|
|
|
|||
|
|
total = builtin_result["nodes_registered"] + custom_result["loaded"]
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"total": total,
|
|||
|
|
"loaded": total,
|
|||
|
|
"builtin": builtin_result,
|
|||
|
|
"custom": custom_result
|
|||
|
|
}
|