TraceStudio-dev/server/app/core/node_registry.py

225 lines
5.9 KiB
Python
Raw Normal View History

"""
节点注册中心 (v2.0)
支持装饰器自动注册和元数据收集
"""
from typing import Dict, Type, List, Optional
from .node_base import TraceNode, auto_collect_specs
class NodeRegistry:
"""节点注册中心 - 管理所有可用的节点类型"""
_nodes: Dict[str, Type[TraceNode]] = {}
@classmethod
def register(cls, node_class: Type[TraceNode]) -> Type[TraceNode]:
"""
注册节点类
Args:
node_class: 节点类必须继承 TraceNode
Returns:
注册后的节点类支持链式调用
"""
if not issubclass(node_class, TraceNode):
raise TypeError(f"{node_class.__name__} 必须继承 TraceNode")
class_name = node_class.__name__
# 自动收集装饰器标记的属性
node_class = auto_collect_specs(node_class)
# 检查重复注册
if class_name in cls._nodes:
print(f"⚠️ 警告: 节点 {class_name} 已存在,将被覆盖")
# 注册节点
cls._nodes[class_name] = node_class
# 输出日志
display_name = node_class.DISPLAY_NAME or class_name
category = node_class.CATEGORY
print(f"✅ 注册节点: {display_name} [{class_name}] ({category})")
return node_class
@classmethod
def unregister(cls, class_name: str) -> bool:
"""
注销节点
Args:
class_name: 节点类名
Returns:
是否成功注销
"""
if class_name in cls._nodes:
del cls._nodes[class_name]
print(f"🗑️ 注销节点: {class_name}")
return True
return False
@classmethod
def get(cls, class_name: str) -> Type[TraceNode]:
"""
获取节点类
Args:
class_name: 节点类名
Returns:
节点类
Raises:
KeyError: 节点未注册
"""
if class_name not in cls._nodes:
raise KeyError(f"未找到节点: {class_name}. 可用节点: {list(cls._nodes.keys())}")
return cls._nodes[class_name]
@classmethod
def exists(cls, class_name: str) -> bool:
"""
检查节点是否存在
Args:
class_name: 节点类名
Returns:
是否存在
"""
return class_name in cls._nodes
@classmethod
def list_all(cls) -> Dict[str, Type[TraceNode]]:
"""
列出所有注册的节点
Returns:
节点字典 {class_name: node_class}
"""
return cls._nodes.copy()
@classmethod
def get_by_category(cls, category: str) -> List[Type[TraceNode]]:
"""
按分类获取节点
Args:
category: 分类路径支持前缀匹配
"Data" 匹配 "Data/Load", "Data/Transform"
Returns:
节点类列表
"""
return [
node_cls for node_cls in cls._nodes.values()
if node_cls.CATEGORY.startswith(category)
]
@classmethod
def get_metadata_list(cls) -> List[Dict]:
"""
获取所有节点的元数据供前端使用
Returns:
元数据列表
"""
return [
node_cls.get_metadata()
for node_cls in cls._nodes.values()
]
@classmethod
def get_metadata(cls, class_name: str) -> Optional[Dict]:
"""
获取指定节点的元数据
Args:
class_name: 节点类名
Returns:
元数据字典如果节点不存在返回 None
"""
if class_name not in cls._nodes:
return None
return cls._nodes[class_name].get_metadata()
@classmethod
def get_categories(cls) -> List[str]:
"""
获取所有分类路径去重
Returns:
分类列表
"""
categories = set()
for node_cls in cls._nodes.values():
category = node_cls.CATEGORY
# 支持多级分类(如 "Data/Transform/Filter"
parts = category.split('/')
for i in range(1, len(parts) + 1):
categories.add('/'.join(parts[:i]))
return sorted(categories)
@classmethod
def clear(cls):
"""清空所有注册的节点(主要用于测试)"""
count = len(cls._nodes)
cls._nodes.clear()
print(f"🧹 已清空节点注册表 ({count} 个节点)")
@classmethod
def get_stats(cls) -> Dict:
"""
获取注册统计信息
Returns:
统计信息字典
"""
total = len(cls._nodes)
by_category = {}
for node_cls in cls._nodes.values():
category = node_cls.CATEGORY
by_category[category] = by_category.get(category, 0) + 1
return {
"total": total,
"by_category": by_category,
"categories": cls.get_categories()
}
# ============= 装饰器 =============
def register_node(node_class: Type[TraceNode] = None):
"""
节点注册装饰器
用法
方式一推荐:
@register_node
class MyNode(TraceNode):
...
方式二:
@register_node()
class MyNode(TraceNode):
...
Args:
node_class: 节点类由装饰器自动传入
Returns:
注册后的节点类
"""
if node_class is None:
# 支持 @register_node() 形式
return lambda cls: NodeRegistry.register(cls)
# 支持 @register_node 形式
return NodeRegistry.register(node_class)