TraceStudio-dist/server/app/services/agent_client.py
2026-01-13 16:41:31 +08:00

157 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent 客户端:通过 HTTP 调用宿主上的 Agent 服务,并提供路径映射工具。"""
from __future__ import annotations
import json
import os
import urllib.request
import urllib.error
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Any, Dict, List, Optional, Union
# 确保这里的引用不会导致循环依赖
from server.app.core.user_manager import CLOUD_ROOT
class AgentClient:
def __init__(
self,
base_url: str = "http://localhost:8100",
*,
host_cloud_root: Optional[str] = None, # 改为 str避免歧义
container_cloud_root: Optional[str] = None, # 改为 str
):
self.base_url = base_url.rstrip("/")
# 【关键修复 1】: 强制使用 PureWindowsPath
# 即使代码跑在 Linux 上,也要把这个对象当作 Windows 路径处理(保留盘符和反斜杠逻辑)
self.host_cloud_root = PureWindowsPath(host_cloud_root) if host_cloud_root else None
# 容器内路径始终是 Posix (Linux)
self.container_cloud_root = PurePosixPath(container_cloud_root) if container_cloud_root else None
print(f"[AgentClient] 初始化base_url={self.base_url}, host_cloud_root={self.host_cloud_root}, container_cloud_root={self.container_cloud_root}")
@classmethod
def from_env(cls, *, container_cloud_root: Optional[Path | str] = None) -> "AgentClient":
base_url = os.environ.get("AGENT_BASE_URL", "http://host.docker.internal:8100")
host_root = os.environ.get("AGENT_HOST_CLOUD_ROOT")
# 确保传入的是字符串或 Path 对象都能处理
container_root = str(container_cloud_root) if container_cloud_root else None
return cls(base_url=base_url, host_cloud_root=host_root, container_cloud_root=container_root)
# ---------------------- 路径映射 ----------------------
def map_cloud_to_host(self, path: Union[str, Path]) -> str:
"""
【Docker -> Host】
将容器内的路径(/opt/...转换为宿主机的路径D:\...)。
"""
# 防御性编程:如果没有配置映射,返回原字符串
if not self.container_cloud_root or not self.host_cloud_root:
return path
# 统一转为 PurePosixPath 进行计算 (Linux 逻辑)
p_obj = PurePosixPath(path)
try:
# 1. 计算相对路径 (Linux 逻辑)
# 例如: /opt/tracestudio/cloud/traces/test.utrace -> traces/test.utrace
rel = p_obj.relative_to(self.container_cloud_root)
# 2. 拼接到 Windows 根路径 (Windows 逻辑)
# self.host_cloud_root 是 PureWindowsPath所以 / 操作符会自动处理为 Windows 风格
windows_path = self.host_cloud_root / rel
# 【关键修复 2】: 显式转为字符串,确保输出 D:\XGame\...
return str(windows_path)
except ValueError:
# 如果路径不在 container_cloud_root 下(比如是 /tmp/xxx原样返回
return str(p_obj)
def map_host_to_container(self, path: Union[str, Path]) -> Path:
"""
【Host -> Docker】
将 Agent 返回的 Windows 路径D:\...)映射回容器路径。
"""
# 统一转为 PureWindowsPath (Windows 逻辑)
p_obj = PureWindowsPath(path)
if not self.host_cloud_root or not self.container_cloud_root:
return Path(str(p_obj))
try:
# 1. 计算相对路径 (Windows 逻辑,忽略大小写)
rel = p_obj.relative_to(self.host_cloud_root)
# 2. 拼接到 Linux 容器路径
# 这里的 Path 在 Docker 环境下就是 PosixPath
return Path(self.container_cloud_root / rel)
except ValueError:
return Path(str(p_obj))
# ---------------------- 调用 ----------------------
def run(
self,
tool: str,
args: Optional[List[str]] = None,
*,
workdir: Optional[str] = None,
timeout: Optional[int] = None,
env: Optional[Dict[str, str]] = None,
capture_output: Optional[bool] = None,
strip_output: Optional[bool] = None,
) -> Dict[str, Any]:
# 【关键修复 3】: Workdir 必须映射!
# 如果 Server 指定 workdir="/opt/cloud/logs"Agent 必须收到 "D:\cloud\logs"
host_workdir = None
if workdir:
host_workdir = self.map_cloud_to_host(workdir)
payload = {
"tool": tool,
"args": args or [],
}
# 使用映射后的 host_workdir
if host_workdir:
payload["workdir"] = host_workdir
if timeout is not None:
payload["timeout"] = timeout
if env:
payload["env"] = env
if capture_output is not None:
payload["capture_output"] = capture_output
if strip_output is not None:
payload["strip_output"] = strip_output
url = f"{self.base_url}/run"
data = json.dumps(payload).encode("utf-8")
# 增加 charset 防止中文乱码
headers = {"Content-Type": "application/json; charset=utf-8"}
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
except urllib.error.HTTPError as e:
# 尝试读取错误详情
try:
err_body = e.read().decode("utf-8") if e.fp else ""
err_json = json.loads(err_body)
detail = err_json.get("detail", str(e))
except Exception:
detail = str(e)
raise RuntimeError(f"Agent HTTP {e.code}: {detail}") from e
except Exception as e:
raise RuntimeError(f"Agent 调用失败: {e}") from e
# 初始化实例
# 注意:确保 CLOUD_ROOT 在这里已经是加载好的绝对路径str 或 Path
agent = AgentClient.from_env(container_cloud_root=CLOUD_ROOT)
__all__ = ["AgentClient", "agent"]