TraceStudio-dist/cloud/custom_nodes/csv_loader.py
2026-01-13 16:41:31 +08:00

42 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Optional
from server.app.core.node_base import TraceNode, input_port, output_port, param
from server.app.core.node_registry import register_node
import polars as pl
from pathlib import Path
from server.app.core.user_manager import CLOUD_ROOT
@register_node
class CsvLoaderNode(TraceNode):
"""
简单 CSV 加载器Polars
输入:相对于用户目录或绝对路径的 CSV 文件路径
输出Polars DataFrameDataTable
"""
CATEGORY = "IO/Load"
DISPLAY_NAME = "CSV 加载器"
DESCRIPTION = "从 CSV 文件加载为 polars.DataFrame支持用户目录下路径"
ICON = "📥"
@input_port("path", "String", description="CSV 文件路径(相对于用户目录或绝对路径)", required=True)
@param("sep", "String", default=",", description="分隔符,例如 , 或 \t")
@output_port("table", "DataTable", description="加载得到的表polars.DataFrame")
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
file_path = inputs.get("path", ".csv")
file_path = CLOUD_ROOT / file_path
if not file_path.exists() or not file_path.is_file():
raise FileNotFoundError(f"CSV 文件未找到: {file_path}")
sep = str(self.get_param("sep", ","))
# polars will infer dtypes; allow user to override in future
try:
if sep.lower() == "\\t":
sep_char = "\t"
else:
sep_char = sep
df = pl.read_csv(file_path, separator=sep_char)
except Exception as e:
raise RuntimeError(f"读取 CSV 失败: {e}")
return {"table": df}