TraceStudio-dist/server/app/nodes/io_nodes.py
2026-01-13 16:41:31 +08:00

458 lines
17 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
IO 节点集合
用于文件系统操作、路径处理等
"""
import os
import glob
from pathlib import Path
from ..core.user_manager import CLOUD_ROOT
from typing import Any, Dict, Optional, List
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType,
CachePolicy
)
from server.app.core.node_registry import register_node
import yaml
def load_system_config():
"""加载系统配置"""
config_path = Path(__file__).parent.parent.parent / "system_config.yaml"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
return {}
@register_node
class DirectoryScanner(TraceNode):
"""
目录扫描器
功能:
1. 扫描指定目录下的文件
2. 支持扩展名过滤(如 .utrace
3. 支持递归扫描子目录
4. 支持模式匹配glob pattern
5. 输出文件路径列表(数组)
使用场景:
- 批量处理多个 .utrace 文件
- 查找特定类型的文件
- 作为数组数据源,配合 EXPAND 维度转换
"""
CATEGORY = "IO/Scanner"
DISPLAY_NAME = "目录扫描器"
DESCRIPTION = "扫描目录并输出文件路径列表"
ICON = "📁"
NODE_TYPE = NodeType.INPUT
CACHE_POLICY = CachePolicy.NONE # 每次都重新扫描
@output_port("files", "Array<String>", description="文件路径列表(数组)")
@output_port("count", "Number", description="文件数量")
@param("directory", "String", default="", description="要扫描的目录(相对于用户目录)", required=True)
@param("pattern", "String", default="*.utrace", description="文件匹配模式(支持 glob", required=True)
@param("recursive", "Boolean", default=False, description="是否递归扫描子目录")
@param("sort_by", "String", default="name", description="排序方式",
options=["name", "size", "modified", "created", "none"])
@param("reverse_sort", "Boolean", default=False, description="反向排序")
@param("max_files", "Number", default=0, description="最大文件数0=无限制)", min=0, step=1)
@context_var("scan_path", "String", description="实际扫描的完整路径")
@context_var("file_count", "Integer", description="找到的文件数")
@context_var("total_size", "String", description="文件总大小")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 获取参数
directory = self.get_param("directory", "")
pattern = self.get_param("pattern", "*.utrace")
recursive = self.get_param("recursive", False)
sort_by = self.get_param("sort_by", "name")
reverse_sort = self.get_param("reverse_sort", False)
max_files = int(self.get_param("max_files", 0))
scan_path = CLOUD_ROOT / directory if directory else CLOUD_ROOT
if not scan_path.exists():
raise FileNotFoundError(f"目录不存在:{scan_path}")
if not scan_path.is_dir():
raise ValueError(f"路径不是目录:{scan_path}")
# 扫描文件
files = []
if recursive:
# 递归扫描
search_pattern = f"**/{pattern}"
file_paths = glob.glob(str(scan_path / search_pattern), recursive=True)
else:
# 只扫描当前目录
file_paths = glob.glob(str(scan_path / pattern))
# 转换为 Path 对象并过滤
file_objects = []
for fp in file_paths:
p = Path(fp)
if p.is_file(): # 只保留文件,排除目录
file_objects.append(p)
# 排序
if sort_by != "none":
if sort_by == "name":
file_objects.sort(key=lambda x: x.name, reverse=reverse_sort)
elif sort_by == "size":
file_objects.sort(key=lambda x: x.stat().st_size, reverse=reverse_sort)
elif sort_by == "modified":
file_objects.sort(key=lambda x: x.stat().st_mtime, reverse=reverse_sort)
elif sort_by == "created":
file_objects.sort(key=lambda x: x.stat().st_ctime, reverse=reverse_sort)
# 限制文件数
if max_files > 0:
file_objects = file_objects[:max_files]
# 计算总大小
total_size = sum(f.stat().st_size for f in file_objects)
# 转换为路径字符串
for file_obj in file_objects:
rel_path = file_obj.relative_to(CLOUD_ROOT)
files.append(str(rel_path).replace("\\\\", "/"))
return {
"files": files,
"count": len(files)
}
@staticmethod
def _format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
@register_node
class PathFilter(TraceNode):
"""
路径过滤器
功能:
从路径列表中过滤出符合条件的路径
"""
CATEGORY = "IO/Filter"
DISPLAY_NAME = "路径过滤器"
DESCRIPTION = "根据条件过滤文件路径列表"
ICON = "🔍"
@input_port("paths", "Array<String>", description="输入路径列表")
@output_port("filtered", "Array<String>", description="过滤后的路径列表")
@output_port("count", "Number", description="过滤后的数量")
@param("include_pattern", "String", default="", description="包含模式(支持通配符)")
@param("exclude_pattern", "String", default="", description="排除模式(支持通配符)")
@param("min_size", "Number", default=0, description="最小文件大小(字节)", min=0)
@param("max_size", "Number", default=0, description="最大文件大小字节0=无限制)", min=0)
@param("case_sensitive", "Boolean", default=False, description="大小写敏感")
@context_var("filtered_count", "Integer", description="过滤后的文件数")
@context_var("removed_count", "Integer", description="移除的文件数")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
import fnmatch
paths = inputs.get("paths", [])
if not isinstance(paths, list):
paths = [paths]
include_pattern = self.get_param("include_pattern", "")
exclude_pattern = self.get_param("exclude_pattern", "")
min_size = self.get_param("min_size", 0)
max_size = self.get_param("max_size", 0)
case_sensitive = self.get_param("case_sensitive", False)
# 加载配置获取基础路径
config = load_system_config()
cloud_root = Path(config.get("storage", {}).get("cloud_root", "./cloud"))
user_id = (context or {}).get("user_id", "guest")
user_base = cloud_root / "users" / user_id
filtered = []
for path_str in paths:
# 构建完整路径用于检查文件大小
if Path(path_str).is_absolute():
full_path = Path(path_str)
else:
full_path = user_base / path_str
# 检查文件是否存在
if not full_path.exists() or not full_path.is_file():
continue
# 获取文件名用于模式匹配
filename = Path(path_str).name
if not case_sensitive:
filename = filename.lower()
check_include = include_pattern.lower() if include_pattern else ""
check_exclude = exclude_pattern.lower() if exclude_pattern else ""
else:
check_include = include_pattern
check_exclude = exclude_pattern
# 包含模式检查
if include_pattern:
if not fnmatch.fnmatch(filename, check_include):
continue
# 排除模式检查
if exclude_pattern:
if fnmatch.fnmatch(filename, check_exclude):
continue
# 文件大小检查
file_size = full_path.stat().st_size
if min_size > 0 and file_size < min_size:
continue
if max_size > 0 and file_size > max_size:
continue
filtered.append(path_str)
return {"filtered": filtered, "count": len(filtered)}
@register_node
class PathBuilder(TraceNode):
"""
路径构建器
功能:
组合目录和文件名构建完整路径
"""
CATEGORY = "IO/Builder"
DISPLAY_NAME = "路径构建器"
DESCRIPTION = "组合目录和文件名构建路径"
ICON = "🔨"
@input_port("directory", "String", description="目录路径")
@input_port("filename", "String", description="文件名")
@output_port("path", "String", description="完整路径")
@param("separator", "String", default="/", description="路径分隔符", options=["/", "\\"])
@param("normalize", "Boolean", default=True, description="规范化路径")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
directory = inputs.get("directory", "")
filename = inputs.get("filename", "")
separator = self.get_param("separator", "/")
normalize = self.get_param("normalize", True)
# 组合路径
if directory and filename:
# 移除目录末尾的分隔符
clean_dir = directory.rstrip('/').rstrip('\\\\')
path = f"{clean_dir}{separator}{filename}"
elif directory:
path = directory
elif filename:
path = filename
else:
path = ""
# 规范化
if normalize and path:
path = str(Path(path)).replace("\\", separator)
return {"path": path}
@register_node
class FileInfo(TraceNode):
"""
文件信息读取器
功能:
读取文件的详细信息(大小、修改时间等)
"""
CATEGORY = "IO/Info"
DISPLAY_NAME = "文件信息"
DESCRIPTION = "读取文件的详细信息"
ICON = ""
@input_port("file_path", "String", description="文件路径")
@output_port("exists", "Boolean", description="文件是否存在")
@output_port("size", "Number", description="文件大小(字节)")
@output_port("size_formatted", "String", description="格式化的文件大小")
@output_port("name", "String", description="文件名")
@output_port("extension", "String", description="文件扩展名")
@output_port("directory", "String", description="所在目录")
@context_var("modified_time", "String", description="修改时间")
@context_var("created_time", "String", description="创建时间")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
from datetime import datetime
file_path = inputs.get("file_path", "")
if not file_path:
raise ValueError("必须提供 file_path 输入")
# 加载配置
config = load_system_config()
cloud_root = Path(config.get("storage", {}).get("cloud_root", "./cloud"))
user_id = (context or {}).get("user_id", "guest")
# 构建完整路径
if Path(file_path).is_absolute():
full_path = Path(file_path)
else:
full_path = cloud_root / "users" / user_id / file_path
# 检查文件是否存在
exists = full_path.exists() and full_path.is_file()
if exists:
stat = full_path.stat()
size = stat.st_size
size_formatted = self._format_file_size(size)
name = full_path.name
extension = full_path.suffix
directory = str(full_path.parent)
modified_time = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
created_time = datetime.fromtimestamp(stat.st_ctime).strftime("%Y-%m-%d %H:%M:%S")
else:
size = 0
size_formatted = "0 B"
name = Path(file_path).name
extension = Path(file_path).suffix
directory = str(Path(file_path).parent)
modified_time = ""
created_time = ""
return {
"exists": exists,
"size": size,
"size_formatted": size_formatted,
"name": name,
"extension": extension,
"directory": directory
}
@staticmethod
def _format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
@register_node
class ArrayToString(TraceNode):
"""
数组转字符串
功能:
将数组元素连接成一个字符串
"""
CATEGORY = "Array/Transform"
DISPLAY_NAME = "数组转字符串"
DESCRIPTION = "将数组元素用分隔符连接成字符串"
ICON = "📝"
@input_port("array", "Array", description="输入数组")
@output_port("string", "String", description="连接后的字符串")
@output_port("length", "Number", description="数组元素个数")
@param("separator", "String", default=", ", description="分隔符")
@param("prefix", "String", default="", description="前缀")
@param("suffix", "String", default="", description="后缀")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
array = inputs.get("array", [])
if not isinstance(array, list):
array = [array]
separator = self.get_param("separator", ", ")
prefix = self.get_param("prefix", "")
suffix = self.get_param("suffix", "")
# 转换为字符串并连接
string_list = [str(item) for item in array]
result = prefix + separator.join(string_list) + suffix
return {"string": result, "length": len(array)}
@register_node
class StringToArray(TraceNode):
"""
字符串转数组
功能:
将字符串按分隔符拆分成数组
"""
CATEGORY = "Array/Transform"
DISPLAY_NAME = "字符串转数组"
DESCRIPTION = "将字符串按分隔符拆分成数组"
ICON = "✂️"
@input_port("string", "String", description="输入字符串")
@output_port("array", "Array<String>", description="拆分后的数组")
@output_port("count", "Number", description="元素个数")
@param("separator", "String", default=",", description="分隔符")
@param("strip_whitespace", "Boolean", default=True, description="去除空白字符")
@param("remove_empty", "Boolean", default=True, description="移除空元素")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
string = inputs.get("string", "")
separator = self.get_param("separator", ",")
strip_whitespace = self.get_param("strip_whitespace", True)
remove_empty = self.get_param("remove_empty", True)
# 拆分字符串
array = string.split(separator)
# 去除空白
if strip_whitespace:
array = [item.strip() for item in array]
# 移除空元素
if remove_empty:
array = [item for item in array if item]
return {"array": array, "count": len(array)}
@register_node
class SaveDataframe(TraceNode):
"""保存 DataFrame 到 CSV 文件"""
CATEGORY = "IO/Save"
DISPLAY_NAME = "保存表为CSV"
DESCRIPTION = "将 DataFrame 保存为 CSV 文件"
ICON = "💾"
@input_port("df", "DataTable", description="要保存的 DataFrame", required=True)
@param("filename", "String", default="output.csv", description="保存的文件名(相对用户目录)", required=True)
@output_port("path", "String", description="保存后的文件路径")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
import polars as pl
from pathlib import Path
df = inputs.get("df", None)
if df is None or not isinstance(df, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
filepath = self.get_param("filename", "output.csv")
if not filepath:
raise ValueError("必须指定文件名")
save_path = CLOUD_ROOT / filepath
save_path.parent.mkdir(parents=True, exist_ok=True)
if df.columns and "Metadata" in df.columns:
df = df.sort("Metadata")
# 保存为 CSV
df.write_csv(str(save_path))
return {"path": str(save_path)}