243 lines
6.7 KiB
Python
243 lines
6.7 KiB
Python
"""
|
||
TraceStudio 节点开发测试
|
||
测试装饰器自动收集系统
|
||
"""
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from server.app.core.node_registry import NodeRegistry
|
||
from server.app.nodes.example_nodes import (
|
||
AddNode,
|
||
MultiplyNode,
|
||
CSVLoaderNode,
|
||
FilterRowsNode,
|
||
SelectColumnsNode,
|
||
ConcatNode,
|
||
TableOutputNode,
|
||
StatisticsNode
|
||
)
|
||
|
||
|
||
def test_node_registration():
|
||
"""测试节点注册"""
|
||
print("\n" + "="*60)
|
||
print("测试 1: 节点注册")
|
||
print("="*60)
|
||
|
||
stats = NodeRegistry.get_stats()
|
||
print(f"\n已注册节点总数: {stats['total']}")
|
||
print(f"\n分类分布:")
|
||
for category, count in sorted(stats['by_category'].items()):
|
||
print(f" {category}: {count} 个")
|
||
|
||
print(f"\n所有分类:")
|
||
for category in stats['categories']:
|
||
print(f" - {category}")
|
||
|
||
|
||
def test_node_metadata():
|
||
"""测试节点元数据生成"""
|
||
print("\n" + "="*60)
|
||
print("测试 2: 节点元数据")
|
||
print("="*60)
|
||
|
||
# 测试 AddNode
|
||
metadata = AddNode.get_metadata()
|
||
print(f"\n{metadata['display_name']} ({metadata['class_name']})")
|
||
print(f"分类: {metadata['category']}")
|
||
print(f"描述: {metadata['description']}")
|
||
|
||
print(f"\n输入端口 ({len(metadata['inputs'])} 个):")
|
||
for inp in metadata['inputs']:
|
||
print(f" - {inp['name']} ({inp['type']}): {inp.get('description', '')}")
|
||
|
||
print(f"\n输出端口 ({len(metadata['outputs'])} 个):")
|
||
for out in metadata['outputs']:
|
||
print(f" - {out['name']} ({out['type']}): {out.get('description', '')}")
|
||
|
||
print(f"\n参数 ({len(metadata['params'])} 个):")
|
||
for par in metadata['params']:
|
||
print(f" - {par['name']} ({par['type']}): {par.get('description', '')} [默认: {par.get('default')}]")
|
||
|
||
print(f"\n上下文变量 ({len(metadata['context'])} 个):")
|
||
for ctx in metadata['context']:
|
||
print(f" - {ctx['name']} ({ctx['type']}): {ctx.get('description', '')}")
|
||
|
||
|
||
def test_node_execution():
|
||
"""测试节点执行"""
|
||
print("\n" + "="*60)
|
||
print("测试 3: 节点执行")
|
||
print("="*60)
|
||
|
||
# 测试 AddNode
|
||
add_node = AddNode(node_id="add_1", params={"offset": 10})
|
||
result = add_node.wrap_process(
|
||
inputs={"a": 5, "b": 3},
|
||
context={}
|
||
)
|
||
|
||
print(f"\nAddNode 执行结果:")
|
||
print(f" 输出: {result['outputs']}")
|
||
print(f" 上下文: {result['context']}")
|
||
|
||
# 测试 MultiplyNode
|
||
multiply_node = MultiplyNode(node_id="mul_1", params={"scale": 2.0})
|
||
result = multiply_node.wrap_process(
|
||
inputs={"a": 4, "b": 5},
|
||
context={}
|
||
)
|
||
|
||
print(f"\nMultiplyNode 执行结果:")
|
||
print(f" 输出: {result['outputs']}")
|
||
print(f" 上下文: {result['context']}")
|
||
|
||
|
||
def test_filter_node():
|
||
"""测试过滤节点"""
|
||
print("\n" + "="*60)
|
||
print("测试 4: 数据过滤节点")
|
||
print("="*60)
|
||
|
||
import pandas as pd
|
||
|
||
# 创建测试数据
|
||
test_data = pd.DataFrame({
|
||
'name': ['Alice', 'Bob', 'Charlie', 'David'],
|
||
'age': [25, 30, 35, 28],
|
||
'score': [85, 92, 78, 88]
|
||
})
|
||
|
||
print("\n原始数据:")
|
||
print(test_data)
|
||
|
||
# 创建过滤节点
|
||
filter_node = FilterRowsNode(
|
||
node_id="filter_1",
|
||
params={"column": "score", "operator": ">", "value": 80}
|
||
)
|
||
|
||
result = filter_node.wrap_process(
|
||
inputs={"table": test_data},
|
||
context={}
|
||
)
|
||
|
||
print("\n过滤后数据 (score > 80):")
|
||
print(result['outputs']['filtered'])
|
||
print(f"\n上下文: {result['context']}")
|
||
|
||
|
||
def test_concat_node():
|
||
"""测试聚合节点(多输入)"""
|
||
print("\n" + "="*60)
|
||
print("测试 5: 数据合并节点(多输入)")
|
||
print("="*60)
|
||
|
||
import pandas as pd
|
||
|
||
# 创建测试数据
|
||
table1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
||
table2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]})
|
||
table3 = pd.DataFrame({'a': [9, 10], 'b': [11, 12]})
|
||
|
||
print("\n输入表 1:")
|
||
print(table1)
|
||
print("\n输入表 2:")
|
||
print(table2)
|
||
print("\n输入表 3:")
|
||
print(table3)
|
||
|
||
# 创建合并节点
|
||
concat_node = ConcatNode(
|
||
node_id="concat_1",
|
||
params={"ignore_index": True}
|
||
)
|
||
|
||
result = concat_node.wrap_process(
|
||
inputs={"tables": [table1, table2, table3]},
|
||
context={}
|
||
)
|
||
|
||
print("\n合并后:")
|
||
print(result['outputs']['concatenated'])
|
||
print(f"\n上下文: {result['context']}")
|
||
|
||
|
||
def test_validation():
|
||
"""测试输入验证"""
|
||
print("\n" + "="*60)
|
||
print("测试 6: 输入验证")
|
||
print("="*60)
|
||
|
||
add_node = AddNode(node_id="add_2", params={})
|
||
|
||
# 测试成功验证
|
||
try:
|
||
add_node.validate_inputs({"a": 1, "b": 2})
|
||
print("✅ 验证通过: 所有必需输入已提供")
|
||
except ValueError as e:
|
||
print(f"❌ 验证失败: {e}")
|
||
|
||
# 测试失败验证(缺少输入)
|
||
try:
|
||
add_node.validate_inputs({"a": 1})
|
||
print("❌ 验证应该失败但通过了")
|
||
except ValueError as e:
|
||
print(f"✅ 验证正确失败: {e}")
|
||
|
||
|
||
def test_all_nodes_metadata():
|
||
"""测试所有节点的元数据"""
|
||
print("\n" + "="*60)
|
||
print("测试 7: 所有节点元数据摘要")
|
||
print("="*60)
|
||
|
||
all_metadata = NodeRegistry.get_metadata_list()
|
||
|
||
print(f"\n共 {len(all_metadata)} 个节点:\n")
|
||
|
||
for meta in all_metadata:
|
||
print(f"📦 {meta['display_name']} ({meta['class_name']})")
|
||
print(f" 分类: {meta['category']}")
|
||
print(f" 输入: {len(meta['inputs'])} | 输出: {len(meta['outputs'])} | "
|
||
f"参数: {len(meta['params'])} | 上下文: {len(meta['context'])}")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("\n" + "🧪 TraceStudio 节点系统测试".center(60, "="))
|
||
|
||
try:
|
||
test_node_registration()
|
||
test_node_metadata()
|
||
test_node_execution()
|
||
|
||
# 依赖 pandas 的测试
|
||
try:
|
||
import pandas
|
||
test_filter_node()
|
||
test_concat_node()
|
||
except ImportError:
|
||
print("\n" + "="*60)
|
||
print("⚠️ 跳过 pandas 相关测试(未安装 pandas)")
|
||
print(" 安装命令: pip install pandas")
|
||
print("="*60)
|
||
|
||
test_validation()
|
||
test_all_nodes_metadata()
|
||
|
||
print("\n" + "="*60)
|
||
print("✅ 所有测试通过!")
|
||
print("="*60 + "\n")
|
||
|
||
except Exception as e:
|
||
print("\n" + "="*60)
|
||
print(f"❌ 测试失败: {e}")
|
||
print("="*60 + "\n")
|
||
import traceback
|
||
traceback.print_exc()
|