133 lines
3.7 KiB
Python
133 lines
3.7 KiB
Python
"""
|
|
LLM 软规则审查器
|
|
使用 LLM 进行代码安全审查
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from typing import Optional
|
|
from dataclasses import dataclass
|
|
from dotenv import load_dotenv
|
|
|
|
from llm.client import get_client, LLMClientError, ENV_PATH
|
|
from llm.prompts import SAFETY_REVIEW_SYSTEM, SAFETY_REVIEW_USER
|
|
|
|
|
|
@dataclass
|
|
class LLMReviewResult:
|
|
"""LLM 审查结果"""
|
|
passed: bool
|
|
reason: str
|
|
raw_response: Optional[str] = None
|
|
|
|
|
|
class LLMReviewer:
|
|
"""
|
|
LLM 安全审查器
|
|
|
|
使用大模型对代码进行语义级别的安全审查
|
|
"""
|
|
|
|
def __init__(self):
|
|
load_dotenv(ENV_PATH)
|
|
self.model_name = os.getenv("GENERATION_MODEL_NAME")
|
|
|
|
def review(
|
|
self,
|
|
user_input: str,
|
|
execution_plan: str,
|
|
code: str
|
|
) -> LLMReviewResult:
|
|
"""
|
|
审查代码安全性
|
|
|
|
Args:
|
|
user_input: 用户原始需求
|
|
execution_plan: 执行计划
|
|
code: 待审查的代码
|
|
|
|
Returns:
|
|
LLMReviewResult: 审查结果
|
|
"""
|
|
try:
|
|
client = get_client()
|
|
|
|
messages = [
|
|
{"role": "system", "content": SAFETY_REVIEW_SYSTEM},
|
|
{"role": "user", "content": SAFETY_REVIEW_USER.format(
|
|
user_input=user_input,
|
|
execution_plan=execution_plan,
|
|
code=code
|
|
)}
|
|
]
|
|
|
|
response = client.chat(
|
|
messages=messages,
|
|
model=self.model_name,
|
|
temperature=0.1,
|
|
max_tokens=512
|
|
)
|
|
|
|
return self._parse_response(response)
|
|
|
|
except LLMClientError as e:
|
|
# LLM 调用失败,保守起见判定为不通过
|
|
return LLMReviewResult(
|
|
passed=False,
|
|
reason=f"安全审查失败({str(e)}),出于安全考虑拒绝执行"
|
|
)
|
|
except Exception as e:
|
|
return LLMReviewResult(
|
|
passed=False,
|
|
reason=f"安全审查异常({str(e)}),出于安全考虑拒绝执行"
|
|
)
|
|
|
|
def _parse_response(self, response: str) -> LLMReviewResult:
|
|
"""解析 LLM 响应"""
|
|
try:
|
|
# 提取 JSON
|
|
json_str = self._extract_json(response)
|
|
data = json.loads(json_str)
|
|
|
|
passed = data.get("pass", False)
|
|
reason = data.get("reason", "未提供原因")
|
|
|
|
# 确保 passed 是布尔值
|
|
if isinstance(passed, str):
|
|
passed = passed.lower() in ('true', 'yes', '1', 'pass')
|
|
|
|
return LLMReviewResult(
|
|
passed=bool(passed),
|
|
reason=reason,
|
|
raw_response=response
|
|
)
|
|
|
|
except (json.JSONDecodeError, ValueError, TypeError):
|
|
# 解析失败,保守判定
|
|
return LLMReviewResult(
|
|
passed=False,
|
|
reason=f"审查结果解析失败,出于安全考虑拒绝执行",
|
|
raw_response=response
|
|
)
|
|
|
|
def _extract_json(self, text: str) -> str:
|
|
"""从文本中提取 JSON"""
|
|
start = text.find('{')
|
|
end = text.rfind('}')
|
|
|
|
if start != -1 and end != -1 and end > start:
|
|
return text[start:end + 1]
|
|
|
|
return text
|
|
|
|
|
|
def review_code_safety(
|
|
user_input: str,
|
|
execution_plan: str,
|
|
code: str
|
|
) -> LLMReviewResult:
|
|
"""便捷函数:审查代码安全性"""
|
|
reviewer = LLMReviewer()
|
|
return reviewer.review(user_input, execution_plan, code)
|
|
|