Files
LocalAgent/safety/rule_checker.py
Mimikko-zeus 1ba5f0f7d6 feat: implement streaming support for chat and enhance safety review process
- Updated .env.example to include API key placeholder and configuration instructions.
- Refactored main.py to support streaming responses from the LLM, improving user experience during chat interactions.
- Enhanced LLMClient to include methods for streaming chat and collecting responses.
- Modified safety review process to pass static analysis warnings to the LLM for better code safety evaluation.
- Improved UI components in chat_view.py to handle streaming messages effectively.
2026-01-07 09:43:40 +08:00

241 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
硬规则安全检查器
只检测最危险的操作,其他交给 LLM 审查
"""
import re
import ast
from typing import List
from dataclasses import dataclass
@dataclass
class RuleCheckResult:
"""规则检查结果"""
passed: bool
violations: List[str] # 违规项列表
warnings: List[str] # 警告项(交给 LLM 审查)
class RuleChecker:
"""
硬规则检查器
只硬性禁止最危险的操作:
1. 网络模块: socket底层网络
2. 执行任意代码: eval, exec, compile
3. 执行系统命令: subprocess, os.system, os.popen
4. 动态导入: __import__
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
"""
# 【硬性禁止】最危险的模块 - 直接拒绝
CRITICAL_FORBIDDEN_IMPORTS = {
'socket', # 底层网络,可绑定端口、建立连接
'subprocess', # 执行任意系统命令
'multiprocessing', # 可能绑定端口
'asyncio', # 可能包含网络操作
'ctypes', # 可调用任意 C 函数
'cffi', # 外部函数接口
}
# 【硬性禁止】最危险的函数调用 - 直接拒绝
CRITICAL_FORBIDDEN_CALLS = {
# 执行任意代码
'eval',
'exec',
'compile',
'__import__',
# 执行系统命令
'os.system',
'os.popen',
'os.spawn',
'os.spawnl',
'os.spawnle',
'os.spawnlp',
'os.spawnlpe',
'os.spawnv',
'os.spawnve',
'os.spawnvp',
'os.spawnvpe',
'os.execl',
'os.execle',
'os.execlp',
'os.execlpe',
'os.execv',
'os.execve',
'os.execvp',
'os.execvpe',
}
# 【警告】需要 LLM 审查的模块
WARNING_IMPORTS = {
'requests', # HTTP 请求
'urllib', # URL 处理
'http.client', # HTTP 客户端
'ftplib', # FTP
'smtplib', # 邮件
'telnetlib', # Telnet
}
# 【警告】需要 LLM 审查的函数调用
WARNING_CALLS = {
'os.remove', # 删除文件
'os.unlink', # 删除文件
'os.rmdir', # 删除目录
'os.removedirs', # 递归删除目录
'shutil.rmtree', # 递归删除目录树
'shutil.move', # 移动文件(可能丢失原文件)
'open', # 打开文件(检查路径)
}
def check(self, code: str) -> RuleCheckResult:
"""
检查代码是否符合安全规则
Args:
code: Python 代码字符串
Returns:
RuleCheckResult: 检查结果
"""
violations = [] # 硬性违规,直接拒绝
warnings = [] # 警告,交给 LLM 审查
# 1. 检查硬性禁止的导入
critical_import_violations = self._check_critical_imports(code)
violations.extend(critical_import_violations)
# 2. 检查硬性禁止的函数调用
critical_call_violations = self._check_critical_calls(code)
violations.extend(critical_call_violations)
# 3. 检查警告级别的导入
warning_imports = self._check_warning_imports(code)
warnings.extend(warning_imports)
# 4. 检查警告级别的函数调用
warning_calls = self._check_warning_calls(code)
warnings.extend(warning_calls)
return RuleCheckResult(
passed=len(violations) == 0,
violations=violations,
warnings=warnings
)
def _check_critical_imports(self, code: str) -> List[str]:
"""检查硬性禁止的导入"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
except SyntaxError:
for module in self.CRITICAL_FORBIDDEN_IMPORTS:
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
if re.search(pattern, code):
violations.append(f"严禁使用模块: {module}")
return violations
def _check_critical_calls(self, code: str) -> List[str]:
"""检查硬性禁止的函数调用"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.CRITICAL_FORBIDDEN_CALLS:
violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
except SyntaxError:
for func in self.CRITICAL_FORBIDDEN_CALLS:
pattern = rf'\b{re.escape(func)}\s*\('
if re.search(pattern, code):
violations.append(f"严禁调用: {func}")
return violations
def _check_warning_imports(self, code: str) -> List[str]:
"""检查警告级别的导入"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {node.module}")
except SyntaxError:
pass
return warnings
def _check_warning_calls(self, code: str) -> List[str]:
"""检查警告级别的函数调用"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.WARNING_CALLS:
warnings.append(f"使用了敏感操作: {call_name}")
except SyntaxError:
pass
return warnings
def _get_call_name(self, node: ast.Call) -> str:
"""获取函数调用的完整名称"""
if isinstance(node.func, ast.Name):
return node.func.id
elif isinstance(node.func, ast.Attribute):
parts = []
current = node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return '.'.join(reversed(parts))
return ''
def check_code_safety(code: str) -> RuleCheckResult:
"""便捷函数:检查代码安全性"""
checker = RuleChecker()
return checker.check(code)