[version][1.0.1] 支持自定义节点重载
This commit is contained in:
parent
87b38cae9c
commit
4844b285f9
109
cloud/custom_nodes/assert_loading_test.py
Normal file
109
cloud/custom_nodes/assert_loading_test.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from server.app.core.node_base import (
|
||||
TraceNode,
|
||||
input_port,
|
||||
output_port,
|
||||
param,
|
||||
)
|
||||
from server.app.core.node_registry import register_node
|
||||
import polars as pl
|
||||
|
||||
@register_node
|
||||
class DungeonFilterNode2(TraceNode):
|
||||
"""
|
||||
副本过滤器节点
|
||||
"""
|
||||
|
||||
CATEGORY = "Filiter"
|
||||
DISPLAY_NAME = "副本过滤器2"
|
||||
DESCRIPTION = "依据规则筛选出副本内的数据"
|
||||
ICON = "📥"
|
||||
|
||||
@input_port("metadata", "DataTable", description="输入的元数据表")
|
||||
@param("sep", "String", default=",", description="分隔符,例如 , 或 \t")
|
||||
@output_port("metadata", "DataTable", description="导出的元数据表")
|
||||
@output_port("table", "DataTable", description="加载得到的表(polars.DataFrame)")
|
||||
def process(self, inputs: Dict[str, Any], context: Optional[object] = None) -> Dict[str, Any]:
|
||||
metadata = inputs.get("metadata", None)
|
||||
if metadata is None or not isinstance(metadata, pl.DataFrame):
|
||||
raise ValueError("metadata 输入必须为 polars.DataFrame")
|
||||
|
||||
# 1. 过滤主线程
|
||||
main_thread_id = 2
|
||||
if "ThreadId" in metadata.columns:
|
||||
main_thread_df = metadata.filter(pl.col("ThreadId") == main_thread_id)
|
||||
else:
|
||||
main_thread_df = metadata
|
||||
|
||||
# 2. 找到所有进入副本的时间点
|
||||
if "Metadata" not in main_thread_df.columns:
|
||||
raise ValueError("缺少 Metadata 列")
|
||||
enter_mask = main_thread_df["Metadata"].str.contains("PackageName:/Game/Asset/Audio/FMOD/Events/bgm/cbt01/level/zhuiji")
|
||||
enter_idx = enter_mask.to_list()
|
||||
enter_indices = [i for i, v in enumerate(enter_idx) if v]
|
||||
if not enter_indices:
|
||||
raise ValueError("未找到进入副本的事件")
|
||||
|
||||
# 3. 找到所有离开副本的时间点
|
||||
leave_mask = main_thread_df["Metadata"].str.contains("Token:TriggerPlayerFinish")
|
||||
leave_idx = leave_mask.to_list()
|
||||
leave_indices = [i for i, v in enumerate(leave_idx) if v]
|
||||
# 允许 leave_indices 为空,后续逻辑自动用最后一条数据为结尾
|
||||
|
||||
# 4. 匹配每一次进入-离开区间,生成 Dungeon 区间
|
||||
|
||||
dungeon_ranges = []
|
||||
leave_ptr = 0
|
||||
n_rows = main_thread_df.height
|
||||
for dungeon_num, enter_index in enumerate(enter_indices, 1):
|
||||
# 找到第一个大于enter_index的leave_index
|
||||
while leave_ptr < len(leave_indices) and leave_indices[leave_ptr] <= enter_index:
|
||||
leave_ptr += 1
|
||||
if leave_ptr < len(leave_indices):
|
||||
leave_index = leave_indices[leave_ptr] - 1
|
||||
leave_ptr += 1
|
||||
else:
|
||||
# 没有匹配到离开副本事件,取最后一条数据为结尾
|
||||
leave_index = n_rows - 1
|
||||
# 只有当进入点在结尾之前才输出区间
|
||||
if enter_index <= leave_index:
|
||||
dungeon_ranges.append((enter_index, leave_index, dungeon_num))
|
||||
|
||||
if not dungeon_ranges:
|
||||
raise ValueError("未能匹配到任何副本区间(所有进入点都在数据结尾之后)")
|
||||
|
||||
# 5. 合并所有区间,并新增 Dungeon 列
|
||||
dfs = []
|
||||
for enter_index, leave_index, dungeon_num in dungeon_ranges:
|
||||
df = main_thread_df.slice(enter_index, leave_index - enter_index + 1).with_columns([
|
||||
pl.lit(dungeon_num).alias("Dungeon")
|
||||
])
|
||||
dfs.append(df)
|
||||
if dfs:
|
||||
filtered_df = pl.concat(dfs)
|
||||
else:
|
||||
filtered_df = pl.DataFrame([])
|
||||
|
||||
# 示例:将筛选到的区间数量写入节点私有上下文,并记录到全局上下文的统计
|
||||
try:
|
||||
# 某些旧实现仍然接收 dict,上下文可能是 ExecutionContext 或纯 dict
|
||||
node_id = getattr(self, "node_id", None) or "unknown"
|
||||
dungeon_count = len(dungeon_ranges)
|
||||
if hasattr(context, "update_node_private"):
|
||||
context.update_node_private(node_id, "dungeon_count", dungeon_count)
|
||||
elif isinstance(context, dict):
|
||||
ns = context.setdefault("nodes", {})
|
||||
ns.setdefault(node_id, {})["dungeon_count"] = dungeon_count
|
||||
|
||||
if hasattr(context, "update_global"):
|
||||
prev = context.get_global("total_dungeons", 0)
|
||||
context.update_global("total_dungeons", prev + dungeon_count)
|
||||
elif isinstance(context, dict):
|
||||
context.setdefault("global", {})["total_dungeons"] = context.get("global", {}).get("total_dungeons", 0) + dungeon_count
|
||||
except Exception:
|
||||
# 不要阻塞主流程:上下文更新失败应当是非致命的
|
||||
pass
|
||||
#print(f"DungeonFilterNode: filtered to {len(dungeon_ranges)} dungeon ranges {filtered_df.shape}.")
|
||||
outputs = {"metadata": filtered_df}
|
||||
return outputs
|
||||
|
||||
@ -170,8 +170,24 @@ def cmd_build(args):
|
||||
"""构建发布包"""
|
||||
print_step(f"清理并构建发布包到: {DIST_DIR}")
|
||||
if DIST_DIR.exists():
|
||||
shutil.rmtree(DIST_DIR)
|
||||
DIST_DIR.mkdir()
|
||||
# 遍历 DIST_DIR 下的所有文件和文件夹
|
||||
for item in DIST_DIR.iterdir():
|
||||
# 【关键】核心保护逻辑:如果名字是 .git,直接跳过
|
||||
if item.name == ".git":
|
||||
continue
|
||||
try:
|
||||
if item.is_dir():
|
||||
#如果是文件夹,递归删除
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
# 如果是文件,直接删除
|
||||
item.unlink()
|
||||
except Exception as e:
|
||||
print(f"!!! 删除失败: {item} - {e}")
|
||||
# Windows 下有时候会因为文件被占用删不掉,视情况决定是否要 raise 阻断流程
|
||||
else:
|
||||
# 如果目录压根不存在,才需要 mkdir
|
||||
DIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ignore = shutil.ignore_patterns(
|
||||
".git", ".gitignore", ".dockerignore", "__pycache__", "node_modules",
|
||||
@ -186,7 +202,7 @@ def cmd_build(args):
|
||||
shutil.copytree(src, dst, ignore=ignore)
|
||||
|
||||
shutil.copytree(ROOT / "cloud/custom_nodes", DIST_DIR / "cloud/custom_nodes", ignore=ignore)
|
||||
|
||||
|
||||
# 2. 复制生产配置
|
||||
prod_compose = ROOT / "docker-compose.yml"
|
||||
if prod_compose.exists():
|
||||
|
||||
@ -12,6 +12,7 @@ from server.app.core.node_base import DimensionMode, NodeType
|
||||
|
||||
from ..core.user_manager import get_user_path
|
||||
from ..core.security import is_safe_path, validate_filename, sanitize_path
|
||||
from ..core.node_loader import reload_custom_nodes
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -397,6 +398,30 @@ async def add_user(payload: Dict[str, Any]):
|
||||
}
|
||||
|
||||
|
||||
@router.post("/reload-custom-nodes")
|
||||
async def reload_custom_nodes_endpoint():
|
||||
"""
|
||||
触发后端重新加载所有自定义节点(包括内置节点的刷新统计)
|
||||
|
||||
Returns:
|
||||
包含加载统计和错误详情的 JSON
|
||||
"""
|
||||
try:
|
||||
result = reload_custom_nodes()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已加载 {result.get('loaded', 0)}/{result.get('total', 0)} 个节点",
|
||||
"details": result
|
||||
}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
raise HTTPException(status_code=500, detail={
|
||||
"error": str(e),
|
||||
"traceback": tb
|
||||
})
|
||||
|
||||
|
||||
@router.post("/nodes/save")
|
||||
async def save_node(payload: Dict[str, Any]):
|
||||
"""
|
||||
|
||||
@ -39,31 +39,71 @@ def load_builtin_nodes() -> Dict[str, Any]:
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
for module_path in nodes_dir.glob("*.py"):
|
||||
module_name = module_path.stem
|
||||
if module_name in ignore_node_modules:
|
||||
continue
|
||||
try:
|
||||
# 导入模块(按完整包路径)
|
||||
module = importlib.import_module(f"server.app.nodes.{module_name}")
|
||||
# 使用文件路径导入模块,更稳健于不同运行上下文(容器、打包后的 dist 等)
|
||||
spec = importlib.util.spec_from_file_location(f"server.app.nodes.{module_name}", str(module_path))
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"无法为 {module_name} 创建模块规范")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# 支持热重载:若模块已在 sys.modules 中,先移除再加载
|
||||
if module.__name__ in sys.modules:
|
||||
del sys.modules[module.__name__]
|
||||
|
||||
spec.loader.exec_module(module)
|
||||
# 将模块放入 sys.modules,允许其它地方按包名引用(如果需要)
|
||||
sys.modules[module.__name__] = module
|
||||
|
||||
loaded += 1
|
||||
print(f"📦 加载节点模块: {module_name}")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
error_msg = f"❌ 加载失败: {module_name} - {str(e)}"
|
||||
# 输出更详细的异常堆栈信息,便于诊断
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
error_msg = f"❌ 加载失败: {module_name} - {str(e)}\n{tb}"
|
||||
errors.append(error_msg)
|
||||
print(error_msg)
|
||||
# todo: 加载 custom_nodes 目录下的用户自定义节点模块,如果有的话
|
||||
if custom_nodes_dir.exists():
|
||||
# `custom_nodes.<name>` 导入;不需要在磁盘上创建 package 文件。
|
||||
if 'custom_nodes' not in sys.modules:
|
||||
cn_pkg = types.ModuleType('custom_nodes')
|
||||
cn_pkg.__path__ = [str(custom_nodes_dir)]
|
||||
sys.modules['custom_nodes'] = cn_pkg
|
||||
|
||||
for module_path in custom_nodes_dir.glob("*.py"):
|
||||
module_name = module_path.stem
|
||||
try:
|
||||
module = importlib.import_module(f"cloud.custom_nodes.{module_name}")
|
||||
# 使用文件路径导入自定义节点,模块名使用 custom_nodes.<module_name>
|
||||
spec = importlib.util.spec_from_file_location(f"custom_nodes.{module_name}", str(module_path))
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"无法为自定义节点 {module_name} 创建模块规范")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
if module.__name__ in sys.modules:
|
||||
del sys.modules[module.__name__]
|
||||
|
||||
spec.loader.exec_module(module)
|
||||
# 注册为 custom_nodes.<name>
|
||||
sys.modules[module.__name__] = module
|
||||
# 兼容性:如果现有代码仍然尝试以 cloud.custom_nodes.<name> 导入,
|
||||
# 可以在此处创建别名(可选)。当前不创建 cloud 别名以遵循你的要求。
|
||||
|
||||
loaded += 1
|
||||
print(f"📦 加载自定义节点模块: {module_name}")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
error_msg = f"❌ 加载自定义节点失败: {module_name} - {str(e)}"
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
error_msg = f"❌ 加载自定义节点失败: {module_name} - {str(e)}\n{tb}"
|
||||
errors.append(error_msg)
|
||||
print(error_msg)
|
||||
# 获取注册统计
|
||||
|
||||
@ -23,9 +23,6 @@ storage:
|
||||
# 云存储根目录(相对于项目根目录 TraceStudio/)
|
||||
cloud_root: "./cloud"
|
||||
|
||||
# 自定义节点目录(相对于项目根目录 TraceStudio/)
|
||||
custom_nodes_dir: "./cloud/custom_nodes"
|
||||
|
||||
# 用户目录结构
|
||||
user_dirs:
|
||||
- "data"
|
||||
|
||||
@ -322,6 +322,35 @@ export default function HeaderBar(){
|
||||
>
|
||||
📂 导入工作流
|
||||
</button>
|
||||
<button
|
||||
onClick={async () => {
|
||||
setShowMenu(false)
|
||||
try {
|
||||
const data = await (await import('../core/services/CustomNodesService')).default.reloadAll()
|
||||
alert(`🔁 ${data?.message || '已触发自定义节点重载'}`)
|
||||
} catch (e) {
|
||||
alert('❌ 触发自定义节点重载失败')
|
||||
console.error(e)
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
width: '100%',
|
||||
padding: '12px 16px',
|
||||
background: 'transparent',
|
||||
border: 'none',
|
||||
borderBottom: '1px solid rgba(255,255,255,0.05)',
|
||||
textAlign: 'left',
|
||||
color: 'rgba(255,255,255,0.8)',
|
||||
fontSize: 13,
|
||||
fontWeight: 500,
|
||||
cursor: 'pointer',
|
||||
transition: 'background 0.15s'
|
||||
}}
|
||||
onMouseEnter={(e) => e.currentTarget.style.background = 'rgba(59,130,246,0.08)'}
|
||||
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
|
||||
>
|
||||
🔁 重载自定义节点
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExport}
|
||||
style={{
|
||||
|
||||
9
web/src/core/api/CustomNodesApi.ts
Normal file
9
web/src/core/api/CustomNodesApi.ts
Normal file
@ -0,0 +1,9 @@
|
||||
import { request } from './api'
|
||||
|
||||
export async function reloadAllCustomNodes() {
|
||||
return request('/api/custom-nodes/reload-all', { method: 'POST' })
|
||||
}
|
||||
|
||||
export async function getLoadedCustomNodes() {
|
||||
return request('/api/custom-nodes/loaded', { method: 'GET' })
|
||||
}
|
||||
@ -17,7 +17,7 @@ interface ApiResponse<T = any> {
|
||||
/**
|
||||
* 通用请求方法
|
||||
*/
|
||||
async function request<T = any>(
|
||||
export async function request<T = any>(
|
||||
endpoint: string,
|
||||
options: RequestInit = {}
|
||||
): Promise<ApiResponse<T>> {
|
||||
|
||||
22
web/src/core/services/CustomNodesService.ts
Normal file
22
web/src/core/services/CustomNodesService.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import * as CustomNodesApi from '../api/CustomNodesApi'
|
||||
import RuntimeService from './RuntimeService'
|
||||
|
||||
const CustomNodesService = {
|
||||
async reloadAll() {
|
||||
const resp = await CustomNodesApi.reloadAllCustomNodes()
|
||||
if (resp.error) throw new Error(resp.error)
|
||||
// Refresh runtime manifest / node meta data
|
||||
await RuntimeService.reloadManifest()
|
||||
// sync runtime to store
|
||||
RuntimeService.syncGraphFromRuntime()
|
||||
return resp.data
|
||||
},
|
||||
|
||||
async getLoaded() {
|
||||
const resp = await CustomNodesApi.getLoadedCustomNodes()
|
||||
if (resp.error) throw new Error(resp.error)
|
||||
return resp.data
|
||||
}
|
||||
}
|
||||
|
||||
export default CustomNodesService
|
||||
Loading…
Reference in New Issue
Block a user