209 lines
6.3 KiB
Python
209 lines
6.3 KiB
Python
"""
|
|
硬规则安全检查器
|
|
静态扫描执行代码,检测危险操作
|
|
"""
|
|
|
|
import re
|
|
import ast
|
|
from typing import List, Tuple
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class RuleCheckResult:
|
|
"""规则检查结果"""
|
|
passed: bool
|
|
violations: List[str] # 违规项列表
|
|
|
|
|
|
class RuleChecker:
|
|
"""
|
|
硬规则检查器
|
|
|
|
静态扫描代码,检测以下危险操作:
|
|
1. 网络请求: requests, socket, urllib, http.client
|
|
2. 危险文件操作: os.remove, shutil.rmtree, os.unlink
|
|
3. 执行外部命令: subprocess, os.system, os.popen
|
|
4. 访问非 workspace 路径
|
|
"""
|
|
|
|
# 禁止导入的模块
|
|
FORBIDDEN_IMPORTS = {
|
|
'requests',
|
|
'socket',
|
|
'urllib',
|
|
'urllib.request',
|
|
'urllib.parse',
|
|
'urllib.error',
|
|
'http.client',
|
|
'httplib',
|
|
'ftplib',
|
|
'smtplib',
|
|
'telnetlib',
|
|
'subprocess',
|
|
}
|
|
|
|
# 禁止调用的函数(模块.函数 或 单独函数名)
|
|
FORBIDDEN_CALLS = {
|
|
'os.remove',
|
|
'os.unlink',
|
|
'os.rmdir',
|
|
'os.removedirs',
|
|
'os.system',
|
|
'os.popen',
|
|
'os.spawn',
|
|
'os.spawnl',
|
|
'os.spawnle',
|
|
'os.spawnlp',
|
|
'os.spawnlpe',
|
|
'os.spawnv',
|
|
'os.spawnve',
|
|
'os.spawnvp',
|
|
'os.spawnvpe',
|
|
'os.exec',
|
|
'os.execl',
|
|
'os.execle',
|
|
'os.execlp',
|
|
'os.execlpe',
|
|
'os.execv',
|
|
'os.execve',
|
|
'os.execvp',
|
|
'os.execvpe',
|
|
'shutil.rmtree',
|
|
'shutil.move', # move 可能导致原文件丢失
|
|
'eval',
|
|
'exec',
|
|
'compile',
|
|
'__import__',
|
|
}
|
|
|
|
# 危险路径模式(正则)
|
|
DANGEROUS_PATH_PATTERNS = [
|
|
r'[A-Za-z]:\\', # Windows 绝对路径
|
|
r'\\\\', # UNC 路径
|
|
r'/etc/',
|
|
r'/usr/',
|
|
r'/bin/',
|
|
r'/home/',
|
|
r'/root/',
|
|
r'\.\./', # 父目录遍历
|
|
r'\.\.', # 父目录
|
|
]
|
|
|
|
def check(self, code: str) -> RuleCheckResult:
|
|
"""
|
|
检查代码是否符合安全规则
|
|
|
|
Args:
|
|
code: Python 代码字符串
|
|
|
|
Returns:
|
|
RuleCheckResult: 检查结果
|
|
"""
|
|
violations = []
|
|
|
|
# 1. 检查禁止的导入
|
|
import_violations = self._check_imports(code)
|
|
violations.extend(import_violations)
|
|
|
|
# 2. 检查禁止的函数调用
|
|
call_violations = self._check_calls(code)
|
|
violations.extend(call_violations)
|
|
|
|
# 3. 检查危险路径
|
|
path_violations = self._check_paths(code)
|
|
violations.extend(path_violations)
|
|
|
|
return RuleCheckResult(
|
|
passed=len(violations) == 0,
|
|
violations=violations
|
|
)
|
|
|
|
def _check_imports(self, code: str) -> List[str]:
|
|
"""检查禁止的导入"""
|
|
violations = []
|
|
|
|
try:
|
|
tree = ast.parse(code)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
module_name = alias.name.split('.')[0]
|
|
if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
|
|
violations.append(f"禁止导入模块: {alias.name}")
|
|
|
|
elif isinstance(node, ast.ImportFrom):
|
|
if node.module:
|
|
module_name = node.module.split('.')[0]
|
|
if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
|
|
violations.append(f"禁止导入模块: {node.module}")
|
|
|
|
except SyntaxError:
|
|
# 如果代码有语法错误,使用正则匹配
|
|
for module in self.FORBIDDEN_IMPORTS:
|
|
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
|
|
if re.search(pattern, code):
|
|
violations.append(f"禁止导入模块: {module}")
|
|
|
|
return violations
|
|
|
|
def _check_calls(self, code: str) -> List[str]:
|
|
"""检查禁止的函数调用"""
|
|
violations = []
|
|
|
|
try:
|
|
tree = ast.parse(code)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Call):
|
|
call_name = self._get_call_name(node)
|
|
if call_name in self.FORBIDDEN_CALLS:
|
|
violations.append(f"禁止调用函数: {call_name}")
|
|
|
|
except SyntaxError:
|
|
# 如果代码有语法错误,使用正则匹配
|
|
for func in self.FORBIDDEN_CALLS:
|
|
pattern = rf'\b{re.escape(func)}\s*\('
|
|
if re.search(pattern, code):
|
|
violations.append(f"禁止调用函数: {func}")
|
|
|
|
return violations
|
|
|
|
def _get_call_name(self, node: ast.Call) -> str:
|
|
"""获取函数调用的完整名称"""
|
|
if isinstance(node.func, ast.Name):
|
|
return node.func.id
|
|
elif isinstance(node.func, ast.Attribute):
|
|
parts = []
|
|
current = node.func
|
|
while isinstance(current, ast.Attribute):
|
|
parts.append(current.attr)
|
|
current = current.value
|
|
if isinstance(current, ast.Name):
|
|
parts.append(current.id)
|
|
return '.'.join(reversed(parts))
|
|
return ''
|
|
|
|
def _check_paths(self, code: str) -> List[str]:
|
|
"""检查危险路径访问"""
|
|
violations = []
|
|
|
|
for pattern in self.DANGEROUS_PATH_PATTERNS:
|
|
matches = re.findall(pattern, code, re.IGNORECASE)
|
|
if matches:
|
|
# 排除 workspace 相关的合法路径
|
|
for match in matches:
|
|
if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower():
|
|
violations.append(f"检测到可疑路径模式: {match}")
|
|
break
|
|
|
|
return violations
|
|
|
|
|
|
def check_code_safety(code: str) -> RuleCheckResult:
|
|
"""便捷函数:检查代码安全性"""
|
|
checker = RuleChecker()
|
|
return checker.check(code)
|
|
|