TraceStudio/server/app/core/node_validator.py

267 lines
8.7 KiB
Python
Raw Permalink Normal View History

2026-01-12 21:51:45 +08:00
"""
自定义节点验证器
检查节点代码的安全性和规范性
"""
import ast
import re
from typing import Dict, List, Optional, Tuple
from pathlib import Path
class NodeValidationError(Exception):
"""节点验证错误"""
pass
class NodeValidator:
"""节点代码验证器"""
# 危险的模块和函数(黑名单)
DANGEROUS_MODULES = {
'os', 'subprocess', 'sys', 'importlib', 'builtins',
'__import__', 'eval', 'exec', 'compile', 'open',
'file', 'input', 'raw_input'
}
# 危险的函数调用
DANGEROUS_CALLS = {
'eval', 'exec', 'compile', '__import__',
'system', 'popen', 'spawn', 'fork'
}
# 必须的基类
REQUIRED_BASE_CLASS = 'TraceNode'
def __init__(self):
self.errors: List[str] = []
self.warnings: List[str] = []
def validate_code(self, code: str, filename: str = '<string>') -> Dict:
"""
验证节点代码
Args:
code: Python代码字符串
filename: 文件名用于错误提示
Returns:
{
'valid': bool,
'errors': List[str],
'warnings': List[str],
'node_classes': List[str], # 检测到的节点类
'metadata': Dict # 提取的元数据
}
"""
self.errors = []
self.warnings = []
node_classes = []
# 1. 基础语法检查
try:
tree = ast.parse(code, filename=filename)
except SyntaxError as e:
return {
'valid': False,
'errors': [f"语法错误 (行 {e.lineno}): {e.msg}"],
'warnings': [],
'node_classes': [],
'metadata': {}
}
# 2. AST安全检查
self._check_security(tree)
# 3. 检查节点类定义
node_classes = self._check_node_classes(tree)
# 4. 提取元数据
metadata = self._extract_metadata(tree)
return {
'valid': len(self.errors) == 0,
'errors': self.errors,
'warnings': self.warnings,
'node_classes': node_classes,
'metadata': metadata
}
def _check_security(self, tree: ast.AST):
"""检查代码安全性"""
for node in ast.walk(tree):
# 检查危险的导入
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name in self.DANGEROUS_MODULES:
self.errors.append(
f"❌ 禁止导入危险模块: {alias.name} (行 {node.lineno})"
)
elif isinstance(node, ast.ImportFrom):
if node.module in self.DANGEROUS_MODULES:
self.errors.append(
f"❌ 禁止从危险模块导入: {node.module} (行 {node.lineno})"
)
# 检查危险的函数调用
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in self.DANGEROUS_CALLS:
self.errors.append(
f"❌ 禁止调用危险函数: {node.func.id} (行 {node.lineno})"
)
# 检查文件操作
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == 'open':
self.warnings.append(
f"⚠️ 检测到文件操作 (行 {node.lineno}),请确保路径安全"
)
def _check_node_classes(self, tree: ast.AST) -> List[str]:
"""检查节点类定义"""
node_classes = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# 检查是否继承TraceNode
has_trace_base = False
for base in node.bases:
if isinstance(base, ast.Name) and base.id == self.REQUIRED_BASE_CLASS:
has_trace_base = True
break
if has_trace_base:
node_classes.append(node.name)
# 检查必须的方法
methods = {n.name for n in node.body if isinstance(n, ast.FunctionDef)}
if 'execute' not in methods:
self.errors.append(
f"❌ 节点类 {node.name} 缺少 execute() 方法 (行 {node.lineno})"
)
if 'get_metadata' not in methods:
self.warnings.append(
f"⚠️ 节点类 {node.name} 建议实现 get_metadata() 静态方法"
)
if not node_classes:
self.errors.append(
f"❌ 未找到继承自 {self.REQUIRED_BASE_CLASS} 的节点类"
)
return node_classes
def _extract_metadata(self, tree: ast.AST) -> Dict:
"""提取节点元数据"""
metadata = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for item in node.body:
if isinstance(item, ast.FunctionDef) and item.name == 'get_metadata':
# 尝试提取返回的字典
for stmt in item.body:
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Dict):
try:
# 简单提取(只处理字面量)
metadata = ast.literal_eval(stmt.value)
except:
pass
return metadata
def validate_file(self, filepath: Path) -> Dict:
"""验证节点文件"""
try:
code = filepath.read_text(encoding='utf-8')
return self.validate_code(code, str(filepath))
except Exception as e:
return {
'valid': False,
'errors': [f"读取文件失败: {str(e)}"],
'warnings': [],
'node_classes': [],
'metadata': {}
}
@staticmethod
def is_safe_filename(filename: str) -> Tuple[bool, Optional[str]]:
"""
检查文件名是否安全
Returns:
(is_safe, error_message)
"""
# 必须是.py文件
if not filename.endswith('.py'):
return False, "文件必须是 .py 扩展名"
# 不能以__开头Python内部文件
name = filename[:-3]
if name.startswith('__'):
return False, "文件名不能以 __ 开头"
# 不能包含特殊字符
if not re.match(r'^[a-zA-Z0-9_]+\.py$', filename):
return False, "文件名只能包含字母、数字和下划线"
# 保留的文件名
reserved = {'__init__', 'test', 'setup', 'config'}
if name.lower() in reserved:
return False, f"文件名 '{name}' 是保留名称"
return True, None
class NodeSandbox:
"""节点执行沙箱(简化版)"""
MAX_EXECUTION_TIME = 30 # 秒
MAX_MEMORY_MB = 512
@staticmethod
def execute_with_limits(node_instance, inputs: Dict, timeout: int = None):
"""
在资源限制下执行节点
TODO: 实现真正的沙箱使用 RestrictedPython 或容器
"""
import time
import threading
timeout = timeout or NodeSandbox.MAX_EXECUTION_TIME
result = {'output': None, 'error': None, 'timeout': False}
def target():
try:
result['output'] = node_instance.execute(inputs)
except Exception as e:
result['error'] = str(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout)
if thread.is_alive():
result['timeout'] = True
result['error'] = f"节点执行超时 ({timeout}秒)"
return result
# 便捷函数
def validate_node_code(code: str) -> Dict:
"""验证节点代码(快捷方式)"""
validator = NodeValidator()
return validator.validate_code(code)
def validate_node_file(filepath: str) -> Dict:
"""验证节点文件(快捷方式)"""
validator = NodeValidator()
return validator.validate_file(Path(filepath))