feat: implement streaming support for chat and enhance safety review process

- Updated .env.example to include API key placeholder and configuration instructions.
- Refactored main.py to support streaming responses from the LLM, improving user experience during chat interactions.
- Enhanced LLMClient to include methods for streaming chat and collecting responses.
- Modified safety review process to pass static analysis warnings to the LLM for better code safety evaluation.
- Improved UI components in chat_view.py to handle streaming messages effectively.
This commit is contained in:
Mimikko-zeus
2026-01-07 09:43:40 +08:00
parent dad0d2629a
commit 1ba5f0f7d6
7 changed files with 406 additions and 145 deletions

View File

@@ -1,11 +1,11 @@
"""
硬规则安全检查器
静态扫描执行代码,检测危险操作
检测危险操作,其他交给 LLM 审查
"""
import re
import ast
from typing import List, Tuple
from typing import List
from dataclasses import dataclass
@@ -14,41 +14,41 @@ class RuleCheckResult:
"""规则检查结果"""
passed: bool
violations: List[str] # 违规项列表
warnings: List[str] # 警告项(交给 LLM 审查)
class RuleChecker:
"""
硬规则检查器
静态扫描代码,检测以下危险操作:
1. 网络请求: requests, socket, urllib, http.client
2. 危险文件操作: os.remove, shutil.rmtree, os.unlink
3. 执行外部命令: subprocess, os.system, os.popen
4. 访问非 workspace 路径
只硬性禁止最危险操作:
1. 网络模块: socket底层网络
2. 执行任意代码: eval, exec, compile
3. 执行系统命令: subprocess, os.system, os.popen
4. 动态导入: __import__
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
"""
# 禁止导入的模块
FORBIDDEN_IMPORTS = {
'requests',
'socket',
'urllib',
'urllib.request',
'urllib.parse',
'urllib.error',
'http.client',
'httplib',
'ftplib',
'smtplib',
'telnetlib',
'subprocess',
# 【硬性禁止】最危险的模块 - 直接拒绝
CRITICAL_FORBIDDEN_IMPORTS = {
'socket', # 底层网络,可绑定端口、建立连接
'subprocess', # 执行任意系统命令
'multiprocessing', # 可能绑定端口
'asyncio', # 可能包含网络操作
'ctypes', # 可调用任意 C 函数
'cffi', # 外部函数接口
}
# 禁止调用的函数(模块.函数 或 单独函数名)
FORBIDDEN_CALLS = {
'os.remove',
'os.unlink',
'os.rmdir',
'os.removedirs',
# 【硬性禁止】最危险的函数调用 - 直接拒绝
CRITICAL_FORBIDDEN_CALLS = {
# 执行任意代码
'eval',
'exec',
'compile',
'__import__',
# 执行系统命令
'os.system',
'os.popen',
'os.spawn',
@@ -60,7 +60,6 @@ class RuleChecker:
'os.spawnve',
'os.spawnvp',
'os.spawnvpe',
'os.exec',
'os.execl',
'os.execle',
'os.execlp',
@@ -69,26 +68,28 @@ class RuleChecker:
'os.execve',
'os.execvp',
'os.execvpe',
'shutil.rmtree',
'shutil.move', # move 可能导致原文件丢失
'eval',
'exec',
'compile',
'__import__',
}
# 危险路径模式(正则)
DANGEROUS_PATH_PATTERNS = [
r'[A-Za-z]:\\', # Windows 绝对路径
r'\\\\', # UNC 路径
r'/etc/',
r'/usr/',
r'/bin/',
r'/home/',
r'/root/',
r'\.\./', # 父目录遍历
r'\.\.', # 父目录
]
# 【警告】需要 LLM 审查的模块
WARNING_IMPORTS = {
'requests', # HTTP 请求
'urllib', # URL 处理
'http.client', # HTTP 客户端
'ftplib', # FTP
'smtplib', # 邮件
'telnetlib', # Telnet
}
# 【警告】需要 LLM 审查的函数调用
WARNING_CALLS = {
'os.remove', # 删除文件
'os.unlink', # 删除文件
'os.rmdir', # 删除目录
'os.removedirs', # 递归删除目录
'shutil.rmtree', # 递归删除目录树
'shutil.move', # 移动文件(可能丢失原文件)
'open', # 打开文件(检查路径)
}
def check(self, code: str) -> RuleCheckResult:
"""
@@ -100,27 +101,33 @@ class RuleChecker:
Returns:
RuleCheckResult: 检查结果
"""
violations = []
violations = [] # 硬性违规,直接拒绝
warnings = [] # 警告,交给 LLM 审查
# 1. 检查禁止的导入
import_violations = self._check_imports(code)
violations.extend(import_violations)
# 1. 检查硬性禁止的导入
critical_import_violations = self._check_critical_imports(code)
violations.extend(critical_import_violations)
# 2. 检查禁止的函数调用
call_violations = self._check_calls(code)
violations.extend(call_violations)
# 2. 检查硬性禁止的函数调用
critical_call_violations = self._check_critical_calls(code)
violations.extend(critical_call_violations)
# 3. 检查危险路径
path_violations = self._check_paths(code)
violations.extend(path_violations)
# 3. 检查警告级别的导入
warning_imports = self._check_warning_imports(code)
warnings.extend(warning_imports)
# 4. 检查警告级别的函数调用
warning_calls = self._check_warning_calls(code)
warnings.extend(warning_calls)
return RuleCheckResult(
passed=len(violations) == 0,
violations=violations
violations=violations,
warnings=warnings
)
def _check_imports(self, code: str) -> List[str]:
"""检查禁止的导入"""
def _check_critical_imports(self, code: str) -> List[str]:
"""检查硬性禁止的导入"""
violations = []
try:
@@ -130,26 +137,25 @@ class RuleChecker:
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {alias.name}")
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {node.module}")
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
except SyntaxError:
# 如果代码有语法错误,使用正则匹配
for module in self.FORBIDDEN_IMPORTS:
for module in self.CRITICAL_FORBIDDEN_IMPORTS:
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
if re.search(pattern, code):
violations.append(f"禁止导入模块: {module}")
violations.append(f"严禁使用模块: {module}")
return violations
def _check_calls(self, code: str) -> List[str]:
"""检查禁止的函数调用"""
def _check_critical_calls(self, code: str) -> List[str]:
"""检查硬性禁止的函数调用"""
violations = []
try:
@@ -158,18 +164,60 @@ class RuleChecker:
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.FORBIDDEN_CALLS:
violations.append(f"禁止调用函数: {call_name}")
if call_name in self.CRITICAL_FORBIDDEN_CALLS:
violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
except SyntaxError:
# 如果代码有语法错误,使用正则匹配
for func in self.FORBIDDEN_CALLS:
for func in self.CRITICAL_FORBIDDEN_CALLS:
pattern = rf'\b{re.escape(func)}\s*\('
if re.search(pattern, code):
violations.append(f"禁止调用函数: {func}")
violations.append(f"严禁调用: {func}")
return violations
def _check_warning_imports(self, code: str) -> List[str]:
"""检查警告级别的导入"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {node.module}")
except SyntaxError:
pass
return warnings
def _check_warning_calls(self, code: str) -> List[str]:
"""检查警告级别的函数调用"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.WARNING_CALLS:
warnings.append(f"使用了敏感操作: {call_name}")
except SyntaxError:
pass
return warnings
def _get_call_name(self, node: ast.Call) -> str:
"""获取函数调用的完整名称"""
if isinstance(node.func, ast.Name):
@@ -184,25 +232,9 @@ class RuleChecker:
parts.append(current.id)
return '.'.join(reversed(parts))
return ''
def _check_paths(self, code: str) -> List[str]:
"""检查危险路径访问"""
violations = []
for pattern in self.DANGEROUS_PATH_PATTERNS:
matches = re.findall(pattern, code, re.IGNORECASE)
if matches:
# 排除 workspace 相关的合法路径
for match in matches:
if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower():
violations.append(f"检测到可疑路径模式: {match}")
break
return violations
def check_code_safety(code: str) -> RuleCheckResult:
"""便捷函数:检查代码安全性"""
checker = RuleChecker()
return checker.check(code)