""" 硬规则安全检查器 只检测最危险的操作,其他交给 LLM 审查 """ import re import ast from typing import List from dataclasses import dataclass @dataclass class RuleCheckResult: """规则检查结果""" passed: bool violations: List[str] # 违规项列表 warnings: List[str] # 警告项(交给 LLM 审查) class RuleChecker: """ 硬规则检查器 只硬性禁止最危险的操作: 1. 网络模块: socket(底层网络) 2. 执行任意代码: eval, exec, compile 3. 执行系统命令: subprocess, os.system, os.popen 4. 动态导入: __import__ 其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查 """ # 【硬性禁止】最危险的模块 - 直接拒绝 CRITICAL_FORBIDDEN_IMPORTS = { 'socket', # 底层网络,可绑定端口、建立连接 'subprocess', # 执行任意系统命令 'multiprocessing', # 可能绑定端口 'asyncio', # 可能包含网络操作 'ctypes', # 可调用任意 C 函数 'cffi', # 外部函数接口 } # 【硬性禁止】最危险的函数调用 - 直接拒绝 CRITICAL_FORBIDDEN_CALLS = { # 执行任意代码 'eval', 'exec', 'compile', '__import__', # 执行系统命令 'os.system', 'os.popen', 'os.spawn', 'os.spawnl', 'os.spawnle', 'os.spawnlp', 'os.spawnlpe', 'os.spawnv', 'os.spawnve', 'os.spawnvp', 'os.spawnvpe', 'os.execl', 'os.execle', 'os.execlp', 'os.execlpe', 'os.execv', 'os.execve', 'os.execvp', 'os.execvpe', } # 【警告】需要 LLM 审查的模块 WARNING_IMPORTS = { 'requests', # HTTP 请求 'urllib', # URL 处理 'http.client', # HTTP 客户端 'ftplib', # FTP 'smtplib', # 邮件 'telnetlib', # Telnet } # 【警告】需要 LLM 审查的函数调用 WARNING_CALLS = { 'os.remove', # 删除文件 'os.unlink', # 删除文件 'os.rmdir', # 删除目录 'os.removedirs', # 递归删除目录 'shutil.rmtree', # 递归删除目录树 'shutil.move', # 移动文件(可能丢失原文件) 'open', # 打开文件(检查路径) } def check(self, code: str) -> RuleCheckResult: """ 检查代码是否符合安全规则 Args: code: Python 代码字符串 Returns: RuleCheckResult: 检查结果 """ violations = [] # 硬性违规,直接拒绝 warnings = [] # 警告,交给 LLM 审查 # 1. 检查硬性禁止的导入 critical_import_violations = self._check_critical_imports(code) violations.extend(critical_import_violations) # 2. 检查硬性禁止的函数调用 critical_call_violations = self._check_critical_calls(code) violations.extend(critical_call_violations) # 3. 检查警告级别的导入 warning_imports = self._check_warning_imports(code) warnings.extend(warning_imports) # 4. 检查警告级别的函数调用 warning_calls = self._check_warning_calls(code) warnings.extend(warning_calls) return RuleCheckResult( passed=len(violations) == 0, violations=violations, warnings=warnings ) def _check_critical_imports(self, code: str) -> List[str]: """检查硬性禁止的导入""" violations = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: module_name = alias.name.split('.')[0] if module_name in self.CRITICAL_FORBIDDEN_IMPORTS: violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)") elif isinstance(node, ast.ImportFrom): if node.module: module_name = node.module.split('.')[0] if module_name in self.CRITICAL_FORBIDDEN_IMPORTS: violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)") except SyntaxError: for module in self.CRITICAL_FORBIDDEN_IMPORTS: pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b' if re.search(pattern, code): violations.append(f"严禁使用模块: {module}") return violations def _check_critical_calls(self, code: str) -> List[str]: """检查硬性禁止的函数调用""" violations = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Call): call_name = self._get_call_name(node) if call_name in self.CRITICAL_FORBIDDEN_CALLS: violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)") except SyntaxError: for func in self.CRITICAL_FORBIDDEN_CALLS: pattern = rf'\b{re.escape(func)}\s*\(' if re.search(pattern, code): violations.append(f"严禁调用: {func}") return violations def _check_warning_imports(self, code: str) -> List[str]: """检查警告级别的导入""" warnings = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Import): for alias in node.names: module_name = alias.name.split('.')[0] if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS: warnings.append(f"使用了网络相关模块: {alias.name}") elif isinstance(node, ast.ImportFrom): if node.module: module_name = node.module.split('.')[0] if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS: warnings.append(f"使用了网络相关模块: {node.module}") except SyntaxError: pass return warnings def _check_warning_calls(self, code: str) -> List[str]: """检查警告级别的函数调用""" warnings = [] try: tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, ast.Call): call_name = self._get_call_name(node) if call_name in self.WARNING_CALLS: warnings.append(f"使用了敏感操作: {call_name}") except SyntaxError: pass return warnings def _get_call_name(self, node: ast.Call) -> str: """获取函数调用的完整名称""" if isinstance(node.func, ast.Name): return node.func.id elif isinstance(node.func, ast.Attribute): parts = [] current = node.func while isinstance(current, ast.Attribute): parts.append(current.attr) current = current.value if isinstance(current, ast.Name): parts.append(current.id) return '.'.join(reversed(parts)) return '' def check_code_safety(code: str) -> RuleCheckResult: """便捷函数:检查代码安全性""" checker = RuleChecker() return checker.check(code)