TraceStudio-dev/server/app/nodes/io_nodes.py

458 lines
17 KiB
Python
Raw Permalink Normal View History

"""
IO 节点集合
用于文件系统操作路径处理等
"""
import os
import glob
from pathlib import Path
from ..core.user_manager import CLOUD_ROOT
from typing import Any, Dict, Optional, List
from server.app.core.node_base import (
TraceNode,
input_port,
output_port,
param,
context_var,
NodeType,
CachePolicy
)
from server.app.core.node_registry import register_node
import yaml
def load_system_config():
"""加载系统配置"""
config_path = Path(__file__).parent.parent.parent / "system_config.yaml"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
return {}
@register_node
class DirectoryScanner(TraceNode):
"""
目录扫描器
功能
1. 扫描指定目录下的文件
2. 支持扩展名过滤 .utrace
3. 支持递归扫描子目录
4. 支持模式匹配glob pattern
5. 输出文件路径列表数组
使用场景
- 批量处理多个 .utrace 文件
- 查找特定类型的文件
- 作为数组数据源配合 EXPAND 维度转换
"""
CATEGORY = "IO/Scanner"
DISPLAY_NAME = "目录扫描器"
DESCRIPTION = "扫描目录并输出文件路径列表"
ICON = "📁"
NODE_TYPE = NodeType.INPUT
CACHE_POLICY = CachePolicy.NONE # 每次都重新扫描
@output_port("files", "Array<String>", description="文件路径列表(数组)")
@output_port("count", "Number", description="文件数量")
@param("directory", "String", default="", description="要扫描的目录(相对于用户目录)", required=True)
@param("pattern", "String", default="*.utrace", description="文件匹配模式(支持 glob", required=True)
@param("recursive", "Boolean", default=False, description="是否递归扫描子目录")
@param("sort_by", "String", default="name", description="排序方式",
options=["name", "size", "modified", "created", "none"])
@param("reverse_sort", "Boolean", default=False, description="反向排序")
@param("max_files", "Number", default=0, description="最大文件数0=无限制)", min=0, step=1)
@context_var("scan_path", "String", description="实际扫描的完整路径")
@context_var("file_count", "Integer", description="找到的文件数")
@context_var("total_size", "String", description="文件总大小")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
# 获取参数
directory = self.get_param("directory", "")
pattern = self.get_param("pattern", "*.utrace")
recursive = self.get_param("recursive", False)
sort_by = self.get_param("sort_by", "name")
reverse_sort = self.get_param("reverse_sort", False)
max_files = int(self.get_param("max_files", 0))
if not directory:
raise ValueError("必须指定 directory 参数")
scan_path = CLOUD_ROOT / directory
if not scan_path.exists():
raise FileNotFoundError(f"目录不存在:{scan_path}")
if not scan_path.is_dir():
raise ValueError(f"路径不是目录:{scan_path}")
# 扫描文件
files = []
if recursive:
# 递归扫描
search_pattern = f"**/{pattern}"
file_paths = glob.glob(str(scan_path / search_pattern), recursive=True)
else:
# 只扫描当前目录
file_paths = glob.glob(str(scan_path / pattern))
# 转换为 Path 对象并过滤
file_objects = []
for fp in file_paths:
p = Path(fp)
if p.is_file(): # 只保留文件,排除目录
file_objects.append(p)
# 排序
if sort_by != "none":
if sort_by == "name":
file_objects.sort(key=lambda x: x.name, reverse=reverse_sort)
elif sort_by == "size":
file_objects.sort(key=lambda x: x.stat().st_size, reverse=reverse_sort)
elif sort_by == "modified":
file_objects.sort(key=lambda x: x.stat().st_mtime, reverse=reverse_sort)
elif sort_by == "created":
file_objects.sort(key=lambda x: x.stat().st_ctime, reverse=reverse_sort)
# 限制文件数
if max_files > 0:
file_objects = file_objects[:max_files]
# 计算总大小
total_size = sum(f.stat().st_size for f in file_objects)
# 转换为路径字符串
for file_obj in file_objects:
rel_path = file_obj.relative_to(CLOUD_ROOT)
files.append(str(rel_path).replace("\\\\", "/"))
return {
"files": files,
"count": len(files)
}
@staticmethod
def _format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
@register_node
class PathFilter(TraceNode):
"""
路径过滤器
功能
从路径列表中过滤出符合条件的路径
"""
CATEGORY = "IO/Filter"
DISPLAY_NAME = "路径过滤器"
DESCRIPTION = "根据条件过滤文件路径列表"
ICON = "🔍"
@input_port("paths", "Array<String>", description="输入路径列表")
@output_port("filtered", "Array<String>", description="过滤后的路径列表")
@output_port("count", "Number", description="过滤后的数量")
@param("include_pattern", "String", default="", description="包含模式(支持通配符)")
@param("exclude_pattern", "String", default="", description="排除模式(支持通配符)")
@param("min_size", "Number", default=0, description="最小文件大小(字节)", min=0)
@param("max_size", "Number", default=0, description="最大文件大小字节0=无限制)", min=0)
@param("case_sensitive", "Boolean", default=False, description="大小写敏感")
@context_var("filtered_count", "Integer", description="过滤后的文件数")
@context_var("removed_count", "Integer", description="移除的文件数")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
import fnmatch
paths = inputs.get("paths", [])
if not isinstance(paths, list):
paths = [paths]
include_pattern = self.get_param("include_pattern", "")
exclude_pattern = self.get_param("exclude_pattern", "")
min_size = self.get_param("min_size", 0)
max_size = self.get_param("max_size", 0)
case_sensitive = self.get_param("case_sensitive", False)
# 加载配置获取基础路径
config = load_system_config()
cloud_root = Path(config.get("storage", {}).get("cloud_root", "./cloud"))
user_id = (context or {}).get("user_id", "guest")
user_base = cloud_root / "users" / user_id
filtered = []
for path_str in paths:
# 构建完整路径用于检查文件大小
if Path(path_str).is_absolute():
full_path = Path(path_str)
else:
full_path = user_base / path_str
# 检查文件是否存在
if not full_path.exists() or not full_path.is_file():
continue
# 获取文件名用于模式匹配
filename = Path(path_str).name
if not case_sensitive:
filename = filename.lower()
check_include = include_pattern.lower() if include_pattern else ""
check_exclude = exclude_pattern.lower() if exclude_pattern else ""
else:
check_include = include_pattern
check_exclude = exclude_pattern
# 包含模式检查
if include_pattern:
if not fnmatch.fnmatch(filename, check_include):
continue
# 排除模式检查
if exclude_pattern:
if fnmatch.fnmatch(filename, check_exclude):
continue
# 文件大小检查
file_size = full_path.stat().st_size
if min_size > 0 and file_size < min_size:
continue
if max_size > 0 and file_size > max_size:
continue
filtered.append(path_str)
return {"filtered": filtered, "count": len(filtered)}
@register_node
class PathBuilder(TraceNode):
"""
路径构建器
功能
组合目录和文件名构建完整路径
"""
CATEGORY = "IO/Builder"
DISPLAY_NAME = "路径构建器"
DESCRIPTION = "组合目录和文件名构建路径"
ICON = "🔨"
@input_port("directory", "String", description="目录路径")
@input_port("filename", "String", description="文件名")
@output_port("path", "String", description="完整路径")
@param("separator", "String", default="/", description="路径分隔符", options=["/", "\\"])
@param("normalize", "Boolean", default=True, description="规范化路径")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
directory = inputs.get("directory", "")
filename = inputs.get("filename", "")
separator = self.get_param("separator", "/")
normalize = self.get_param("normalize", True)
# 组合路径
if directory and filename:
# 移除目录末尾的分隔符
clean_dir = directory.rstrip('/').rstrip('\\\\')
path = f"{clean_dir}{separator}{filename}"
elif directory:
path = directory
elif filename:
path = filename
else:
path = ""
# 规范化
if normalize and path:
path = str(Path(path)).replace("\\", separator)
return {"path": path}
@register_node
class FileInfo(TraceNode):
"""
文件信息读取器
功能
读取文件的详细信息大小修改时间等
"""
CATEGORY = "IO/Info"
DISPLAY_NAME = "文件信息"
DESCRIPTION = "读取文件的详细信息"
ICON = ""
@input_port("file_path", "String", description="文件路径")
@output_port("exists", "Boolean", description="文件是否存在")
@output_port("size", "Number", description="文件大小(字节)")
@output_port("size_formatted", "String", description="格式化的文件大小")
@output_port("name", "String", description="文件名")
@output_port("extension", "String", description="文件扩展名")
@output_port("directory", "String", description="所在目录")
@context_var("modified_time", "String", description="修改时间")
@context_var("created_time", "String", description="创建时间")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
from datetime import datetime
file_path = inputs.get("file_path", "")
if not file_path:
raise ValueError("必须提供 file_path 输入")
# 加载配置
config = load_system_config()
cloud_root = Path(config.get("storage", {}).get("cloud_root", "./cloud"))
user_id = (context or {}).get("user_id", "guest")
# 构建完整路径
if Path(file_path).is_absolute():
full_path = Path(file_path)
else:
full_path = cloud_root / "users" / user_id / file_path
# 检查文件是否存在
exists = full_path.exists() and full_path.is_file()
if exists:
stat = full_path.stat()
size = stat.st_size
size_formatted = self._format_file_size(size)
name = full_path.name
extension = full_path.suffix
directory = str(full_path.parent)
modified_time = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
created_time = datetime.fromtimestamp(stat.st_ctime).strftime("%Y-%m-%d %H:%M:%S")
else:
size = 0
size_formatted = "0 B"
name = Path(file_path).name
extension = Path(file_path).suffix
directory = str(Path(file_path).parent)
modified_time = ""
created_time = ""
return {
"exists": exists,
"size": size,
"size_formatted": size_formatted,
"name": name,
"extension": extension,
"directory": directory
}
@staticmethod
def _format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
@register_node
class ArrayToString(TraceNode):
"""
数组转字符串
功能
将数组元素连接成一个字符串
"""
CATEGORY = "Array/Transform"
DISPLAY_NAME = "数组转字符串"
DESCRIPTION = "将数组元素用分隔符连接成字符串"
ICON = "📝"
@input_port("array", "Array", description="输入数组")
@output_port("string", "String", description="连接后的字符串")
@output_port("length", "Number", description="数组元素个数")
@param("separator", "String", default=", ", description="分隔符")
@param("prefix", "String", default="", description="前缀")
@param("suffix", "String", default="", description="后缀")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
array = inputs.get("array", [])
if not isinstance(array, list):
array = [array]
separator = self.get_param("separator", ", ")
prefix = self.get_param("prefix", "")
suffix = self.get_param("suffix", "")
# 转换为字符串并连接
string_list = [str(item) for item in array]
result = prefix + separator.join(string_list) + suffix
return {"string": result, "length": len(array)}
@register_node
class StringToArray(TraceNode):
"""
字符串转数组
功能
将字符串按分隔符拆分成数组
"""
CATEGORY = "Array/Transform"
DISPLAY_NAME = "字符串转数组"
DESCRIPTION = "将字符串按分隔符拆分成数组"
ICON = "✂️"
@input_port("string", "String", description="输入字符串")
@output_port("array", "Array<String>", description="拆分后的数组")
@output_port("count", "Number", description="元素个数")
@param("separator", "String", default=",", description="分隔符")
@param("strip_whitespace", "Boolean", default=True, description="去除空白字符")
@param("remove_empty", "Boolean", default=True, description="移除空元素")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
string = inputs.get("string", "")
separator = self.get_param("separator", ",")
strip_whitespace = self.get_param("strip_whitespace", True)
remove_empty = self.get_param("remove_empty", True)
# 拆分字符串
array = string.split(separator)
# 去除空白
if strip_whitespace:
array = [item.strip() for item in array]
# 移除空元素
if remove_empty:
array = [item for item in array if item]
return {"array": array, "count": len(array)}
@register_node
class SaveDataframe(TraceNode):
"""保存 DataFrame 到 CSV 文件"""
CATEGORY = "IO/Save"
DISPLAY_NAME = "保存表为CSV"
DESCRIPTION = "将 DataFrame 保存为 CSV 文件"
ICON = "💾"
@input_port("df", "DataTable", description="要保存的 DataFrame", required=True)
@param("filename", "String", default="output.csv", description="保存的文件名(相对用户目录)", required=True)
@output_port("path", "String", description="保存后的文件路径")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
import polars as pl
from pathlib import Path
df = inputs.get("df", None)
if df is None or not isinstance(df, pl.DataFrame):
raise ValueError("输入必须为 polars.DataFrame")
filepath = self.get_param("filename", "output.csv")
if not filepath:
raise ValueError("必须指定文件名")
save_path = CLOUD_ROOT / filepath
save_path.parent.mkdir(parents=True, exist_ok=True)
# 保存为 CSV
df.write_csv(str(save_path))
return {"path": str(save_path)}