84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
|
|
"""
|
|||
|
|
TraceStudio 节点图运行时数据结构设计
|
|||
|
|
|
|||
|
|
本文件定义了节点图执行所需的核心数据结构,确保前后端一致。
|
|||
|
|
"""
|
|||
|
|
from enum import Enum
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from app.core.node_base import NodeType
|
|||
|
|
|
|||
|
|
class DimensionMode(Enum):
|
|||
|
|
NONE = "none"
|
|||
|
|
EXPAND = "expand"
|
|||
|
|
COLLAPSE = "collapse"
|
|||
|
|
BROADCAST = "broadcast"
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EdgeMetadata:
|
|||
|
|
source_node: str
|
|||
|
|
source_port: str
|
|||
|
|
target_node: str
|
|||
|
|
target_port: str
|
|||
|
|
dimension_mode: DimensionMode = DimensionMode.NONE
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class NodeMetadata:
|
|||
|
|
node_id: str
|
|||
|
|
# 严格约定:type 必须为 NodeType 的值("input"/"normal"/"output"/"function"/...)
|
|||
|
|
type: str
|
|||
|
|
# 实现类名(必须提供,用于从 NodeRegistry 加载实现)
|
|||
|
|
class_name: str
|
|||
|
|
params: Dict[str, Any] = field(default_factory=dict)
|
|||
|
|
sub_workflow_nodes: Optional[List[Dict]] = None
|
|||
|
|
sub_workflow_edges: Optional[List[Dict]] = None
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class GraphRuntime:
|
|||
|
|
nodes: List[NodeMetadata]
|
|||
|
|
edges: List[EdgeMetadata]
|
|||
|
|
# 可扩展:全局上下文、运行参数等
|
|||
|
|
|
|||
|
|
def to_executor_format(self) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
将运行时 Graph 格式转换为 WorkflowExecutor 接受的 nodes/edges 列表。
|
|||
|
|
|
|||
|
|
- 将 EdgeMetadata 字段命名转换为 executor 期望的 snake_case(source, source_port, target, target_port)
|
|||
|
|
- 将 NodeMetadata 中的 sub_workflow_* 字段展开为 executor 可识别的键名(sub_workflow_nodes/sub_workflow_edges)
|
|||
|
|
- 保持 params 原样传递
|
|||
|
|
"""
|
|||
|
|
nodes_out = []
|
|||
|
|
for n in self.nodes:
|
|||
|
|
# 验证 type 为 NodeType 的合法值
|
|||
|
|
try:
|
|||
|
|
NodeType(n.type)
|
|||
|
|
except Exception:
|
|||
|
|
raise ValueError(f"NodeMetadata.type 必须为 NodeType 的值之一: {[t.value for t in NodeType]} (got {n.type})")
|
|||
|
|
|
|||
|
|
nd = {
|
|||
|
|
"id": n.node_id,
|
|||
|
|
"type": n.type,
|
|||
|
|
"class_name": n.class_name,
|
|||
|
|
"params": n.params or {}
|
|||
|
|
}
|
|||
|
|
# 支持两种约定:直接 sub_workflow_* 或嵌套 sub_workflow
|
|||
|
|
if n.sub_workflow_nodes:
|
|||
|
|
nd["sub_workflow_nodes"] = n.sub_workflow_nodes
|
|||
|
|
if n.sub_workflow_edges:
|
|||
|
|
nd["sub_workflow_edges"] = n.sub_workflow_edges
|
|||
|
|
nodes_out.append(nd)
|
|||
|
|
|
|||
|
|
edges_out = []
|
|||
|
|
for e in self.edges:
|
|||
|
|
edges_out.append({
|
|||
|
|
"source": e.source_node,
|
|||
|
|
"source_port": e.source_port,
|
|||
|
|
"target": e.target_node,
|
|||
|
|
"target_port": e.target_port,
|
|||
|
|
# executor currently ignores dimension_mode on input dicts,
|
|||
|
|
# but we include it for completeness (as string)
|
|||
|
|
"dimension_mode": e.dimension_mode.value if e.dimension_mode else DimensionMode.NONE.value
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {"nodes": nodes_out, "edges": edges_out}
|