TraceStudio-dev/server/app/core/node_validator.py
2026-01-09 21:37:02 +08:00

267 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
自定义节点验证器
检查节点代码的安全性和规范性
"""
import ast
import re
from typing import Dict, List, Optional, Tuple
from pathlib import Path
class NodeValidationError(Exception):
"""节点验证错误"""
pass
class NodeValidator:
"""节点代码验证器"""
# 危险的模块和函数(黑名单)
DANGEROUS_MODULES = {
'os', 'subprocess', 'sys', 'importlib', 'builtins',
'__import__', 'eval', 'exec', 'compile', 'open',
'file', 'input', 'raw_input'
}
# 危险的函数调用
DANGEROUS_CALLS = {
'eval', 'exec', 'compile', '__import__',
'system', 'popen', 'spawn', 'fork'
}
# 必须的基类
REQUIRED_BASE_CLASS = 'TraceNode'
def __init__(self):
self.errors: List[str] = []
self.warnings: List[str] = []
def validate_code(self, code: str, filename: str = '<string>') -> Dict:
"""
验证节点代码
Args:
code: Python代码字符串
filename: 文件名(用于错误提示)
Returns:
{
'valid': bool,
'errors': List[str],
'warnings': List[str],
'node_classes': List[str], # 检测到的节点类
'metadata': Dict # 提取的元数据
}
"""
self.errors = []
self.warnings = []
node_classes = []
# 1. 基础语法检查
try:
tree = ast.parse(code, filename=filename)
except SyntaxError as e:
return {
'valid': False,
'errors': [f"语法错误 (行 {e.lineno}): {e.msg}"],
'warnings': [],
'node_classes': [],
'metadata': {}
}
# 2. AST安全检查
self._check_security(tree)
# 3. 检查节点类定义
node_classes = self._check_node_classes(tree)
# 4. 提取元数据
metadata = self._extract_metadata(tree)
return {
'valid': len(self.errors) == 0,
'errors': self.errors,
'warnings': self.warnings,
'node_classes': node_classes,
'metadata': metadata
}
def _check_security(self, tree: ast.AST):
"""检查代码安全性"""
for node in ast.walk(tree):
# 检查危险的导入
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name in self.DANGEROUS_MODULES:
self.errors.append(
f"❌ 禁止导入危险模块: {alias.name} (行 {node.lineno})"
)
elif isinstance(node, ast.ImportFrom):
if node.module in self.DANGEROUS_MODULES:
self.errors.append(
f"❌ 禁止从危险模块导入: {node.module} (行 {node.lineno})"
)
# 检查危险的函数调用
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in self.DANGEROUS_CALLS:
self.errors.append(
f"❌ 禁止调用危险函数: {node.func.id} (行 {node.lineno})"
)
# 检查文件操作
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == 'open':
self.warnings.append(
f"⚠️ 检测到文件操作 (行 {node.lineno}),请确保路径安全"
)
def _check_node_classes(self, tree: ast.AST) -> List[str]:
"""检查节点类定义"""
node_classes = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# 检查是否继承TraceNode
has_trace_base = False
for base in node.bases:
if isinstance(base, ast.Name) and base.id == self.REQUIRED_BASE_CLASS:
has_trace_base = True
break
if has_trace_base:
node_classes.append(node.name)
# 检查必须的方法
methods = {n.name for n in node.body if isinstance(n, ast.FunctionDef)}
if 'execute' not in methods:
self.errors.append(
f"❌ 节点类 {node.name} 缺少 execute() 方法 (行 {node.lineno})"
)
if 'get_metadata' not in methods:
self.warnings.append(
f"⚠️ 节点类 {node.name} 建议实现 get_metadata() 静态方法"
)
if not node_classes:
self.errors.append(
f"❌ 未找到继承自 {self.REQUIRED_BASE_CLASS} 的节点类"
)
return node_classes
def _extract_metadata(self, tree: ast.AST) -> Dict:
"""提取节点元数据"""
metadata = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for item in node.body:
if isinstance(item, ast.FunctionDef) and item.name == 'get_metadata':
# 尝试提取返回的字典
for stmt in item.body:
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Dict):
try:
# 简单提取(只处理字面量)
metadata = ast.literal_eval(stmt.value)
except:
pass
return metadata
def validate_file(self, filepath: Path) -> Dict:
"""验证节点文件"""
try:
code = filepath.read_text(encoding='utf-8')
return self.validate_code(code, str(filepath))
except Exception as e:
return {
'valid': False,
'errors': [f"读取文件失败: {str(e)}"],
'warnings': [],
'node_classes': [],
'metadata': {}
}
@staticmethod
def is_safe_filename(filename: str) -> Tuple[bool, Optional[str]]:
"""
检查文件名是否安全
Returns:
(is_safe, error_message)
"""
# 必须是.py文件
if not filename.endswith('.py'):
return False, "文件必须是 .py 扩展名"
# 不能以__开头Python内部文件
name = filename[:-3]
if name.startswith('__'):
return False, "文件名不能以 __ 开头"
# 不能包含特殊字符
if not re.match(r'^[a-zA-Z0-9_]+\.py$', filename):
return False, "文件名只能包含字母、数字和下划线"
# 保留的文件名
reserved = {'__init__', 'test', 'setup', 'config'}
if name.lower() in reserved:
return False, f"文件名 '{name}' 是保留名称"
return True, None
class NodeSandbox:
"""节点执行沙箱(简化版)"""
MAX_EXECUTION_TIME = 30 # 秒
MAX_MEMORY_MB = 512
@staticmethod
def execute_with_limits(node_instance, inputs: Dict, timeout: int = None):
"""
在资源限制下执行节点
TODO: 实现真正的沙箱(使用 RestrictedPython 或容器)
"""
import time
import threading
timeout = timeout or NodeSandbox.MAX_EXECUTION_TIME
result = {'output': None, 'error': None, 'timeout': False}
def target():
try:
result['output'] = node_instance.execute(inputs)
except Exception as e:
result['error'] = str(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout)
if thread.is_alive():
result['timeout'] = True
result['error'] = f"节点执行超时 ({timeout}秒)"
return result
# 便捷函数
def validate_node_code(code: str) -> Dict:
"""验证节点代码(快捷方式)"""
validator = NodeValidator()
return validator.validate_code(code)
def validate_node_file(filepath: str) -> Dict:
"""验证节点文件(快捷方式)"""
validator = NodeValidator()
return validator.validate_file(Path(filepath))