225 lines
5.9 KiB
Python
225 lines
5.9 KiB
Python
"""
|
||
节点注册中心 (v2.0)
|
||
支持装饰器自动注册和元数据收集
|
||
"""
|
||
from typing import Dict, Type, List, Optional
|
||
from .node_base import TraceNode, auto_collect_specs
|
||
|
||
|
||
class NodeRegistry:
|
||
"""节点注册中心 - 管理所有可用的节点类型"""
|
||
|
||
_nodes: Dict[str, Type[TraceNode]] = {}
|
||
|
||
@classmethod
|
||
def register(cls, node_class: Type[TraceNode]) -> Type[TraceNode]:
|
||
"""
|
||
注册节点类
|
||
|
||
Args:
|
||
node_class: 节点类(必须继承 TraceNode)
|
||
|
||
Returns:
|
||
注册后的节点类(支持链式调用)
|
||
"""
|
||
if not issubclass(node_class, TraceNode):
|
||
raise TypeError(f"{node_class.__name__} 必须继承 TraceNode")
|
||
|
||
class_name = node_class.__name__
|
||
|
||
# 自动收集装饰器标记的属性
|
||
node_class = auto_collect_specs(node_class)
|
||
|
||
# 检查重复注册
|
||
if class_name in cls._nodes:
|
||
print(f"⚠️ 警告: 节点 {class_name} 已存在,将被覆盖")
|
||
|
||
# 注册节点
|
||
cls._nodes[class_name] = node_class
|
||
|
||
# 输出日志
|
||
display_name = node_class.DISPLAY_NAME or class_name
|
||
category = node_class.CATEGORY
|
||
print(f"✅ 注册节点: {display_name} [{class_name}] ({category})")
|
||
|
||
return node_class
|
||
|
||
@classmethod
|
||
def unregister(cls, class_name: str) -> bool:
|
||
"""
|
||
注销节点
|
||
|
||
Args:
|
||
class_name: 节点类名
|
||
|
||
Returns:
|
||
是否成功注销
|
||
"""
|
||
if class_name in cls._nodes:
|
||
del cls._nodes[class_name]
|
||
print(f"🗑️ 注销节点: {class_name}")
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
def get(cls, class_name: str) -> Type[TraceNode]:
|
||
"""
|
||
获取节点类
|
||
|
||
Args:
|
||
class_name: 节点类名
|
||
|
||
Returns:
|
||
节点类
|
||
|
||
Raises:
|
||
KeyError: 节点未注册
|
||
"""
|
||
if class_name not in cls._nodes:
|
||
raise KeyError(f"未找到节点: {class_name}. 可用节点: {list(cls._nodes.keys())}")
|
||
return cls._nodes[class_name]
|
||
|
||
@classmethod
|
||
def exists(cls, class_name: str) -> bool:
|
||
"""
|
||
检查节点是否存在
|
||
|
||
Args:
|
||
class_name: 节点类名
|
||
|
||
Returns:
|
||
是否存在
|
||
"""
|
||
return class_name in cls._nodes
|
||
|
||
@classmethod
|
||
def list_all(cls) -> Dict[str, Type[TraceNode]]:
|
||
"""
|
||
列出所有注册的节点
|
||
|
||
Returns:
|
||
节点字典 {class_name: node_class}
|
||
"""
|
||
return cls._nodes.copy()
|
||
|
||
@classmethod
|
||
def get_by_category(cls, category: str) -> List[Type[TraceNode]]:
|
||
"""
|
||
按分类获取节点
|
||
|
||
Args:
|
||
category: 分类路径(支持前缀匹配)
|
||
如 "Data" 匹配 "Data/Load", "Data/Transform" 等
|
||
|
||
Returns:
|
||
节点类列表
|
||
"""
|
||
return [
|
||
node_cls for node_cls in cls._nodes.values()
|
||
if node_cls.CATEGORY.startswith(category)
|
||
]
|
||
|
||
@classmethod
|
||
def get_metadata_list(cls) -> List[Dict]:
|
||
"""
|
||
获取所有节点的元数据(供前端使用)
|
||
|
||
Returns:
|
||
元数据列表
|
||
"""
|
||
return [
|
||
node_cls.get_metadata()
|
||
for node_cls in cls._nodes.values()
|
||
]
|
||
|
||
@classmethod
|
||
def get_metadata(cls, class_name: str) -> Optional[Dict]:
|
||
"""
|
||
获取指定节点的元数据
|
||
|
||
Args:
|
||
class_name: 节点类名
|
||
|
||
Returns:
|
||
元数据字典,如果节点不存在返回 None
|
||
"""
|
||
if class_name not in cls._nodes:
|
||
return None
|
||
return cls._nodes[class_name].get_metadata()
|
||
|
||
@classmethod
|
||
def get_categories(cls) -> List[str]:
|
||
"""
|
||
获取所有分类路径(去重)
|
||
|
||
Returns:
|
||
分类列表
|
||
"""
|
||
categories = set()
|
||
for node_cls in cls._nodes.values():
|
||
category = node_cls.CATEGORY
|
||
# 支持多级分类(如 "Data/Transform/Filter")
|
||
parts = category.split('/')
|
||
for i in range(1, len(parts) + 1):
|
||
categories.add('/'.join(parts[:i]))
|
||
return sorted(categories)
|
||
|
||
@classmethod
|
||
def clear(cls):
|
||
"""清空所有注册的节点(主要用于测试)"""
|
||
count = len(cls._nodes)
|
||
cls._nodes.clear()
|
||
print(f"🧹 已清空节点注册表 ({count} 个节点)")
|
||
|
||
@classmethod
|
||
def get_stats(cls) -> Dict:
|
||
"""
|
||
获取注册统计信息
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
total = len(cls._nodes)
|
||
by_category = {}
|
||
|
||
for node_cls in cls._nodes.values():
|
||
category = node_cls.CATEGORY
|
||
by_category[category] = by_category.get(category, 0) + 1
|
||
|
||
return {
|
||
"total": total,
|
||
"by_category": by_category,
|
||
"categories": cls.get_categories()
|
||
}
|
||
|
||
|
||
# ============= 装饰器 =============
|
||
|
||
def register_node(node_class: Type[TraceNode] = None):
|
||
"""
|
||
节点注册装饰器
|
||
|
||
用法:
|
||
方式一(推荐):
|
||
@register_node
|
||
class MyNode(TraceNode):
|
||
...
|
||
|
||
方式二:
|
||
@register_node()
|
||
class MyNode(TraceNode):
|
||
...
|
||
|
||
Args:
|
||
node_class: 节点类(由装饰器自动传入)
|
||
|
||
Returns:
|
||
注册后的节点类
|
||
"""
|
||
if node_class is None:
|
||
# 支持 @register_node() 形式
|
||
return lambda cls: NodeRegistry.register(cls)
|
||
|
||
# 支持 @register_node 形式
|
||
return NodeRegistry.register(node_class)
|