""" 节点注册中心 (v2.0) 支持装饰器自动注册和元数据收集 """ from typing import Dict, Type, List, Optional from .node_base import TraceNode, auto_collect_specs class NodeRegistry: """节点注册中心 - 管理所有可用的节点类型""" _nodes: Dict[str, Type[TraceNode]] = {} @classmethod def register(cls, node_class: Type[TraceNode]) -> Type[TraceNode]: """ 注册节点类 Args: node_class: 节点类(必须继承 TraceNode) Returns: 注册后的节点类(支持链式调用) """ if not issubclass(node_class, TraceNode): raise TypeError(f"{node_class.__name__} 必须继承 TraceNode") class_name = node_class.__name__ # 自动收集装饰器标记的属性 node_class = auto_collect_specs(node_class) # 检查重复注册 if class_name in cls._nodes: print(f"⚠️ 警告: 节点 {class_name} 已存在,将被覆盖") # 注册节点 cls._nodes[class_name] = node_class # 输出日志 display_name = node_class.DISPLAY_NAME or class_name category = node_class.CATEGORY print(f"✅ 注册节点: {display_name} [{class_name}] ({category})") return node_class @classmethod def unregister(cls, class_name: str) -> bool: """ 注销节点 Args: class_name: 节点类名 Returns: 是否成功注销 """ if class_name in cls._nodes: del cls._nodes[class_name] print(f"🗑️ 注销节点: {class_name}") return True return False @classmethod def get(cls, class_name: str) -> Type[TraceNode]: """ 获取节点类 Args: class_name: 节点类名 Returns: 节点类 Raises: KeyError: 节点未注册 """ if class_name not in cls._nodes: raise KeyError(f"未找到节点: {class_name}. 可用节点: {list(cls._nodes.keys())}") return cls._nodes[class_name] @classmethod def exists(cls, class_name: str) -> bool: """ 检查节点是否存在 Args: class_name: 节点类名 Returns: 是否存在 """ return class_name in cls._nodes @classmethod def list_all(cls) -> Dict[str, Type[TraceNode]]: """ 列出所有注册的节点 Returns: 节点字典 {class_name: node_class} """ return cls._nodes.copy() @classmethod def get_by_category(cls, category: str) -> List[Type[TraceNode]]: """ 按分类获取节点 Args: category: 分类路径(支持前缀匹配) 如 "Data" 匹配 "Data/Load", "Data/Transform" 等 Returns: 节点类列表 """ return [ node_cls for node_cls in cls._nodes.values() if node_cls.CATEGORY.startswith(category) ] @classmethod def get_metadata_list(cls) -> List[Dict]: """ 获取所有节点的元数据(供前端使用) Returns: 元数据列表 """ return [ node_cls.get_metadata() for node_cls in cls._nodes.values() ] @classmethod def get_metadata(cls, class_name: str) -> Optional[Dict]: """ 获取指定节点的元数据 Args: class_name: 节点类名 Returns: 元数据字典,如果节点不存在返回 None """ if class_name not in cls._nodes: return None return cls._nodes[class_name].get_metadata() @classmethod def get_categories(cls) -> List[str]: """ 获取所有分类路径(去重) Returns: 分类列表 """ categories = set() for node_cls in cls._nodes.values(): category = node_cls.CATEGORY # 支持多级分类(如 "Data/Transform/Filter") parts = category.split('/') for i in range(1, len(parts) + 1): categories.add('/'.join(parts[:i])) return sorted(categories) @classmethod def clear(cls): """清空所有注册的节点(主要用于测试)""" count = len(cls._nodes) cls._nodes.clear() print(f"🧹 已清空节点注册表 ({count} 个节点)") @classmethod def get_stats(cls) -> Dict: """ 获取注册统计信息 Returns: 统计信息字典 """ total = len(cls._nodes) by_category = {} for node_cls in cls._nodes.values(): category = node_cls.CATEGORY by_category[category] = by_category.get(category, 0) + 1 return { "total": total, "by_category": by_category, "categories": cls.get_categories() } # ============= 装饰器 ============= def register_node(node_class: Type[TraceNode] = None): """ 节点注册装饰器 用法: 方式一(推荐): @register_node class MyNode(TraceNode): ... 方式二: @register_node() class MyNode(TraceNode): ... Args: node_class: 节点类(由装饰器自动传入) Returns: 注册后的节点类 """ if node_class is None: # 支持 @register_node() 形式 return lambda cls: NodeRegistry.register(cls) # 支持 @register_node 形式 return NodeRegistry.register(node_class)