Initial commit

This commit is contained in:
Mimikko-zeus
2026-01-07 00:17:46 +08:00
commit 4b3286f546
49 changed files with 2492 additions and 0 deletions

2
safety/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# 安全检查模块

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

132
safety/llm_reviewer.py Normal file
View File

@@ -0,0 +1,132 @@
"""
LLM 软规则审查器
使用 LLM 进行代码安全审查
"""
import os
import json
from typing import Optional
from dataclasses import dataclass
from dotenv import load_dotenv
from llm.client import get_client, LLMClientError, ENV_PATH
from llm.prompts import SAFETY_REVIEW_SYSTEM, SAFETY_REVIEW_USER
@dataclass
class LLMReviewResult:
"""LLM 审查结果"""
passed: bool
reason: str
raw_response: Optional[str] = None
class LLMReviewer:
"""
LLM 安全审查器
使用大模型对代码进行语义级别的安全审查
"""
def __init__(self):
load_dotenv(ENV_PATH)
self.model_name = os.getenv("GENERATION_MODEL_NAME")
def review(
self,
user_input: str,
execution_plan: str,
code: str
) -> LLMReviewResult:
"""
审查代码安全性
Args:
user_input: 用户原始需求
execution_plan: 执行计划
code: 待审查的代码
Returns:
LLMReviewResult: 审查结果
"""
try:
client = get_client()
messages = [
{"role": "system", "content": SAFETY_REVIEW_SYSTEM},
{"role": "user", "content": SAFETY_REVIEW_USER.format(
user_input=user_input,
execution_plan=execution_plan,
code=code
)}
]
response = client.chat(
messages=messages,
model=self.model_name,
temperature=0.1,
max_tokens=512
)
return self._parse_response(response)
except LLMClientError as e:
# LLM 调用失败,保守起见判定为不通过
return LLMReviewResult(
passed=False,
reason=f"安全审查失败({str(e)}),出于安全考虑拒绝执行"
)
except Exception as e:
return LLMReviewResult(
passed=False,
reason=f"安全审查异常({str(e)}),出于安全考虑拒绝执行"
)
def _parse_response(self, response: str) -> LLMReviewResult:
"""解析 LLM 响应"""
try:
# 提取 JSON
json_str = self._extract_json(response)
data = json.loads(json_str)
passed = data.get("pass", False)
reason = data.get("reason", "未提供原因")
# 确保 passed 是布尔值
if isinstance(passed, str):
passed = passed.lower() in ('true', 'yes', '1', 'pass')
return LLMReviewResult(
passed=bool(passed),
reason=reason,
raw_response=response
)
except (json.JSONDecodeError, ValueError, TypeError):
# 解析失败,保守判定
return LLMReviewResult(
passed=False,
reason=f"审查结果解析失败,出于安全考虑拒绝执行",
raw_response=response
)
def _extract_json(self, text: str) -> str:
"""从文本中提取 JSON"""
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
return text[start:end + 1]
return text
def review_code_safety(
user_input: str,
execution_plan: str,
code: str
) -> LLMReviewResult:
"""便捷函数:审查代码安全性"""
reviewer = LLMReviewer()
return reviewer.review(user_input, execution_plan, code)

208
safety/rule_checker.py Normal file
View File

@@ -0,0 +1,208 @@
"""
硬规则安全检查器
静态扫描执行代码,检测危险操作
"""
import re
import ast
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class RuleCheckResult:
"""规则检查结果"""
passed: bool
violations: List[str] # 违规项列表
class RuleChecker:
"""
硬规则检查器
静态扫描代码,检测以下危险操作:
1. 网络请求: requests, socket, urllib, http.client
2. 危险文件操作: os.remove, shutil.rmtree, os.unlink
3. 执行外部命令: subprocess, os.system, os.popen
4. 访问非 workspace 路径
"""
# 禁止导入的模块
FORBIDDEN_IMPORTS = {
'requests',
'socket',
'urllib',
'urllib.request',
'urllib.parse',
'urllib.error',
'http.client',
'httplib',
'ftplib',
'smtplib',
'telnetlib',
'subprocess',
}
# 禁止调用的函数(模块.函数 或 单独函数名)
FORBIDDEN_CALLS = {
'os.remove',
'os.unlink',
'os.rmdir',
'os.removedirs',
'os.system',
'os.popen',
'os.spawn',
'os.spawnl',
'os.spawnle',
'os.spawnlp',
'os.spawnlpe',
'os.spawnv',
'os.spawnve',
'os.spawnvp',
'os.spawnvpe',
'os.exec',
'os.execl',
'os.execle',
'os.execlp',
'os.execlpe',
'os.execv',
'os.execve',
'os.execvp',
'os.execvpe',
'shutil.rmtree',
'shutil.move', # move 可能导致原文件丢失
'eval',
'exec',
'compile',
'__import__',
}
# 危险路径模式(正则)
DANGEROUS_PATH_PATTERNS = [
r'[A-Za-z]:\\', # Windows 绝对路径
r'\\\\', # UNC 路径
r'/etc/',
r'/usr/',
r'/bin/',
r'/home/',
r'/root/',
r'\.\./', # 父目录遍历
r'\.\.', # 父目录
]
def check(self, code: str) -> RuleCheckResult:
"""
检查代码是否符合安全规则
Args:
code: Python 代码字符串
Returns:
RuleCheckResult: 检查结果
"""
violations = []
# 1. 检查禁止的导入
import_violations = self._check_imports(code)
violations.extend(import_violations)
# 2. 检查禁止的函数调用
call_violations = self._check_calls(code)
violations.extend(call_violations)
# 3. 检查危险路径
path_violations = self._check_paths(code)
violations.extend(path_violations)
return RuleCheckResult(
passed=len(violations) == 0,
violations=violations
)
def _check_imports(self, code: str) -> List[str]:
"""检查禁止的导入"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {node.module}")
except SyntaxError:
# 如果代码有语法错误,使用正则匹配
for module in self.FORBIDDEN_IMPORTS:
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
if re.search(pattern, code):
violations.append(f"禁止导入模块: {module}")
return violations
def _check_calls(self, code: str) -> List[str]:
"""检查禁止的函数调用"""
violations = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.FORBIDDEN_CALLS:
violations.append(f"禁止调用函数: {call_name}")
except SyntaxError:
# 如果代码有语法错误,使用正则匹配
for func in self.FORBIDDEN_CALLS:
pattern = rf'\b{re.escape(func)}\s*\('
if re.search(pattern, code):
violations.append(f"禁止调用函数: {func}")
return violations
def _get_call_name(self, node: ast.Call) -> str:
"""获取函数调用的完整名称"""
if isinstance(node.func, ast.Name):
return node.func.id
elif isinstance(node.func, ast.Attribute):
parts = []
current = node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return '.'.join(reversed(parts))
return ''
def _check_paths(self, code: str) -> List[str]:
"""检查危险路径访问"""
violations = []
for pattern in self.DANGEROUS_PATH_PATTERNS:
matches = re.findall(pattern, code, re.IGNORECASE)
if matches:
# 排除 workspace 相关的合法路径
for match in matches:
if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower():
violations.append(f"检测到可疑路径模式: {match}")
break
return violations
def check_code_safety(code: str) -> RuleCheckResult:
"""便捷函数:检查代码安全性"""
checker = RuleChecker()
return checker.check(code)