""" 硬规则安全检查器 静态扫描执行代码,检测危险操作 """ import re import ast from typing import List, Tuple from dataclasses import dataclass @dataclass class RuleCheckResult: """规则检查结果""" passed: bool violations: List[str] # 违规项列表 class RuleChecker: """ 硬规则检查器 静态扫描代码,检测以下危险操作: 1. 网络请求: requests, socket, urllib, http.client 2. 危险文件操作: os.remove, shutil.rmtree, os.unlink 3. 执行外部命令: subprocess, os.system, os.popen 4. 访问非 workspace 路径 """ # 禁止导入的模块 FORBIDDEN_IMPORTS = { 'requests', 'socket', 'urllib', 'urllib.request', 'urllib.parse', 'urllib.error', 'http.client', 'httplib', 'ftplib', 'smtplib', 'telnetlib', 'subprocess', } # 禁止调用的函数(模块.函数 或 单独函数名) FORBIDDEN_CALLS = { 'os.remove', 'os.unlink', 'os.rmdir', 'os.removedirs', 'os.system', 'os.popen', 'os.spawn', 'os.spawnl', 'os.spawnle', 'os.spawnlp', 'os.spawnlpe', 'os.spawnv', 'os.spawnve', 'os.spawnvp', 'os.spawnvpe', 'os.exec', 'os.execl', 'os.execle', 'os.execlp', 'os.execlpe', 'os.execv', 'os.execve', 'os.execvp', 'os.execvpe', 'shutil.rmtree', 'shutil.move', # move 可能导致原文件丢失 'eval', 'exec', 'compile', '__import__', } # 危险路径模式(正则) DANGEROUS_PATH_PATTERNS = [ r'[A-Za-z]:\\', # Windows 绝对路径 r'\\\\', # UNC 路径 r'/etc/', r'/usr/', r'/bin/', r'/home/', r'/root/', r'\.\./', # 父目录遍历 r'\.\.', # 父目录 ] def check(self, code: str) -> RuleCheckResult: """ 检查代码是否符合安全规则 Args: code: Python 代码字符串 Returns: RuleCheckResult: 检查结果 """ violations = [] # 1. 检查禁止的导入 import_violations = self._check_imports(code) violations.extend(import_violations) # 2. 检查禁止的函数调用 call_violations = self._check_calls(code) violations.extend(call_violations) # 3. 检查危险路径 path_violations = self._check_paths(code) violations.extend(path_violations) return RuleCheckResult( passed=len(violations) == 0, violations=violations ) def _check_imports(self, code: str) -> List[str]: """检查禁止的导入""" violations = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: module_name = alias.name.split('.')[0] if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: violations.append(f"禁止导入模块: {alias.name}") elif isinstance(node, ast.ImportFrom): if node.module: module_name = node.module.split('.')[0] if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: violations.append(f"禁止导入模块: {node.module}") except SyntaxError: # 如果代码有语法错误,使用正则匹配 for module in self.FORBIDDEN_IMPORTS: pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b' if re.search(pattern, code): violations.append(f"禁止导入模块: {module}") return violations def _check_calls(self, code: str) -> List[str]: """检查禁止的函数调用""" violations = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Call): call_name = self._get_call_name(node) if call_name in self.FORBIDDEN_CALLS: violations.append(f"禁止调用函数: {call_name}") except SyntaxError: # 如果代码有语法错误,使用正则匹配 for func in self.FORBIDDEN_CALLS: pattern = rf'\b{re.escape(func)}\s*\(' if re.search(pattern, code): violations.append(f"禁止调用函数: {func}") return violations def _get_call_name(self, node: ast.Call) -> str: """获取函数调用的完整名称""" if isinstance(node.func, ast.Name): return node.func.id elif isinstance(node.func, ast.Attribute): parts = [] current = node.func while isinstance(current, ast.Attribute): parts.append(current.attr) current = current.value if isinstance(current, ast.Name): parts.append(current.id) return '.'.join(reversed(parts)) return '' def _check_paths(self, code: str) -> List[str]: """检查危险路径访问""" violations = [] for pattern in self.DANGEROUS_PATH_PATTERNS: matches = re.findall(pattern, code, re.IGNORECASE) if matches: # 排除 workspace 相关的合法路径 for match in matches: if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower(): violations.append(f"检测到可疑路径模式: {match}") break return violations def check_code_safety(code: str) -> RuleCheckResult: """便捷函数:检查代码安全性""" checker = RuleChecker() return checker.check(code)