Initial commit
This commit is contained in:
132
safety/llm_reviewer.py
Normal file
132
safety/llm_reviewer.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
LLM 软规则审查器
|
||||
使用 LLM 进行代码安全审查
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llm.client import get_client, LLMClientError, ENV_PATH
|
||||
from llm.prompts import SAFETY_REVIEW_SYSTEM, SAFETY_REVIEW_USER
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMReviewResult:
|
||||
"""LLM 审查结果"""
|
||||
passed: bool
|
||||
reason: str
|
||||
raw_response: Optional[str] = None
|
||||
|
||||
|
||||
class LLMReviewer:
|
||||
"""
|
||||
LLM 安全审查器
|
||||
|
||||
使用大模型对代码进行语义级别的安全审查
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
load_dotenv(ENV_PATH)
|
||||
self.model_name = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
def review(
|
||||
self,
|
||||
user_input: str,
|
||||
execution_plan: str,
|
||||
code: str
|
||||
) -> LLMReviewResult:
|
||||
"""
|
||||
审查代码安全性
|
||||
|
||||
Args:
|
||||
user_input: 用户原始需求
|
||||
execution_plan: 执行计划
|
||||
code: 待审查的代码
|
||||
|
||||
Returns:
|
||||
LLMReviewResult: 审查结果
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SAFETY_REVIEW_SYSTEM},
|
||||
{"role": "user", "content": SAFETY_REVIEW_USER.format(
|
||||
user_input=user_input,
|
||||
execution_plan=execution_plan,
|
||||
code=code
|
||||
)}
|
||||
]
|
||||
|
||||
response = client.chat(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
temperature=0.1,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
except LLMClientError as e:
|
||||
# LLM 调用失败,保守起见判定为不通过
|
||||
return LLMReviewResult(
|
||||
passed=False,
|
||||
reason=f"安全审查失败({str(e)}),出于安全考虑拒绝执行"
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMReviewResult(
|
||||
passed=False,
|
||||
reason=f"安全审查异常({str(e)}),出于安全考虑拒绝执行"
|
||||
)
|
||||
|
||||
def _parse_response(self, response: str) -> LLMReviewResult:
|
||||
"""解析 LLM 响应"""
|
||||
try:
|
||||
# 提取 JSON
|
||||
json_str = self._extract_json(response)
|
||||
data = json.loads(json_str)
|
||||
|
||||
passed = data.get("pass", False)
|
||||
reason = data.get("reason", "未提供原因")
|
||||
|
||||
# 确保 passed 是布尔值
|
||||
if isinstance(passed, str):
|
||||
passed = passed.lower() in ('true', 'yes', '1', 'pass')
|
||||
|
||||
return LLMReviewResult(
|
||||
passed=bool(passed),
|
||||
reason=reason,
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
# 解析失败,保守判定
|
||||
return LLMReviewResult(
|
||||
passed=False,
|
||||
reason=f"审查结果解析失败,出于安全考虑拒绝执行",
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
def _extract_json(self, text: str) -> str:
|
||||
"""从文本中提取 JSON"""
|
||||
start = text.find('{')
|
||||
end = text.rfind('}')
|
||||
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return text[start:end + 1]
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def review_code_safety(
|
||||
user_input: str,
|
||||
execution_plan: str,
|
||||
code: str
|
||||
) -> LLMReviewResult:
|
||||
"""便捷函数:审查代码安全性"""
|
||||
reviewer = LLMReviewer()
|
||||
return reviewer.review(user_input, execution_plan, code)
|
||||
|
||||
Reference in New Issue
Block a user