42 lines
1.7 KiB
Python
42 lines
1.7 KiB
Python
from typing import Any, Dict, Optional
|
||
from server.app.core.node_base import TraceNode, input_port, output_port, param
|
||
from server.app.core.node_registry import register_node
|
||
import polars as pl
|
||
from pathlib import Path
|
||
from server.app.core.user_manager import CLOUD_ROOT
|
||
|
||
@register_node
|
||
class CsvLoaderNode(TraceNode):
|
||
"""
|
||
简单 CSV 加载器(Polars)
|
||
输入:相对于用户目录或绝对路径的 CSV 文件路径
|
||
输出:Polars DataFrame(DataTable)
|
||
"""
|
||
|
||
CATEGORY = "IO/Load"
|
||
DISPLAY_NAME = "CSV 加载器"
|
||
DESCRIPTION = "从 CSV 文件加载为 polars.DataFrame(支持用户目录下路径)"
|
||
ICON = "📥"
|
||
|
||
@input_port("path", "String", description="CSV 文件路径(相对于用户目录或绝对路径)", required=True)
|
||
@param("sep", "String", default=",", description="分隔符,例如 , 或 \t")
|
||
@output_port("table", "DataTable", description="加载得到的表(polars.DataFrame)")
|
||
def process(self, inputs: Dict[str, Any], context: Optional[Dict] = None) -> Dict[str, Any]:
|
||
file_path = inputs.get("path", ".csv")
|
||
file_path = CLOUD_ROOT / file_path
|
||
|
||
if not file_path.exists() or not file_path.is_file():
|
||
raise FileNotFoundError(f"CSV 文件未找到: {file_path}")
|
||
|
||
sep = str(self.get_param("sep", ","))
|
||
# polars will infer dtypes; allow user to override in future
|
||
try:
|
||
if sep.lower() == "\\t":
|
||
sep_char = "\t"
|
||
else:
|
||
sep_char = sep
|
||
df = pl.read_csv(file_path, separator=sep_char)
|
||
except Exception as e:
|
||
raise RuntimeError(f"读取 CSV 失败: {e}")
|
||
|
||
return {"table": df} |