225 lines
5.9 KiB
Python
225 lines
5.9 KiB
Python
|
|
"""
|
|||
|
|
节点注册中心 (v2.0)
|
|||
|
|
支持装饰器自动注册和元数据收集
|
|||
|
|
"""
|
|||
|
|
from typing import Dict, Type, List, Optional
|
|||
|
|
from .node_base import TraceNode, auto_collect_specs
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeRegistry:
|
|||
|
|
"""节点注册中心 - 管理所有可用的节点类型"""
|
|||
|
|
|
|||
|
|
_nodes: Dict[str, Type[TraceNode]] = {}
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def register(cls, node_class: Type[TraceNode]) -> Type[TraceNode]:
|
|||
|
|
"""
|
|||
|
|
注册节点类
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node_class: 节点类(必须继承 TraceNode)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
注册后的节点类(支持链式调用)
|
|||
|
|
"""
|
|||
|
|
if not issubclass(node_class, TraceNode):
|
|||
|
|
raise TypeError(f"{node_class.__name__} 必须继承 TraceNode")
|
|||
|
|
|
|||
|
|
class_name = node_class.__name__
|
|||
|
|
|
|||
|
|
# 自动收集装饰器标记的属性
|
|||
|
|
node_class = auto_collect_specs(node_class)
|
|||
|
|
|
|||
|
|
# 检查重复注册
|
|||
|
|
if class_name in cls._nodes:
|
|||
|
|
print(f"⚠️ 警告: 节点 {class_name} 已存在,将被覆盖")
|
|||
|
|
|
|||
|
|
# 注册节点
|
|||
|
|
cls._nodes[class_name] = node_class
|
|||
|
|
|
|||
|
|
# 输出日志
|
|||
|
|
display_name = node_class.DISPLAY_NAME or class_name
|
|||
|
|
category = node_class.CATEGORY
|
|||
|
|
print(f"✅ 注册节点: {display_name} [{class_name}] ({category})")
|
|||
|
|
|
|||
|
|
return node_class
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def unregister(cls, class_name: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
注销节点
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
class_name: 节点类名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功注销
|
|||
|
|
"""
|
|||
|
|
if class_name in cls._nodes:
|
|||
|
|
del cls._nodes[class_name]
|
|||
|
|
print(f"🗑️ 注销节点: {class_name}")
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get(cls, class_name: str) -> Type[TraceNode]:
|
|||
|
|
"""
|
|||
|
|
获取节点类
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
class_name: 节点类名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
节点类
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
KeyError: 节点未注册
|
|||
|
|
"""
|
|||
|
|
if class_name not in cls._nodes:
|
|||
|
|
raise KeyError(f"未找到节点: {class_name}. 可用节点: {list(cls._nodes.keys())}")
|
|||
|
|
return cls._nodes[class_name]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def exists(cls, class_name: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查节点是否存在
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
class_name: 节点类名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否存在
|
|||
|
|
"""
|
|||
|
|
return class_name in cls._nodes
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def list_all(cls) -> Dict[str, Type[TraceNode]]:
|
|||
|
|
"""
|
|||
|
|
列出所有注册的节点
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
节点字典 {class_name: node_class}
|
|||
|
|
"""
|
|||
|
|
return cls._nodes.copy()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_by_category(cls, category: str) -> List[Type[TraceNode]]:
|
|||
|
|
"""
|
|||
|
|
按分类获取节点
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
category: 分类路径(支持前缀匹配)
|
|||
|
|
如 "Data" 匹配 "Data/Load", "Data/Transform" 等
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
节点类列表
|
|||
|
|
"""
|
|||
|
|
return [
|
|||
|
|
node_cls for node_cls in cls._nodes.values()
|
|||
|
|
if node_cls.CATEGORY.startswith(category)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_metadata_list(cls) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
获取所有节点的元数据(供前端使用)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
元数据列表
|
|||
|
|
"""
|
|||
|
|
return [
|
|||
|
|
node_cls.get_metadata()
|
|||
|
|
for node_cls in cls._nodes.values()
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_metadata(cls, class_name: str) -> Optional[Dict]:
|
|||
|
|
"""
|
|||
|
|
获取指定节点的元数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
class_name: 节点类名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
元数据字典,如果节点不存在返回 None
|
|||
|
|
"""
|
|||
|
|
if class_name not in cls._nodes:
|
|||
|
|
return None
|
|||
|
|
return cls._nodes[class_name].get_metadata()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_categories(cls) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
获取所有分类路径(去重)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
分类列表
|
|||
|
|
"""
|
|||
|
|
categories = set()
|
|||
|
|
for node_cls in cls._nodes.values():
|
|||
|
|
category = node_cls.CATEGORY
|
|||
|
|
# 支持多级分类(如 "Data/Transform/Filter")
|
|||
|
|
parts = category.split('/')
|
|||
|
|
for i in range(1, len(parts) + 1):
|
|||
|
|
categories.add('/'.join(parts[:i]))
|
|||
|
|
return sorted(categories)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def clear(cls):
|
|||
|
|
"""清空所有注册的节点(主要用于测试)"""
|
|||
|
|
count = len(cls._nodes)
|
|||
|
|
cls._nodes.clear()
|
|||
|
|
print(f"🧹 已清空节点注册表 ({count} 个节点)")
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_stats(cls) -> Dict:
|
|||
|
|
"""
|
|||
|
|
获取注册统计信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
统计信息字典
|
|||
|
|
"""
|
|||
|
|
total = len(cls._nodes)
|
|||
|
|
by_category = {}
|
|||
|
|
|
|||
|
|
for node_cls in cls._nodes.values():
|
|||
|
|
category = node_cls.CATEGORY
|
|||
|
|
by_category[category] = by_category.get(category, 0) + 1
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"total": total,
|
|||
|
|
"by_category": by_category,
|
|||
|
|
"categories": cls.get_categories()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============= 装饰器 =============
|
|||
|
|
|
|||
|
|
def register_node(node_class: Type[TraceNode] = None):
|
|||
|
|
"""
|
|||
|
|
节点注册装饰器
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
方式一(推荐):
|
|||
|
|
@register_node
|
|||
|
|
class MyNode(TraceNode):
|
|||
|
|
...
|
|||
|
|
|
|||
|
|
方式二:
|
|||
|
|
@register_node()
|
|||
|
|
class MyNode(TraceNode):
|
|||
|
|
...
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node_class: 节点类(由装饰器自动传入)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
注册后的节点类
|
|||
|
|
"""
|
|||
|
|
if node_class is None:
|
|||
|
|
# 支持 @register_node() 形式
|
|||
|
|
return lambda cls: NodeRegistry.register(cls)
|
|||
|
|
|
|||
|
|
# 支持 @register_node 形式
|
|||
|
|
return NodeRegistry.register(node_class)
|