267 lines
8.7 KiB
Python
267 lines
8.7 KiB
Python
|
|
"""
|
|||
|
|
自定义节点验证器
|
|||
|
|
检查节点代码的安全性和规范性
|
|||
|
|
"""
|
|||
|
|
import ast
|
|||
|
|
import re
|
|||
|
|
from typing import Dict, List, Optional, Tuple
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeValidationError(Exception):
|
|||
|
|
"""节点验证错误"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeValidator:
|
|||
|
|
"""节点代码验证器"""
|
|||
|
|
|
|||
|
|
# 危险的模块和函数(黑名单)
|
|||
|
|
DANGEROUS_MODULES = {
|
|||
|
|
'os', 'subprocess', 'sys', 'importlib', 'builtins',
|
|||
|
|
'__import__', 'eval', 'exec', 'compile', 'open',
|
|||
|
|
'file', 'input', 'raw_input'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 危险的函数调用
|
|||
|
|
DANGEROUS_CALLS = {
|
|||
|
|
'eval', 'exec', 'compile', '__import__',
|
|||
|
|
'system', 'popen', 'spawn', 'fork'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 必须的基类
|
|||
|
|
REQUIRED_BASE_CLASS = 'TraceNode'
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.errors: List[str] = []
|
|||
|
|
self.warnings: List[str] = []
|
|||
|
|
|
|||
|
|
def validate_code(self, code: str, filename: str = '<string>') -> Dict:
|
|||
|
|
"""
|
|||
|
|
验证节点代码
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
code: Python代码字符串
|
|||
|
|
filename: 文件名(用于错误提示)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
{
|
|||
|
|
'valid': bool,
|
|||
|
|
'errors': List[str],
|
|||
|
|
'warnings': List[str],
|
|||
|
|
'node_classes': List[str], # 检测到的节点类
|
|||
|
|
'metadata': Dict # 提取的元数据
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
self.errors = []
|
|||
|
|
self.warnings = []
|
|||
|
|
node_classes = []
|
|||
|
|
|
|||
|
|
# 1. 基础语法检查
|
|||
|
|
try:
|
|||
|
|
tree = ast.parse(code, filename=filename)
|
|||
|
|
except SyntaxError as e:
|
|||
|
|
return {
|
|||
|
|
'valid': False,
|
|||
|
|
'errors': [f"语法错误 (行 {e.lineno}): {e.msg}"],
|
|||
|
|
'warnings': [],
|
|||
|
|
'node_classes': [],
|
|||
|
|
'metadata': {}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 2. AST安全检查
|
|||
|
|
self._check_security(tree)
|
|||
|
|
|
|||
|
|
# 3. 检查节点类定义
|
|||
|
|
node_classes = self._check_node_classes(tree)
|
|||
|
|
|
|||
|
|
# 4. 提取元数据
|
|||
|
|
metadata = self._extract_metadata(tree)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'valid': len(self.errors) == 0,
|
|||
|
|
'errors': self.errors,
|
|||
|
|
'warnings': self.warnings,
|
|||
|
|
'node_classes': node_classes,
|
|||
|
|
'metadata': metadata
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _check_security(self, tree: ast.AST):
|
|||
|
|
"""检查代码安全性"""
|
|||
|
|
for node in ast.walk(tree):
|
|||
|
|
# 检查危险的导入
|
|||
|
|
if isinstance(node, ast.Import):
|
|||
|
|
for alias in node.names:
|
|||
|
|
if alias.name in self.DANGEROUS_MODULES:
|
|||
|
|
self.errors.append(
|
|||
|
|
f"❌ 禁止导入危险模块: {alias.name} (行 {node.lineno})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif isinstance(node, ast.ImportFrom):
|
|||
|
|
if node.module in self.DANGEROUS_MODULES:
|
|||
|
|
self.errors.append(
|
|||
|
|
f"❌ 禁止从危险模块导入: {node.module} (行 {node.lineno})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查危险的函数调用
|
|||
|
|
elif isinstance(node, ast.Call):
|
|||
|
|
if isinstance(node.func, ast.Name):
|
|||
|
|
if node.func.id in self.DANGEROUS_CALLS:
|
|||
|
|
self.errors.append(
|
|||
|
|
f"❌ 禁止调用危险函数: {node.func.id} (行 {node.lineno})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查文件操作
|
|||
|
|
elif isinstance(node, ast.Call):
|
|||
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'open':
|
|||
|
|
self.warnings.append(
|
|||
|
|
f"⚠️ 检测到文件操作 (行 {node.lineno}),请确保路径安全"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _check_node_classes(self, tree: ast.AST) -> List[str]:
|
|||
|
|
"""检查节点类定义"""
|
|||
|
|
node_classes = []
|
|||
|
|
|
|||
|
|
for node in ast.walk(tree):
|
|||
|
|
if isinstance(node, ast.ClassDef):
|
|||
|
|
# 检查是否继承TraceNode
|
|||
|
|
has_trace_base = False
|
|||
|
|
for base in node.bases:
|
|||
|
|
if isinstance(base, ast.Name) and base.id == self.REQUIRED_BASE_CLASS:
|
|||
|
|
has_trace_base = True
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if has_trace_base:
|
|||
|
|
node_classes.append(node.name)
|
|||
|
|
|
|||
|
|
# 检查必须的方法
|
|||
|
|
methods = {n.name for n in node.body if isinstance(n, ast.FunctionDef)}
|
|||
|
|
|
|||
|
|
if 'execute' not in methods:
|
|||
|
|
self.errors.append(
|
|||
|
|
f"❌ 节点类 {node.name} 缺少 execute() 方法 (行 {node.lineno})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if 'get_metadata' not in methods:
|
|||
|
|
self.warnings.append(
|
|||
|
|
f"⚠️ 节点类 {node.name} 建议实现 get_metadata() 静态方法"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not node_classes:
|
|||
|
|
self.errors.append(
|
|||
|
|
f"❌ 未找到继承自 {self.REQUIRED_BASE_CLASS} 的节点类"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return node_classes
|
|||
|
|
|
|||
|
|
def _extract_metadata(self, tree: ast.AST) -> Dict:
|
|||
|
|
"""提取节点元数据"""
|
|||
|
|
metadata = {}
|
|||
|
|
|
|||
|
|
for node in ast.walk(tree):
|
|||
|
|
if isinstance(node, ast.ClassDef):
|
|||
|
|
for item in node.body:
|
|||
|
|
if isinstance(item, ast.FunctionDef) and item.name == 'get_metadata':
|
|||
|
|
# 尝试提取返回的字典
|
|||
|
|
for stmt in item.body:
|
|||
|
|
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Dict):
|
|||
|
|
try:
|
|||
|
|
# 简单提取(只处理字面量)
|
|||
|
|
metadata = ast.literal_eval(stmt.value)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return metadata
|
|||
|
|
|
|||
|
|
def validate_file(self, filepath: Path) -> Dict:
|
|||
|
|
"""验证节点文件"""
|
|||
|
|
try:
|
|||
|
|
code = filepath.read_text(encoding='utf-8')
|
|||
|
|
return self.validate_code(code, str(filepath))
|
|||
|
|
except Exception as e:
|
|||
|
|
return {
|
|||
|
|
'valid': False,
|
|||
|
|
'errors': [f"读取文件失败: {str(e)}"],
|
|||
|
|
'warnings': [],
|
|||
|
|
'node_classes': [],
|
|||
|
|
'metadata': {}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def is_safe_filename(filename: str) -> Tuple[bool, Optional[str]]:
|
|||
|
|
"""
|
|||
|
|
检查文件名是否安全
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(is_safe, error_message)
|
|||
|
|
"""
|
|||
|
|
# 必须是.py文件
|
|||
|
|
if not filename.endswith('.py'):
|
|||
|
|
return False, "文件必须是 .py 扩展名"
|
|||
|
|
|
|||
|
|
# 不能以__开头(Python内部文件)
|
|||
|
|
name = filename[:-3]
|
|||
|
|
if name.startswith('__'):
|
|||
|
|
return False, "文件名不能以 __ 开头"
|
|||
|
|
|
|||
|
|
# 不能包含特殊字符
|
|||
|
|
if not re.match(r'^[a-zA-Z0-9_]+\.py$', filename):
|
|||
|
|
return False, "文件名只能包含字母、数字和下划线"
|
|||
|
|
|
|||
|
|
# 保留的文件名
|
|||
|
|
reserved = {'__init__', 'test', 'setup', 'config'}
|
|||
|
|
if name.lower() in reserved:
|
|||
|
|
return False, f"文件名 '{name}' 是保留名称"
|
|||
|
|
|
|||
|
|
return True, None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeSandbox:
|
|||
|
|
"""节点执行沙箱(简化版)"""
|
|||
|
|
|
|||
|
|
MAX_EXECUTION_TIME = 30 # 秒
|
|||
|
|
MAX_MEMORY_MB = 512
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def execute_with_limits(node_instance, inputs: Dict, timeout: int = None):
|
|||
|
|
"""
|
|||
|
|
在资源限制下执行节点
|
|||
|
|
|
|||
|
|
TODO: 实现真正的沙箱(使用 RestrictedPython 或容器)
|
|||
|
|
"""
|
|||
|
|
import time
|
|||
|
|
import threading
|
|||
|
|
|
|||
|
|
timeout = timeout or NodeSandbox.MAX_EXECUTION_TIME
|
|||
|
|
result = {'output': None, 'error': None, 'timeout': False}
|
|||
|
|
|
|||
|
|
def target():
|
|||
|
|
try:
|
|||
|
|
result['output'] = node_instance.execute(inputs)
|
|||
|
|
except Exception as e:
|
|||
|
|
result['error'] = str(e)
|
|||
|
|
|
|||
|
|
thread = threading.Thread(target=target)
|
|||
|
|
thread.daemon = True
|
|||
|
|
thread.start()
|
|||
|
|
thread.join(timeout)
|
|||
|
|
|
|||
|
|
if thread.is_alive():
|
|||
|
|
result['timeout'] = True
|
|||
|
|
result['error'] = f"节点执行超时 ({timeout}秒)"
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 便捷函数
|
|||
|
|
def validate_node_code(code: str) -> Dict:
|
|||
|
|
"""验证节点代码(快捷方式)"""
|
|||
|
|
validator = NodeValidator()
|
|||
|
|
return validator.validate_code(code)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_node_file(filepath: str) -> Dict:
|
|||
|
|
"""验证节点文件(快捷方式)"""
|
|||
|
|
validator = NodeValidator()
|
|||
|
|
return validator.validate_file(Path(filepath))
|