TraceStudio-dist/server/app/core/node_registry.py
2026-01-13 16:41:31 +08:00

225 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
节点注册中心 (v2.0)
支持装饰器自动注册和元数据收集
"""
from typing import Dict, Type, List, Optional
from .node_base import TraceNode, auto_collect_specs
class NodeRegistry:
"""节点注册中心 - 管理所有可用的节点类型"""
_nodes: Dict[str, Type[TraceNode]] = {}
@classmethod
def register(cls, node_class: Type[TraceNode]) -> Type[TraceNode]:
"""
注册节点类
Args:
node_class: 节点类(必须继承 TraceNode
Returns:
注册后的节点类(支持链式调用)
"""
if not issubclass(node_class, TraceNode):
raise TypeError(f"{node_class.__name__} 必须继承 TraceNode")
class_name = node_class.__name__
# 自动收集装饰器标记的属性
node_class = auto_collect_specs(node_class)
# 检查重复注册
if class_name in cls._nodes:
print(f"⚠️ 警告: 节点 {class_name} 已存在,将被覆盖")
# 注册节点
cls._nodes[class_name] = node_class
# 输出日志
display_name = node_class.DISPLAY_NAME or class_name
category = node_class.CATEGORY
print(f"✅ 注册节点: {display_name} [{class_name}] ({category})")
return node_class
@classmethod
def unregister(cls, class_name: str) -> bool:
"""
注销节点
Args:
class_name: 节点类名
Returns:
是否成功注销
"""
if class_name in cls._nodes:
del cls._nodes[class_name]
print(f"🗑️ 注销节点: {class_name}")
return True
return False
@classmethod
def get(cls, class_name: str) -> Type[TraceNode]:
"""
获取节点类
Args:
class_name: 节点类名
Returns:
节点类
Raises:
KeyError: 节点未注册
"""
if class_name not in cls._nodes:
raise KeyError(f"未找到节点: {class_name}. 可用节点: {list(cls._nodes.keys())}")
return cls._nodes[class_name]
@classmethod
def exists(cls, class_name: str) -> bool:
"""
检查节点是否存在
Args:
class_name: 节点类名
Returns:
是否存在
"""
return class_name in cls._nodes
@classmethod
def list_all(cls) -> Dict[str, Type[TraceNode]]:
"""
列出所有注册的节点
Returns:
节点字典 {class_name: node_class}
"""
return cls._nodes.copy()
@classmethod
def get_by_category(cls, category: str) -> List[Type[TraceNode]]:
"""
按分类获取节点
Args:
category: 分类路径(支持前缀匹配)
"Data" 匹配 "Data/Load", "Data/Transform"
Returns:
节点类列表
"""
return [
node_cls for node_cls in cls._nodes.values()
if node_cls.CATEGORY.startswith(category)
]
@classmethod
def get_metadata_list(cls) -> List[Dict]:
"""
获取所有节点的元数据(供前端使用)
Returns:
元数据列表
"""
return [
node_cls.get_metadata()
for node_cls in cls._nodes.values()
]
@classmethod
def get_metadata(cls, class_name: str) -> Optional[Dict]:
"""
获取指定节点的元数据
Args:
class_name: 节点类名
Returns:
元数据字典,如果节点不存在返回 None
"""
if class_name not in cls._nodes:
return None
return cls._nodes[class_name].get_metadata()
@classmethod
def get_categories(cls) -> List[str]:
"""
获取所有分类路径(去重)
Returns:
分类列表
"""
categories = set()
for node_cls in cls._nodes.values():
category = node_cls.CATEGORY
# 支持多级分类(如 "Data/Transform/Filter"
parts = category.split('/')
for i in range(1, len(parts) + 1):
categories.add('/'.join(parts[:i]))
return sorted(categories)
@classmethod
def clear(cls):
"""清空所有注册的节点(主要用于测试)"""
count = len(cls._nodes)
cls._nodes.clear()
print(f"🧹 已清空节点注册表 ({count} 个节点)")
@classmethod
def get_stats(cls) -> Dict:
"""
获取注册统计信息
Returns:
统计信息字典
"""
total = len(cls._nodes)
by_category = {}
for node_cls in cls._nodes.values():
category = node_cls.CATEGORY
by_category[category] = by_category.get(category, 0) + 1
return {
"total": total,
"by_category": by_category,
"categories": cls.get_categories()
}
# ============= 装饰器 =============
def register_node(node_class: Type[TraceNode] = None):
"""
节点注册装饰器
用法:
方式一(推荐):
@register_node
class MyNode(TraceNode):
...
方式二:
@register_node()
class MyNode(TraceNode):
...
Args:
node_class: 节点类(由装饰器自动传入)
Returns:
注册后的节点类
"""
if node_class is None:
# 支持 @register_node() 形式
return lambda cls: NodeRegistry.register(cls)
# 支持 @register_node 形式
return NodeRegistry.register(node_class)