TraceStudio-dev/server/tests/test_executor_and_cache.py
2026-01-09 21:37:02 +08:00

265 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TraceStudio 执行引擎和缓存系统集成测试
"""
import sys
import asyncio
from pathlib import Path
import time
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from server.app.core.cache_manager import CacheManager, MemoryCache, DiskCache, CacheEvictionPolicy
from server.app.core.workflow_executor import WorkflowExecutor, WorkflowGraph
from server.app.nodes.example_nodes import AddNode, MultiplyNode
# ============= 测试 1: 内存缓存 =============
def test_memory_cache():
"""测试内存缓存"""
print("\n" + "="*60)
print("测试 1: 内存缓存")
print("="*60)
cache = MemoryCache(max_size=3, ttl=None, policy=CacheEvictionPolicy.LRU)
# 添加数据
cache.set("key1", {"value": 1})
cache.set("key2", {"value": 2})
cache.set("key3", {"value": 3})
print("✅ 添加 3 个缓存项")
# 获取数据
assert cache.get("key1") == {"value": 1}
print("✅ 成功获取 key1")
# 触发淘汰LRU
cache.set("key4", {"value": 4})
print("✅ 添加第 4 个项,触发 LRU 淘汰")
# key2 应该被淘汰(最久未使用)
assert cache.get("key2") is None
print("✅ key2 被淘汰LRU")
# 统计
stats = cache.get_stats()
print(f"\n缓存统计: {stats}")
def test_disk_cache():
"""测试磁盘缓存"""
print("\n" + "="*60)
print("测试 2: 磁盘缓存")
print("="*60)
cache_dir = Path("/tmp/tracestudio_cache_test")
cache = DiskCache(cache_dir, ttl=None)
# 设置数据
cache.set("test_key", {"data": "test_value", "count": 42})
print("✅ 设置磁盘缓存")
# 获取数据
result = cache.get("test_key")
assert result == {"data": "test_value", "count": 42}
print("✅ 成功获取磁盘缓存数据")
# 统计
stats = cache.get_stats()
print(f"\n缓存统计: {stats}")
# 清理
cache.clear()
print("✅ 清理磁盘缓存")
def test_cache_ttl():
"""测试缓存 TTL"""
print("\n" + "="*60)
print("测试 3: 缓存 TTL")
print("="*60)
cache = MemoryCache(max_size=100, ttl=1) # 1 秒过期
# 设置数据
cache.set("expire_key", {"value": "will_expire"})
print("✅ 设置缓存1秒过期")
# 立即获取
assert cache.get("expire_key") == {"value": "will_expire"}
print("✅ 立即获取成功")
# 等待过期
print("⏳ 等待 1.1 秒...")
time.sleep(1.1)
# 再次获取应该返回 None
assert cache.get("expire_key") is None
print("✅ 过期后返回 None")
# ============= 测试 4: 工作流图 =============
def test_workflow_graph():
"""测试工作流图"""
print("\n" + "="*60)
print("测试 4: 工作流图")
print("="*60)
graph = WorkflowGraph()
# 添加节点
# 按照执行器约定type 使用 NodeType 值(例如 'normal'),实现类名放到 class_name
graph.add_node("n1", "normal", {"offset": 0})
graph.nodes["n1"]["class_name"] = "AddNode"
graph.add_node("n2", "normal", {"scale": 2.0})
graph.nodes["n2"]["class_name"] = "MultiplyNode"
graph.add_node("n3", "normal", {"offset": 10})
graph.nodes["n3"]["class_name"] = "AddNode"
print("✅ 添加 3 个节点")
# 添加连接
graph.add_edge("n1", "result", "n2", "a")
graph.add_edge("n2", "result", "n3", "a")
print("✅ 添加连接: n1 -> n2 -> n3")
# 获取依赖
deps = graph.get_dependencies("n3")
assert "n1" in deps and "n2" in deps
print(f"✅ n3 的依赖: {deps}")
# 拓扑排序
order = graph.topological_sort()
assert order == ["n1", "n2", "n3"]
print(f"✅ 拓扑排序: {order}")
# 检查循环
assert not graph.has_cycle()
print("✅ 无循环依赖")
# ============= 测试 5: 工作流执行 =============
async def test_workflow_execution():
"""测试工作流执行"""
print("\n" + "="*60)
print("测试 5: 工作流执行")
print("="*60)
executor = WorkflowExecutor(user_id="test_user")
# 定义工作流
nodes = [
{
"id": "add_node",
"type": "AddNode",
"params": {"offset": 5}
},
{
"id": "multiply_node",
"type": "MultiplyNode",
"params": {"scale": 2.0}
}
]
# 边定义add_node.result -> multiply_node.a
edges = [
{
"source": "add_node",
"source_port": "result",
"target": "multiply_node",
"target_port": "a"
}
]
# add_node 的 a、bmultiply_node 的 b 都由全局 context 提供
global_context = {"a": 3, "b": 4}
try:
success, report = await executor.execute(nodes, edges, global_context=global_context)
if success:
print(f"✅ 工作流执行成功")
print(f" 执行 ID: {report.get('execution_id')}")
print(f" 总耗时: {report.get('total_duration', 0):.3f}s")
print(f" 节点数: {len(report.get('nodes', {}))}")
for node_id, info in report.get('nodes', {}).items():
print(f" - {node_id}: {info.get('status')}")
else:
print(f"❌ 工作流执行失败: {report.get('error') if report else 'Unknown error'}")
if report:
print(f" 详情: {report}")
except Exception as e:
print(f"❌ 执行异常: {e}")
import traceback
traceback.print_exc()
# ============= 测试 6: 缓存集成 =============
def test_cache_manager():
"""测试缓存管理器"""
print("\n" + "="*60)
print("测试 6: 缓存管理器")
print("="*60)
# 初始化缓存
CacheManager.init_memory_cache(max_size=100, ttl=None)
CacheManager.init_disk_cache(Path("/tmp/tracestudio_cache"), ttl=None)
print("✅ 初始化内存和磁盘缓存")
# 设置数据
CacheManager.set("test_key", {"data": "value"}, storage="memory")
CacheManager.set("test_key2", {"data": "value2"}, storage="disk")
print("✅ 分别设置内存和磁盘缓存")
# 获取数据
assert CacheManager.get("test_key", prefer="memory") == {"data": "value"}
assert CacheManager.get("test_key2", prefer="disk") == {"data": "value2"}
print("✅ 成功获取数据")
# 同时设置
CacheManager.set_both("both_key", {"data": "both"})
assert CacheManager.get("both_key") == {"data": "both"}
print("✅ 同时设置内存和磁盘缓存成功")
# 统计
stats = CacheManager.get_stats()
print(f"\n缓存统计:")
if stats.get("memory"):
print(f" 内存: {stats['memory']}")
if stats.get("disk"):
print(f" 磁盘: {stats['disk']}")
# 清理
CacheManager.clear()
print("✅ 清理所有缓存")
if __name__ == "__main__":
print("\n" + "🧪 TraceStudio 执行引擎和缓存系统测试".center(60, "="))
try:
# 运行测试
test_memory_cache()
test_disk_cache()
test_cache_ttl()
test_workflow_graph()
asyncio.run(test_workflow_execution())
test_cache_manager()
print("\n" + "="*60)
print("✅ 所有测试通过!")
print("="*60 + "\n")
except Exception as e:
print("\n" + "="*60)
print(f"❌ 测试失败: {e}")
print("="*60 + "\n")
import traceback
traceback.print_exc()