TraceStudio-dev/server/tests/test_node_system.py

243 lines
6.7 KiB
Python
Raw Normal View History

"""
TraceStudio 节点开发测试
测试装饰器自动收集系统
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from server.app.core.node_registry import NodeRegistry
from server.app.nodes.example_nodes import (
AddNode,
MultiplyNode,
CSVLoaderNode,
FilterRowsNode,
SelectColumnsNode,
ConcatNode,
TableOutputNode,
StatisticsNode
)
def test_node_registration():
"""测试节点注册"""
print("\n" + "="*60)
print("测试 1: 节点注册")
print("="*60)
stats = NodeRegistry.get_stats()
print(f"\n已注册节点总数: {stats['total']}")
print(f"\n分类分布:")
for category, count in sorted(stats['by_category'].items()):
print(f" {category}: {count}")
print(f"\n所有分类:")
for category in stats['categories']:
print(f" - {category}")
def test_node_metadata():
"""测试节点元数据生成"""
print("\n" + "="*60)
print("测试 2: 节点元数据")
print("="*60)
# 测试 AddNode
metadata = AddNode.get_metadata()
print(f"\n{metadata['display_name']} ({metadata['class_name']})")
print(f"分类: {metadata['category']}")
print(f"描述: {metadata['description']}")
print(f"\n输入端口 ({len(metadata['inputs'])} 个):")
for inp in metadata['inputs']:
print(f" - {inp['name']} ({inp['type']}): {inp.get('description', '')}")
print(f"\n输出端口 ({len(metadata['outputs'])} 个):")
for out in metadata['outputs']:
print(f" - {out['name']} ({out['type']}): {out.get('description', '')}")
print(f"\n参数 ({len(metadata['params'])} 个):")
for par in metadata['params']:
print(f" - {par['name']} ({par['type']}): {par.get('description', '')} [默认: {par.get('default')}]")
print(f"\n上下文变量 ({len(metadata['context'])} 个):")
for ctx in metadata['context']:
print(f" - {ctx['name']} ({ctx['type']}): {ctx.get('description', '')}")
def test_node_execution():
"""测试节点执行"""
print("\n" + "="*60)
print("测试 3: 节点执行")
print("="*60)
# 测试 AddNode
add_node = AddNode(node_id="add_1", params={"offset": 10})
result = add_node.wrap_process(
inputs={"a": 5, "b": 3},
context={}
)
print(f"\nAddNode 执行结果:")
print(f" 输出: {result['outputs']}")
print(f" 上下文: {result['context']}")
# 测试 MultiplyNode
multiply_node = MultiplyNode(node_id="mul_1", params={"scale": 2.0})
result = multiply_node.wrap_process(
inputs={"a": 4, "b": 5},
context={}
)
print(f"\nMultiplyNode 执行结果:")
print(f" 输出: {result['outputs']}")
print(f" 上下文: {result['context']}")
def test_filter_node():
"""测试过滤节点"""
print("\n" + "="*60)
print("测试 4: 数据过滤节点")
print("="*60)
import pandas as pd
# 创建测试数据
test_data = pd.DataFrame({
'name': ['Alice', 'Bob', 'Charlie', 'David'],
'age': [25, 30, 35, 28],
'score': [85, 92, 78, 88]
})
print("\n原始数据:")
print(test_data)
# 创建过滤节点
filter_node = FilterRowsNode(
node_id="filter_1",
params={"column": "score", "operator": ">", "value": 80}
)
result = filter_node.wrap_process(
inputs={"table": test_data},
context={}
)
print("\n过滤后数据 (score > 80):")
print(result['outputs']['filtered'])
print(f"\n上下文: {result['context']}")
def test_concat_node():
"""测试聚合节点(多输入)"""
print("\n" + "="*60)
print("测试 5: 数据合并节点(多输入)")
print("="*60)
import pandas as pd
# 创建测试数据
table1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
table2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]})
table3 = pd.DataFrame({'a': [9, 10], 'b': [11, 12]})
print("\n输入表 1:")
print(table1)
print("\n输入表 2:")
print(table2)
print("\n输入表 3:")
print(table3)
# 创建合并节点
concat_node = ConcatNode(
node_id="concat_1",
params={"ignore_index": True}
)
result = concat_node.wrap_process(
inputs={"tables": [table1, table2, table3]},
context={}
)
print("\n合并后:")
print(result['outputs']['concatenated'])
print(f"\n上下文: {result['context']}")
def test_validation():
"""测试输入验证"""
print("\n" + "="*60)
print("测试 6: 输入验证")
print("="*60)
add_node = AddNode(node_id="add_2", params={})
# 测试成功验证
try:
add_node.validate_inputs({"a": 1, "b": 2})
print("✅ 验证通过: 所有必需输入已提供")
except ValueError as e:
print(f"❌ 验证失败: {e}")
# 测试失败验证(缺少输入)
try:
add_node.validate_inputs({"a": 1})
print("❌ 验证应该失败但通过了")
except ValueError as e:
print(f"✅ 验证正确失败: {e}")
def test_all_nodes_metadata():
"""测试所有节点的元数据"""
print("\n" + "="*60)
print("测试 7: 所有节点元数据摘要")
print("="*60)
all_metadata = NodeRegistry.get_metadata_list()
print(f"\n{len(all_metadata)} 个节点:\n")
for meta in all_metadata:
print(f"📦 {meta['display_name']} ({meta['class_name']})")
print(f" 分类: {meta['category']}")
print(f" 输入: {len(meta['inputs'])} | 输出: {len(meta['outputs'])} | "
f"参数: {len(meta['params'])} | 上下文: {len(meta['context'])}")
print()
if __name__ == "__main__":
print("\n" + "🧪 TraceStudio 节点系统测试".center(60, "="))
try:
test_node_registration()
test_node_metadata()
test_node_execution()
# 依赖 pandas 的测试
try:
import pandas
test_filter_node()
test_concat_node()
except ImportError:
print("\n" + "="*60)
print("⚠️ 跳过 pandas 相关测试(未安装 pandas")
print(" 安装命令: pip install pandas")
print("="*60)
test_validation()
test_all_nodes_metadata()
print("\n" + "="*60)
print("✅ 所有测试通过!")
print("="*60 + "\n")
except Exception as e:
print("\n" + "="*60)
print(f"❌ 测试失败: {e}")
print("="*60 + "\n")
import traceback
traceback.print_exc()