Files
LocalAgent/safety/rule_checker.py
Mimikko-zeus 8a538bb950 feat: refactor API key configuration and enhance application initialization
- Renamed `check_environment` to `check_api_key_configured` for clarity, simplifying the API key validation logic.
- Removed the blocking behavior of the API key check during application startup, allowing the app to run while providing a prompt for configuration.
- Updated `LocalAgentApp` to accept an `api_configured` parameter, enabling conditional messaging for API key setup.
- Enhanced the `SandboxRunner` to support backup management and improved execution result handling with detailed metrics.
- Integrated data governance strategies into the `HistoryManager`, ensuring compliance and improved data management.
- Added privacy settings and metrics tracking across various components to enhance user experience and application safety.
2026-02-27 14:32:30 +08:00

334 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
硬规则安全检查器
只检测最危险的操作,其他交给 LLM 审查
"""
import re
import ast
from typing import List
from dataclasses import dataclass
from .security_metrics import get_metrics
@dataclass
class RuleCheckResult:
"""规则检查结果"""
passed: bool
violations: List[str] # 违规项列表
warnings: List[str] # 警告项(交给 LLM 审查)
class RuleChecker:
"""
硬规则检查器
只硬性禁止最危险的操作:
1. 网络模块: socket底层网络
2. 执行任意代码: eval, exec, compile
3. 执行系统命令: subprocess, os.system, os.popen
4. 动态导入: __import__
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
"""
# 【硬性禁止】最危险的模块 - 直接拒绝
CRITICAL_FORBIDDEN_IMPORTS = {
# 网络模块(硬阻断)
'socket', # 底层网络,可绑定端口、建立连接
'requests', # HTTP 请求
'urllib', # URL 处理
'urllib3', # HTTP 客户端
'http', # HTTP 相关
'ftplib', # FTP
'smtplib', # 邮件
'telnetlib', # Telnet
'xmlrpc', # XML-RPC
'httplib', # HTTP 库
'httplib2', # HTTP 库
'aiohttp', # 异步 HTTP
# 执行命令
'subprocess', # 执行任意系统命令
'multiprocessing', # 可能绑定端口
'asyncio', # 可能包含网络操作
'ctypes', # 可调用任意 C 函数
'cffi', # 外部函数接口
}
# 【硬性禁止】最危险的函数调用 - 直接拒绝
CRITICAL_FORBIDDEN_CALLS = {
# 执行任意代码
'eval',
'exec',
'compile',
'__import__',
# 执行系统命令
'os.system',
'os.popen',
'os.spawn',
'os.spawnl',
'os.spawnle',
'os.spawnlp',
'os.spawnlpe',
'os.spawnv',
'os.spawnve',
'os.spawnvp',
'os.spawnvpe',
'os.execl',
'os.execle',
'os.execlp',
'os.execlpe',
'os.execv',
'os.execve',
'os.execvp',
'os.execvpe',
}
# 【警告】需要 LLM 审查的模块(已移至硬阻断)
WARNING_IMPORTS = set()
# 【警告】需要 LLM 审查的函数调用
WARNING_CALLS = {
'os.remove', # 删除文件
'os.unlink', # 删除文件
'os.rmdir', # 删除目录
'os.removedirs', # 递归删除目录
'shutil.rmtree', # 递归删除目录树
'shutil.move', # 移动文件(可能丢失原文件)
'open', # 打开文件(检查路径)
}
def check(self, code: str) -> RuleCheckResult:
"""
检查代码是否符合安全规则
Args:
code: Python 代码字符串
Returns:
RuleCheckResult: 检查结果
"""
violations = [] # 硬性违规,直接拒绝
warnings = [] # 警告,交给 LLM 审查
metrics = get_metrics()
# 1. 检查硬性禁止的导入
critical_import_violations = self._check_critical_imports(code)
violations.extend(critical_import_violations)
for v in critical_import_violations:
if 'socket' in v or 'requests' in v or 'urllib' in v or 'http' in v:
metrics.add_static_block('network', v)
else:
metrics.add_static_block('dangerous_call', v)
# 2. 检查硬性禁止的函数调用
critical_call_violations = self._check_critical_calls(code)
violations.extend(critical_call_violations)
for v in critical_call_violations:
metrics.add_static_block('dangerous_call', v)
# 3. 检查绝对路径访问(硬阻断)
path_violations = self._check_absolute_paths(code)
violations.extend(path_violations)
for v in path_violations:
metrics.add_static_block('path', v)
# 4. 检查警告级别的导入
warning_imports = self._check_warning_imports(code)
warnings.extend(warning_imports)
for w in warning_imports:
metrics.add_static_warning('network', w)
# 5. 检查警告级别的函数调用
warning_calls = self._check_warning_calls(code)
warnings.extend(warning_calls)
for w in warning_calls:
metrics.add_static_warning('file_operation', w)
return RuleCheckResult(
passed=len(violations) == 0,
violations=violations,
warnings=warnings
)
def _check_critical_imports(self, code: str) -> List[str]:
"""检查硬性禁止的导入"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
except SyntaxError:
for module in self.CRITICAL_FORBIDDEN_IMPORTS:
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
if re.search(pattern, code):
violations.append(f"严禁使用模块: {module}")
return violations
def _check_critical_calls(self, code: str) -> List[str]:
"""检查硬性禁止的函数调用"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.CRITICAL_FORBIDDEN_CALLS:
violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
except SyntaxError:
for func in self.CRITICAL_FORBIDDEN_CALLS:
pattern = rf'\b{re.escape(func)}\s*\('
if re.search(pattern, code):
violations.append(f"严禁调用: {func}")
return violations
def _check_warning_imports(self, code: str) -> List[str]:
"""检查警告级别的导入"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {node.module}")
except SyntaxError:
pass
return warnings
def _check_warning_calls(self, code: str) -> List[str]:
"""检查警告级别的函数调用"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.WARNING_CALLS:
warnings.append(f"使用了敏感操作: {call_name}")
except SyntaxError:
pass
return warnings
def _check_absolute_paths(self, code: str) -> List[str]:
"""
检查绝对路径访问(硬阻断)
禁止访问 workspace 外的路径:
- Windows: C:\, D:\, E:\
- Linux/Mac: /home, /usr, /etc 等
"""
violations = []
# Windows 绝对路径模式
windows_patterns = [
r'[A-Za-z]:\\', # C:\, D:\
r'[A-Za-z]:/', # C:/, D:/
r'\\\\[^\\]+\\', # UNC 路径 \\server\share
]
# Unix 绝对路径模式
unix_patterns = [
r'(?:^|[\s"\'])(/home|/usr|/etc|/var|/tmp|/root|/opt|/bin|/sbin|/lib|/sys|/proc|/dev)',
]
# 检查所有模式
for pattern in windows_patterns + unix_patterns:
matches = re.finditer(pattern, code)
for match in matches:
# 排除注释中的路径
line_start = code.rfind('\n', 0, match.start()) + 1
line = code[line_start:code.find('\n', match.start())]
if not line.strip().startswith('#'):
violations.append(f"严禁访问绝对路径: {match.group()} (只能访问 workspace 目录)")
break # 每个模式只报告一次
# 检查 Path 对象的绝对路径
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
# 检查 Path() 调用
call_name = self._get_call_name(node)
if call_name in ['Path', 'pathlib.Path']:
for arg in node.args:
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
path_str = arg.value
# 检查是否为绝对路径
if self._is_absolute_path(path_str):
violations.append(f"严禁使用绝对路径: Path('{path_str}') (只能使用相对路径)")
except SyntaxError:
pass
return violations
def _is_absolute_path(self, path: str) -> bool:
"""判断是否为绝对路径"""
# Windows 绝对路径
if re.match(r'^[A-Za-z]:[/\\]', path):
return True
# UNC 路径
if path.startswith(r'\\'):
return True
# Unix 绝对路径
if path.startswith('/'):
return True
return False
def _get_call_name(self, node: ast.Call) -> str:
"""获取函数调用的完整名称"""
if isinstance(node.func, ast.Name):
return node.func.id
elif isinstance(node.func, ast.Attribute):
parts = []
current = node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return '.'.join(reversed(parts))
return ''
def check_code_safety(code: str) -> RuleCheckResult:
"""便捷函数:检查代码安全性"""
checker = RuleChecker()
return checker.check(code)