- Renamed `check_environment` to `check_api_key_configured` for clarity, simplifying the API key validation logic. - Removed the blocking behavior of the API key check during application startup, allowing the app to run while providing a prompt for configuration. - Updated `LocalAgentApp` to accept an `api_configured` parameter, enabling conditional messaging for API key setup. - Enhanced the `SandboxRunner` to support backup management and improved execution result handling with detailed metrics. - Integrated data governance strategies into the `HistoryManager`, ensuring compliance and improved data management. - Added privacy settings and metrics tracking across various components to enhance user experience and application safety.
334 lines
12 KiB
Python
334 lines
12 KiB
Python
"""
|
||
硬规则安全检查器
|
||
只检测最危险的操作,其他交给 LLM 审查
|
||
"""
|
||
|
||
import re
|
||
import ast
|
||
from typing import List
|
||
from dataclasses import dataclass
|
||
|
||
from .security_metrics import get_metrics
|
||
|
||
|
||
@dataclass
|
||
class RuleCheckResult:
|
||
"""规则检查结果"""
|
||
passed: bool
|
||
violations: List[str] # 违规项列表
|
||
warnings: List[str] # 警告项(交给 LLM 审查)
|
||
|
||
|
||
class RuleChecker:
|
||
"""
|
||
硬规则检查器
|
||
|
||
只硬性禁止最危险的操作:
|
||
1. 网络模块: socket(底层网络)
|
||
2. 执行任意代码: eval, exec, compile
|
||
3. 执行系统命令: subprocess, os.system, os.popen
|
||
4. 动态导入: __import__
|
||
|
||
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
|
||
"""
|
||
|
||
# 【硬性禁止】最危险的模块 - 直接拒绝
|
||
CRITICAL_FORBIDDEN_IMPORTS = {
|
||
# 网络模块(硬阻断)
|
||
'socket', # 底层网络,可绑定端口、建立连接
|
||
'requests', # HTTP 请求
|
||
'urllib', # URL 处理
|
||
'urllib3', # HTTP 客户端
|
||
'http', # HTTP 相关
|
||
'ftplib', # FTP
|
||
'smtplib', # 邮件
|
||
'telnetlib', # Telnet
|
||
'xmlrpc', # XML-RPC
|
||
'httplib', # HTTP 库
|
||
'httplib2', # HTTP 库
|
||
'aiohttp', # 异步 HTTP
|
||
|
||
# 执行命令
|
||
'subprocess', # 执行任意系统命令
|
||
'multiprocessing', # 可能绑定端口
|
||
'asyncio', # 可能包含网络操作
|
||
'ctypes', # 可调用任意 C 函数
|
||
'cffi', # 外部函数接口
|
||
}
|
||
|
||
# 【硬性禁止】最危险的函数调用 - 直接拒绝
|
||
CRITICAL_FORBIDDEN_CALLS = {
|
||
# 执行任意代码
|
||
'eval',
|
||
'exec',
|
||
'compile',
|
||
'__import__',
|
||
|
||
# 执行系统命令
|
||
'os.system',
|
||
'os.popen',
|
||
'os.spawn',
|
||
'os.spawnl',
|
||
'os.spawnle',
|
||
'os.spawnlp',
|
||
'os.spawnlpe',
|
||
'os.spawnv',
|
||
'os.spawnve',
|
||
'os.spawnvp',
|
||
'os.spawnvpe',
|
||
'os.execl',
|
||
'os.execle',
|
||
'os.execlp',
|
||
'os.execlpe',
|
||
'os.execv',
|
||
'os.execve',
|
||
'os.execvp',
|
||
'os.execvpe',
|
||
}
|
||
|
||
# 【警告】需要 LLM 审查的模块(已移至硬阻断)
|
||
WARNING_IMPORTS = set()
|
||
|
||
# 【警告】需要 LLM 审查的函数调用
|
||
WARNING_CALLS = {
|
||
'os.remove', # 删除文件
|
||
'os.unlink', # 删除文件
|
||
'os.rmdir', # 删除目录
|
||
'os.removedirs', # 递归删除目录
|
||
'shutil.rmtree', # 递归删除目录树
|
||
'shutil.move', # 移动文件(可能丢失原文件)
|
||
'open', # 打开文件(检查路径)
|
||
}
|
||
|
||
def check(self, code: str) -> RuleCheckResult:
|
||
"""
|
||
检查代码是否符合安全规则
|
||
|
||
Args:
|
||
code: Python 代码字符串
|
||
|
||
Returns:
|
||
RuleCheckResult: 检查结果
|
||
"""
|
||
violations = [] # 硬性违规,直接拒绝
|
||
warnings = [] # 警告,交给 LLM 审查
|
||
|
||
metrics = get_metrics()
|
||
|
||
# 1. 检查硬性禁止的导入
|
||
critical_import_violations = self._check_critical_imports(code)
|
||
violations.extend(critical_import_violations)
|
||
for v in critical_import_violations:
|
||
if 'socket' in v or 'requests' in v or 'urllib' in v or 'http' in v:
|
||
metrics.add_static_block('network', v)
|
||
else:
|
||
metrics.add_static_block('dangerous_call', v)
|
||
|
||
# 2. 检查硬性禁止的函数调用
|
||
critical_call_violations = self._check_critical_calls(code)
|
||
violations.extend(critical_call_violations)
|
||
for v in critical_call_violations:
|
||
metrics.add_static_block('dangerous_call', v)
|
||
|
||
# 3. 检查绝对路径访问(硬阻断)
|
||
path_violations = self._check_absolute_paths(code)
|
||
violations.extend(path_violations)
|
||
for v in path_violations:
|
||
metrics.add_static_block('path', v)
|
||
|
||
# 4. 检查警告级别的导入
|
||
warning_imports = self._check_warning_imports(code)
|
||
warnings.extend(warning_imports)
|
||
for w in warning_imports:
|
||
metrics.add_static_warning('network', w)
|
||
|
||
# 5. 检查警告级别的函数调用
|
||
warning_calls = self._check_warning_calls(code)
|
||
warnings.extend(warning_calls)
|
||
for w in warning_calls:
|
||
metrics.add_static_warning('file_operation', w)
|
||
|
||
return RuleCheckResult(
|
||
passed=len(violations) == 0,
|
||
violations=violations,
|
||
warnings=warnings
|
||
)
|
||
|
||
def _check_critical_imports(self, code: str) -> List[str]:
|
||
"""检查硬性禁止的导入"""
|
||
violations = []
|
||
|
||
try:
|
||
tree = ast.parse(code)
|
||
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.Import):
|
||
for alias in node.names:
|
||
module_name = alias.name.split('.')[0]
|
||
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||
violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
|
||
|
||
elif isinstance(node, ast.ImportFrom):
|
||
if node.module:
|
||
module_name = node.module.split('.')[0]
|
||
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||
violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
|
||
|
||
except SyntaxError:
|
||
for module in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
|
||
if re.search(pattern, code):
|
||
violations.append(f"严禁使用模块: {module}")
|
||
|
||
return violations
|
||
|
||
def _check_critical_calls(self, code: str) -> List[str]:
|
||
"""检查硬性禁止的函数调用"""
|
||
violations = []
|
||
|
||
try:
|
||
tree = ast.parse(code)
|
||
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.Call):
|
||
call_name = self._get_call_name(node)
|
||
if call_name in self.CRITICAL_FORBIDDEN_CALLS:
|
||
violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
|
||
|
||
except SyntaxError:
|
||
for func in self.CRITICAL_FORBIDDEN_CALLS:
|
||
pattern = rf'\b{re.escape(func)}\s*\('
|
||
if re.search(pattern, code):
|
||
violations.append(f"严禁调用: {func}")
|
||
|
||
return violations
|
||
|
||
def _check_warning_imports(self, code: str) -> List[str]:
|
||
"""检查警告级别的导入"""
|
||
warnings = []
|
||
|
||
try:
|
||
tree = ast.parse(code)
|
||
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.Import):
|
||
for alias in node.names:
|
||
module_name = alias.name.split('.')[0]
|
||
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
|
||
warnings.append(f"使用了网络相关模块: {alias.name}")
|
||
|
||
elif isinstance(node, ast.ImportFrom):
|
||
if node.module:
|
||
module_name = node.module.split('.')[0]
|
||
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
|
||
warnings.append(f"使用了网络相关模块: {node.module}")
|
||
|
||
except SyntaxError:
|
||
pass
|
||
|
||
return warnings
|
||
|
||
def _check_warning_calls(self, code: str) -> List[str]:
|
||
"""检查警告级别的函数调用"""
|
||
warnings = []
|
||
|
||
try:
|
||
tree = ast.parse(code)
|
||
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.Call):
|
||
call_name = self._get_call_name(node)
|
||
if call_name in self.WARNING_CALLS:
|
||
warnings.append(f"使用了敏感操作: {call_name}")
|
||
|
||
except SyntaxError:
|
||
pass
|
||
|
||
return warnings
|
||
|
||
def _check_absolute_paths(self, code: str) -> List[str]:
|
||
"""
|
||
检查绝对路径访问(硬阻断)
|
||
|
||
禁止访问 workspace 外的路径:
|
||
- Windows: C:\, D:\, E:\ 等
|
||
- Linux/Mac: /home, /usr, /etc 等
|
||
"""
|
||
violations = []
|
||
|
||
# Windows 绝对路径模式
|
||
windows_patterns = [
|
||
r'[A-Za-z]:\\', # C:\, D:\
|
||
r'[A-Za-z]:/', # C:/, D:/
|
||
r'\\\\[^\\]+\\', # UNC 路径 \\server\share
|
||
]
|
||
|
||
# Unix 绝对路径模式
|
||
unix_patterns = [
|
||
r'(?:^|[\s"\'])(/home|/usr|/etc|/var|/tmp|/root|/opt|/bin|/sbin|/lib|/sys|/proc|/dev)',
|
||
]
|
||
|
||
# 检查所有模式
|
||
for pattern in windows_patterns + unix_patterns:
|
||
matches = re.finditer(pattern, code)
|
||
for match in matches:
|
||
# 排除注释中的路径
|
||
line_start = code.rfind('\n', 0, match.start()) + 1
|
||
line = code[line_start:code.find('\n', match.start())]
|
||
if not line.strip().startswith('#'):
|
||
violations.append(f"严禁访问绝对路径: {match.group()} (只能访问 workspace 目录)")
|
||
break # 每个模式只报告一次
|
||
|
||
# 检查 Path 对象的绝对路径
|
||
try:
|
||
tree = ast.parse(code)
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.Call):
|
||
# 检查 Path() 调用
|
||
call_name = self._get_call_name(node)
|
||
if call_name in ['Path', 'pathlib.Path']:
|
||
for arg in node.args:
|
||
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
||
path_str = arg.value
|
||
# 检查是否为绝对路径
|
||
if self._is_absolute_path(path_str):
|
||
violations.append(f"严禁使用绝对路径: Path('{path_str}') (只能使用相对路径)")
|
||
except SyntaxError:
|
||
pass
|
||
|
||
return violations
|
||
|
||
def _is_absolute_path(self, path: str) -> bool:
|
||
"""判断是否为绝对路径"""
|
||
# Windows 绝对路径
|
||
if re.match(r'^[A-Za-z]:[/\\]', path):
|
||
return True
|
||
# UNC 路径
|
||
if path.startswith(r'\\'):
|
||
return True
|
||
# Unix 绝对路径
|
||
if path.startswith('/'):
|
||
return True
|
||
return False
|
||
|
||
def _get_call_name(self, node: ast.Call) -> str:
|
||
"""获取函数调用的完整名称"""
|
||
if isinstance(node.func, ast.Name):
|
||
return node.func.id
|
||
elif isinstance(node.func, ast.Attribute):
|
||
parts = []
|
||
current = node.func
|
||
while isinstance(current, ast.Attribute):
|
||
parts.append(current.attr)
|
||
current = current.value
|
||
if isinstance(current, ast.Name):
|
||
parts.append(current.id)
|
||
return '.'.join(reversed(parts))
|
||
return ''
|
||
|
||
|
||
def check_code_safety(code: str) -> RuleCheckResult:
|
||
"""便捷函数:检查代码安全性"""
|
||
checker = RuleChecker()
|
||
return checker.check(code)
|