TraceStudio-dev/server/tests/test_node_system.py
2026-01-09 21:37:02 +08:00

243 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TraceStudio 节点开发测试
测试装饰器自动收集系统
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from server.app.core.node_registry import NodeRegistry
from server.app.nodes.example_nodes import (
AddNode,
MultiplyNode,
CSVLoaderNode,
FilterRowsNode,
SelectColumnsNode,
ConcatNode,
TableOutputNode,
StatisticsNode
)
def test_node_registration():
"""测试节点注册"""
print("\n" + "="*60)
print("测试 1: 节点注册")
print("="*60)
stats = NodeRegistry.get_stats()
print(f"\n已注册节点总数: {stats['total']}")
print(f"\n分类分布:")
for category, count in sorted(stats['by_category'].items()):
print(f" {category}: {count}")
print(f"\n所有分类:")
for category in stats['categories']:
print(f" - {category}")
def test_node_metadata():
"""测试节点元数据生成"""
print("\n" + "="*60)
print("测试 2: 节点元数据")
print("="*60)
# 测试 AddNode
metadata = AddNode.get_metadata()
print(f"\n{metadata['display_name']} ({metadata['class_name']})")
print(f"分类: {metadata['category']}")
print(f"描述: {metadata['description']}")
print(f"\n输入端口 ({len(metadata['inputs'])} 个):")
for inp in metadata['inputs']:
print(f" - {inp['name']} ({inp['type']}): {inp.get('description', '')}")
print(f"\n输出端口 ({len(metadata['outputs'])} 个):")
for out in metadata['outputs']:
print(f" - {out['name']} ({out['type']}): {out.get('description', '')}")
print(f"\n参数 ({len(metadata['params'])} 个):")
for par in metadata['params']:
print(f" - {par['name']} ({par['type']}): {par.get('description', '')} [默认: {par.get('default')}]")
print(f"\n上下文变量 ({len(metadata['context'])} 个):")
for ctx in metadata['context']:
print(f" - {ctx['name']} ({ctx['type']}): {ctx.get('description', '')}")
def test_node_execution():
"""测试节点执行"""
print("\n" + "="*60)
print("测试 3: 节点执行")
print("="*60)
# 测试 AddNode
add_node = AddNode(node_id="add_1", params={"offset": 10})
result = add_node.wrap_process(
inputs={"a": 5, "b": 3},
context={}
)
print(f"\nAddNode 执行结果:")
print(f" 输出: {result['outputs']}")
print(f" 上下文: {result['context']}")
# 测试 MultiplyNode
multiply_node = MultiplyNode(node_id="mul_1", params={"scale": 2.0})
result = multiply_node.wrap_process(
inputs={"a": 4, "b": 5},
context={}
)
print(f"\nMultiplyNode 执行结果:")
print(f" 输出: {result['outputs']}")
print(f" 上下文: {result['context']}")
def test_filter_node():
"""测试过滤节点"""
print("\n" + "="*60)
print("测试 4: 数据过滤节点")
print("="*60)
import pandas as pd
# 创建测试数据
test_data = pd.DataFrame({
'name': ['Alice', 'Bob', 'Charlie', 'David'],
'age': [25, 30, 35, 28],
'score': [85, 92, 78, 88]
})
print("\n原始数据:")
print(test_data)
# 创建过滤节点
filter_node = FilterRowsNode(
node_id="filter_1",
params={"column": "score", "operator": ">", "value": 80}
)
result = filter_node.wrap_process(
inputs={"table": test_data},
context={}
)
print("\n过滤后数据 (score > 80):")
print(result['outputs']['filtered'])
print(f"\n上下文: {result['context']}")
def test_concat_node():
"""测试聚合节点(多输入)"""
print("\n" + "="*60)
print("测试 5: 数据合并节点(多输入)")
print("="*60)
import pandas as pd
# 创建测试数据
table1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
table2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]})
table3 = pd.DataFrame({'a': [9, 10], 'b': [11, 12]})
print("\n输入表 1:")
print(table1)
print("\n输入表 2:")
print(table2)
print("\n输入表 3:")
print(table3)
# 创建合并节点
concat_node = ConcatNode(
node_id="concat_1",
params={"ignore_index": True}
)
result = concat_node.wrap_process(
inputs={"tables": [table1, table2, table3]},
context={}
)
print("\n合并后:")
print(result['outputs']['concatenated'])
print(f"\n上下文: {result['context']}")
def test_validation():
"""测试输入验证"""
print("\n" + "="*60)
print("测试 6: 输入验证")
print("="*60)
add_node = AddNode(node_id="add_2", params={})
# 测试成功验证
try:
add_node.validate_inputs({"a": 1, "b": 2})
print("✅ 验证通过: 所有必需输入已提供")
except ValueError as e:
print(f"❌ 验证失败: {e}")
# 测试失败验证(缺少输入)
try:
add_node.validate_inputs({"a": 1})
print("❌ 验证应该失败但通过了")
except ValueError as e:
print(f"✅ 验证正确失败: {e}")
def test_all_nodes_metadata():
"""测试所有节点的元数据"""
print("\n" + "="*60)
print("测试 7: 所有节点元数据摘要")
print("="*60)
all_metadata = NodeRegistry.get_metadata_list()
print(f"\n{len(all_metadata)} 个节点:\n")
for meta in all_metadata:
print(f"📦 {meta['display_name']} ({meta['class_name']})")
print(f" 分类: {meta['category']}")
print(f" 输入: {len(meta['inputs'])} | 输出: {len(meta['outputs'])} | "
f"参数: {len(meta['params'])} | 上下文: {len(meta['context'])}")
print()
if __name__ == "__main__":
print("\n" + "🧪 TraceStudio 节点系统测试".center(60, "="))
try:
test_node_registration()
test_node_metadata()
test_node_execution()
# 依赖 pandas 的测试
try:
import pandas
test_filter_node()
test_concat_node()
except ImportError:
print("\n" + "="*60)
print("⚠️ 跳过 pandas 相关测试(未安装 pandas")
print(" 安装命令: pip install pandas")
print("="*60)
test_validation()
test_all_nodes_metadata()
print("\n" + "="*60)
print("✅ 所有测试通过!")
print("="*60 + "\n")
except Exception as e:
print("\n" + "="*60)
print(f"❌ 测试失败: {e}")
print("="*60 + "\n")
import traceback
traceback.print_exc()