feat: implement streaming support for chat and enhance safety review process
- Updated .env.example to include API key placeholder and configuration instructions. - Refactored main.py to support streaming responses from the LLM, improving user experience during chat interactions. - Enhanced LLMClient to include methods for streaming chat and collecting responses. - Modified safety review process to pass static analysis warnings to the LLM for better code safety evaluation. - Improved UI components in chat_view.py to handle streaming messages effectively.
This commit is contained in:
18
.env.example
18
.env.example
@@ -1,4 +1,16 @@
|
|||||||
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions
|
# ========================================
|
||||||
LLM_API_KEY=sk-fxsxbgatrjjhsnjpkdfgfngukqoqqgitjpxfqfxifcipaqpc
|
# LocalAgent 閰嶇疆鏂囦欢绀轰緥
|
||||||
|
# ========================================
|
||||||
|
# 浣跨敤鏂规硶锛?# 1. 澶嶅埗姝ゆ枃浠朵负 .env
|
||||||
|
# 2. 濉叆浣犵殑 API Key 鍜屽叾浠栭厤缃?# ========================================
|
||||||
|
|
||||||
|
# SiliconFlow API 閰嶇疆
|
||||||
|
# 鑾峰彇 API Key: https://siliconflow.cn
|
||||||
|
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions
|
||||||
|
LLM_API_KEY=your_api_key_here
|
||||||
|
|
||||||
|
# 妯″瀷閰嶇疆
|
||||||
|
# 鎰忓浘璇嗗埆妯″瀷锛堟帹鑽愪娇鐢ㄥ皬妯″瀷锛岄€熷害蹇級
|
||||||
INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
||||||
GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
|
||||||
|
# 浠g爜鐢熸垚妯″瀷锛堟帹鑽愪娇鐢ㄥぇ妯″瀷锛屾晥鏋滃ソ锛?GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
||||||
|
|||||||
150
llm/client.py
150
llm/client.py
@@ -1,12 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
LLM 统一调用客户端
|
LLM 统一调用客户端
|
||||||
所有模型通过 SiliconFlow API 调用
|
所有模型通过 SiliconFlow API 调用
|
||||||
|
支持流式和非流式两种模式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import requests
|
import requests
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Generator, Callable
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# 获取项目根目录
|
# 获取项目根目录
|
||||||
@@ -25,12 +27,19 @@ class LLMClient:
|
|||||||
|
|
||||||
使用方式:
|
使用方式:
|
||||||
client = LLMClient()
|
client = LLMClient()
|
||||||
|
|
||||||
|
# 非流式调用
|
||||||
response = client.chat(
|
response = client.chat(
|
||||||
messages=[{"role": "user", "content": "你好"}],
|
messages=[{"role": "user", "content": "你好"}],
|
||||||
model="Qwen/Qwen2.5-7B-Instruct",
|
model="Qwen/Qwen2.5-7B-Instruct"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 流式调用
|
||||||
|
for chunk in client.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "你好"}],
|
||||||
|
model="Qwen/Qwen2.5-7B-Instruct"
|
||||||
|
):
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -49,22 +58,21 @@ class LLMClient:
|
|||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 1024
|
max_tokens: int = 1024,
|
||||||
|
timeout: int = 180
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
调用 LLM 进行对话
|
调用 LLM 进行对话(非流式)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
|
messages: 消息列表
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
temperature: 温度参数,控制随机性
|
temperature: 温度参数
|
||||||
max_tokens: 最大生成 token 数
|
max_tokens: 最大生成 token 数
|
||||||
|
timeout: 超时时间(秒),默认 180 秒
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLM 生成的文本内容
|
LLM 生成的文本内容
|
||||||
|
|
||||||
Raises:
|
|
||||||
LLMClientError: 网络异常或 API 返回错误
|
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -84,10 +92,10 @@ class LLMClient:
|
|||||||
self.api_url,
|
self.api_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=60
|
timeout=timeout
|
||||||
)
|
)
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
raise LLMClientError("请求超时,请检查网络连接")
|
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
raise LLMClientError("网络连接失败,请检查网络设置")
|
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
@@ -109,6 +117,121 @@ class LLMClient:
|
|||||||
return content
|
return content
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
|
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
|
||||||
|
|
||||||
|
def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
timeout: int = 180
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
调用 LLM 进行对话(流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
model: 模型名称
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大生成 token 数
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
逐个返回生成的文本片段
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.api_url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||||||
|
try:
|
||||||
|
error_detail = response.json()
|
||||||
|
if "error" in error_detail:
|
||||||
|
error_msg += f": {error_detail['error']}"
|
||||||
|
except:
|
||||||
|
error_msg += f": {response.text[:200]}"
|
||||||
|
raise LLMClientError(error_msg)
|
||||||
|
|
||||||
|
# 解析 SSE 流
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
if line.startswith('data: '):
|
||||||
|
data = line[6:] # 去掉 "data: " 前缀
|
||||||
|
if data == '[DONE]':
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
if 'choices' in chunk and len(chunk['choices']) > 0:
|
||||||
|
delta = chunk['choices'][0].get('delta', {})
|
||||||
|
content = delta.get('content', '')
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def chat_stream_collect(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
timeout: int = 180,
|
||||||
|
on_chunk: Optional[Callable[[str], None]] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
流式调用并收集完整结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
model: 模型名称
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大生成 token 数
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
on_chunk: 每收到一个片段时的回调函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的生成文本
|
||||||
|
"""
|
||||||
|
full_content = []
|
||||||
|
|
||||||
|
for chunk in self.chat_stream(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
timeout=timeout
|
||||||
|
):
|
||||||
|
full_content.append(chunk)
|
||||||
|
if on_chunk:
|
||||||
|
on_chunk(chunk)
|
||||||
|
|
||||||
|
return ''.join(full_content)
|
||||||
|
|
||||||
|
|
||||||
# 全局单例(延迟初始化)
|
# 全局单例(延迟初始化)
|
||||||
@@ -121,4 +244,3 @@ def get_client() -> LLMClient:
|
|||||||
if _client is None:
|
if _client is None:
|
||||||
_client = LLMClient()
|
_client = LLMClient()
|
||||||
return _client
|
return _client
|
||||||
|
|
||||||
|
|||||||
@@ -155,17 +155,26 @@ CODE_GENERATION_USER = """执行计划:
|
|||||||
# 安全审查 Prompt
|
# 安全审查 Prompt
|
||||||
# ========================================
|
# ========================================
|
||||||
|
|
||||||
SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。检查代码是否符合安全规范。
|
SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。你的任务是判断代码是否安全可执行。
|
||||||
|
|
||||||
检查项:
|
【核心原则】
|
||||||
1. 是否只操作 workspace/input 和 workspace/output 目录
|
- 代码只应操作 workspace/input(读取)和 workspace/output(写入)
|
||||||
2. 是否有网络请求代码(requests, socket, urllib)
|
- 不应有网络请求、执行系统命令等危险操作
|
||||||
3. 是否有危险的文件删除操作(os.remove, shutil.rmtree)
|
- 代码逻辑应与用户需求一致
|
||||||
4. 是否有执行外部命令的代码(subprocess, os.system)
|
|
||||||
5. 代码逻辑是否与用户需求一致
|
【审查要点】
|
||||||
|
1. 路径安全:是否只访问 workspace 目录?是否有路径遍历风险?
|
||||||
|
2. 网络安全:是否有网络请求?(如果用户明确要求下载等网络操作,需拒绝)
|
||||||
|
3. 文件安全:删除操作是否合理?(如果是清理临时文件可以接受,删除用户文件需拒绝)
|
||||||
|
4. 逻辑一致:代码是否实现了用户的需求?
|
||||||
|
|
||||||
|
【判断标准】
|
||||||
|
- 如果代码安全且符合需求 → pass: true
|
||||||
|
- 如果有安全风险或不符合需求 → pass: false
|
||||||
|
- 对于边界情况,倾向于通过(用户已确认执行)
|
||||||
|
|
||||||
输出JSON格式:
|
输出JSON格式:
|
||||||
{"pass": true或false, "reason": "中文审查结论,一句话"}"""
|
{"pass": true或false, "reason": "中文审查结论,简洁说明"}"""
|
||||||
|
|
||||||
SAFETY_REVIEW_USER = """用户需求:{user_input}
|
SAFETY_REVIEW_USER = """用户需求:{user_input}
|
||||||
|
|
||||||
|
|||||||
67
main.py
67
main.py
@@ -171,30 +171,44 @@ class LocalAgentApp:
|
|||||||
f"识别为对话模式 (原因: {intent_result.reason})",
|
f"识别为对话模式 (原因: {intent_result.reason})",
|
||||||
'system'
|
'system'
|
||||||
)
|
)
|
||||||
self.chat_view.add_message("正在生成回复...", 'system')
|
|
||||||
|
|
||||||
# 在后台线程调用 LLM
|
# 开始流式消息
|
||||||
def do_chat():
|
self.chat_view.start_stream_message('assistant')
|
||||||
|
|
||||||
|
# 在后台线程调用 LLM(流式)
|
||||||
|
def do_chat_stream():
|
||||||
client = get_client()
|
client = get_client()
|
||||||
model = os.getenv("GENERATION_MODEL_NAME")
|
model = os.getenv("GENERATION_MODEL_NAME")
|
||||||
return client.chat(
|
|
||||||
|
full_response = []
|
||||||
|
for chunk in client.chat_stream(
|
||||||
messages=[{"role": "user", "content": user_input}],
|
messages=[{"role": "user", "content": user_input}],
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=2048
|
max_tokens=2048,
|
||||||
)
|
timeout=300
|
||||||
|
):
|
||||||
|
full_response.append(chunk)
|
||||||
|
# 通过队列发送 chunk 到主线程更新 UI
|
||||||
|
self.result_queue.put((self._on_chat_chunk, (chunk,)))
|
||||||
|
|
||||||
|
return ''.join(full_response)
|
||||||
|
|
||||||
self._run_in_thread(
|
self._run_in_thread(
|
||||||
do_chat,
|
do_chat_stream,
|
||||||
self._on_chat_result
|
self._on_chat_complete
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_chat_result(self, response: Optional[str], error: Optional[Exception]):
|
def _on_chat_chunk(self, chunk: str):
|
||||||
|
"""收到对话片段回调(主线程)"""
|
||||||
|
self.chat_view.append_stream_chunk(chunk)
|
||||||
|
|
||||||
|
def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]):
|
||||||
"""对话完成回调"""
|
"""对话完成回调"""
|
||||||
|
self.chat_view.end_stream_message()
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
self.chat_view.add_message(f"对话失败: {str(error)}", 'error')
|
self.chat_view.add_message(f"对话失败: {str(error)}", 'error')
|
||||||
else:
|
|
||||||
self.chat_view.add_message(response, 'assistant')
|
|
||||||
|
|
||||||
self.chat_view.set_input_enabled(True)
|
self.chat_view.set_input_enabled(True)
|
||||||
|
|
||||||
@@ -261,13 +275,18 @@ class LocalAgentApp:
|
|||||||
self.current_task = None
|
self.current_task = None
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 保存警告信息,传递给 LLM 审查
|
||||||
|
self.current_task['warnings'] = rule_result.warnings
|
||||||
|
|
||||||
# 在后台线程进行 LLM 安全审查
|
# 在后台线程进行 LLM 安全审查
|
||||||
self._run_in_thread(
|
self._run_in_thread(
|
||||||
review_code_safety,
|
lambda: review_code_safety(
|
||||||
self._on_safety_reviewed,
|
self.current_task['user_input'],
|
||||||
self.current_task['user_input'],
|
self.current_task['execution_plan'],
|
||||||
self.current_task['execution_plan'],
|
code,
|
||||||
code
|
rule_result.warnings # 传递警告给 LLM
|
||||||
|
),
|
||||||
|
self._on_safety_reviewed
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_safety_reviewed(self, review_result, error: Optional[Exception]):
|
def _on_safety_reviewed(self, review_result, error: Optional[Exception]):
|
||||||
@@ -293,28 +312,31 @@ class LocalAgentApp:
|
|||||||
self._show_task_guide()
|
self._show_task_guide()
|
||||||
|
|
||||||
def _generate_execution_plan(self, user_input: str) -> str:
|
def _generate_execution_plan(self, user_input: str) -> str:
|
||||||
"""生成执行计划"""
|
"""生成执行计划(使用流式传输)"""
|
||||||
client = get_client()
|
client = get_client()
|
||||||
model = os.getenv("GENERATION_MODEL_NAME")
|
model = os.getenv("GENERATION_MODEL_NAME")
|
||||||
|
|
||||||
response = client.chat(
|
# 使用流式传输,避免超时
|
||||||
|
response = client.chat_stream_collect(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": EXECUTION_PLAN_SYSTEM},
|
{"role": "system", "content": EXECUTION_PLAN_SYSTEM},
|
||||||
{"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)}
|
{"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=1024
|
max_tokens=1024,
|
||||||
|
timeout=300 # 5分钟超时
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _generate_code(self, user_input: str, execution_plan: str) -> str:
|
def _generate_code(self, user_input: str, execution_plan: str) -> str:
|
||||||
"""生成执行代码"""
|
"""生成执行代码(使用流式传输)"""
|
||||||
client = get_client()
|
client = get_client()
|
||||||
model = os.getenv("GENERATION_MODEL_NAME")
|
model = os.getenv("GENERATION_MODEL_NAME")
|
||||||
|
|
||||||
response = client.chat(
|
# 使用流式传输,避免超时
|
||||||
|
response = client.chat_stream_collect(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": CODE_GENERATION_SYSTEM},
|
{"role": "system", "content": CODE_GENERATION_SYSTEM},
|
||||||
{"role": "user", "content": CODE_GENERATION_USER.format(
|
{"role": "user", "content": CODE_GENERATION_USER.format(
|
||||||
@@ -324,7 +346,8 @@ class LocalAgentApp:
|
|||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=2048
|
max_tokens=4096, # 代码可能较长
|
||||||
|
timeout=300 # 5分钟超时
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提取代码块
|
# 提取代码块
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ LLM 软规则审查器
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -36,7 +36,8 @@ class LLMReviewer:
|
|||||||
self,
|
self,
|
||||||
user_input: str,
|
user_input: str,
|
||||||
execution_plan: str,
|
execution_plan: str,
|
||||||
code: str
|
code: str,
|
||||||
|
warnings: Optional[List[str]] = None
|
||||||
) -> LLMReviewResult:
|
) -> LLMReviewResult:
|
||||||
"""
|
"""
|
||||||
审查代码安全性
|
审查代码安全性
|
||||||
@@ -45,6 +46,7 @@ class LLMReviewer:
|
|||||||
user_input: 用户原始需求
|
user_input: 用户原始需求
|
||||||
execution_plan: 执行计划
|
execution_plan: 执行计划
|
||||||
code: 待审查的代码
|
code: 待审查的代码
|
||||||
|
warnings: 静态检查产生的警告列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMReviewResult: 审查结果
|
LLMReviewResult: 审查结果
|
||||||
@@ -52,20 +54,26 @@ class LLMReviewer:
|
|||||||
try:
|
try:
|
||||||
client = get_client()
|
client = get_client()
|
||||||
|
|
||||||
|
# 构建警告信息
|
||||||
|
warning_text = ""
|
||||||
|
if warnings and len(warnings) > 0:
|
||||||
|
warning_text = "\n\n【静态检查警告】请重点审查以下内容:\n" + "\n".join(f"- {w}" for w in warnings)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": SAFETY_REVIEW_SYSTEM},
|
{"role": "system", "content": SAFETY_REVIEW_SYSTEM},
|
||||||
{"role": "user", "content": SAFETY_REVIEW_USER.format(
|
{"role": "user", "content": SAFETY_REVIEW_USER.format(
|
||||||
user_input=user_input,
|
user_input=user_input,
|
||||||
execution_plan=execution_plan,
|
execution_plan=execution_plan,
|
||||||
code=code
|
code=code
|
||||||
)}
|
) + warning_text}
|
||||||
]
|
]
|
||||||
|
|
||||||
response = client.chat(
|
response = client.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=512
|
max_tokens=512,
|
||||||
|
timeout=120
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
@@ -124,9 +132,9 @@ class LLMReviewer:
|
|||||||
def review_code_safety(
|
def review_code_safety(
|
||||||
user_input: str,
|
user_input: str,
|
||||||
execution_plan: str,
|
execution_plan: str,
|
||||||
code: str
|
code: str,
|
||||||
|
warnings: Optional[List[str]] = None
|
||||||
) -> LLMReviewResult:
|
) -> LLMReviewResult:
|
||||||
"""便捷函数:审查代码安全性"""
|
"""便捷函数:审查代码安全性"""
|
||||||
reviewer = LLMReviewer()
|
reviewer = LLMReviewer()
|
||||||
return reviewer.review(user_input, execution_plan, code)
|
return reviewer.review(user_input, execution_plan, code, warnings)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
硬规则安全检查器
|
硬规则安全检查器
|
||||||
静态扫描执行代码,检测危险操作
|
只检测最危险的操作,其他交给 LLM 审查
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import ast
|
import ast
|
||||||
from typing import List, Tuple
|
from typing import List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@@ -14,41 +14,41 @@ class RuleCheckResult:
|
|||||||
"""规则检查结果"""
|
"""规则检查结果"""
|
||||||
passed: bool
|
passed: bool
|
||||||
violations: List[str] # 违规项列表
|
violations: List[str] # 违规项列表
|
||||||
|
warnings: List[str] # 警告项(交给 LLM 审查)
|
||||||
|
|
||||||
|
|
||||||
class RuleChecker:
|
class RuleChecker:
|
||||||
"""
|
"""
|
||||||
硬规则检查器
|
硬规则检查器
|
||||||
|
|
||||||
静态扫描代码,检测以下危险操作:
|
只硬性禁止最危险的操作:
|
||||||
1. 网络请求: requests, socket, urllib, http.client
|
1. 网络模块: socket(底层网络)
|
||||||
2. 危险文件操作: os.remove, shutil.rmtree, os.unlink
|
2. 执行任意代码: eval, exec, compile
|
||||||
3. 执行外部命令: subprocess, os.system, os.popen
|
3. 执行系统命令: subprocess, os.system, os.popen
|
||||||
4. 访问非 workspace 路径
|
4. 动态导入: __import__
|
||||||
|
|
||||||
|
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 禁止导入的模块
|
# 【硬性禁止】最危险的模块 - 直接拒绝
|
||||||
FORBIDDEN_IMPORTS = {
|
CRITICAL_FORBIDDEN_IMPORTS = {
|
||||||
'requests',
|
'socket', # 底层网络,可绑定端口、建立连接
|
||||||
'socket',
|
'subprocess', # 执行任意系统命令
|
||||||
'urllib',
|
'multiprocessing', # 可能绑定端口
|
||||||
'urllib.request',
|
'asyncio', # 可能包含网络操作
|
||||||
'urllib.parse',
|
'ctypes', # 可调用任意 C 函数
|
||||||
'urllib.error',
|
'cffi', # 外部函数接口
|
||||||
'http.client',
|
|
||||||
'httplib',
|
|
||||||
'ftplib',
|
|
||||||
'smtplib',
|
|
||||||
'telnetlib',
|
|
||||||
'subprocess',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 禁止调用的函数(模块.函数 或 单独函数名)
|
# 【硬性禁止】最危险的函数调用 - 直接拒绝
|
||||||
FORBIDDEN_CALLS = {
|
CRITICAL_FORBIDDEN_CALLS = {
|
||||||
'os.remove',
|
# 执行任意代码
|
||||||
'os.unlink',
|
'eval',
|
||||||
'os.rmdir',
|
'exec',
|
||||||
'os.removedirs',
|
'compile',
|
||||||
|
'__import__',
|
||||||
|
|
||||||
|
# 执行系统命令
|
||||||
'os.system',
|
'os.system',
|
||||||
'os.popen',
|
'os.popen',
|
||||||
'os.spawn',
|
'os.spawn',
|
||||||
@@ -60,7 +60,6 @@ class RuleChecker:
|
|||||||
'os.spawnve',
|
'os.spawnve',
|
||||||
'os.spawnvp',
|
'os.spawnvp',
|
||||||
'os.spawnvpe',
|
'os.spawnvpe',
|
||||||
'os.exec',
|
|
||||||
'os.execl',
|
'os.execl',
|
||||||
'os.execle',
|
'os.execle',
|
||||||
'os.execlp',
|
'os.execlp',
|
||||||
@@ -69,26 +68,28 @@ class RuleChecker:
|
|||||||
'os.execve',
|
'os.execve',
|
||||||
'os.execvp',
|
'os.execvp',
|
||||||
'os.execvpe',
|
'os.execvpe',
|
||||||
'shutil.rmtree',
|
|
||||||
'shutil.move', # move 可能导致原文件丢失
|
|
||||||
'eval',
|
|
||||||
'exec',
|
|
||||||
'compile',
|
|
||||||
'__import__',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 危险路径模式(正则)
|
# 【警告】需要 LLM 审查的模块
|
||||||
DANGEROUS_PATH_PATTERNS = [
|
WARNING_IMPORTS = {
|
||||||
r'[A-Za-z]:\\', # Windows 绝对路径
|
'requests', # HTTP 请求
|
||||||
r'\\\\', # UNC 路径
|
'urllib', # URL 处理
|
||||||
r'/etc/',
|
'http.client', # HTTP 客户端
|
||||||
r'/usr/',
|
'ftplib', # FTP
|
||||||
r'/bin/',
|
'smtplib', # 邮件
|
||||||
r'/home/',
|
'telnetlib', # Telnet
|
||||||
r'/root/',
|
}
|
||||||
r'\.\./', # 父目录遍历
|
|
||||||
r'\.\.', # 父目录
|
# 【警告】需要 LLM 审查的函数调用
|
||||||
]
|
WARNING_CALLS = {
|
||||||
|
'os.remove', # 删除文件
|
||||||
|
'os.unlink', # 删除文件
|
||||||
|
'os.rmdir', # 删除目录
|
||||||
|
'os.removedirs', # 递归删除目录
|
||||||
|
'shutil.rmtree', # 递归删除目录树
|
||||||
|
'shutil.move', # 移动文件(可能丢失原文件)
|
||||||
|
'open', # 打开文件(检查路径)
|
||||||
|
}
|
||||||
|
|
||||||
def check(self, code: str) -> RuleCheckResult:
|
def check(self, code: str) -> RuleCheckResult:
|
||||||
"""
|
"""
|
||||||
@@ -100,27 +101,33 @@ class RuleChecker:
|
|||||||
Returns:
|
Returns:
|
||||||
RuleCheckResult: 检查结果
|
RuleCheckResult: 检查结果
|
||||||
"""
|
"""
|
||||||
violations = []
|
violations = [] # 硬性违规,直接拒绝
|
||||||
|
warnings = [] # 警告,交给 LLM 审查
|
||||||
|
|
||||||
# 1. 检查禁止的导入
|
# 1. 检查硬性禁止的导入
|
||||||
import_violations = self._check_imports(code)
|
critical_import_violations = self._check_critical_imports(code)
|
||||||
violations.extend(import_violations)
|
violations.extend(critical_import_violations)
|
||||||
|
|
||||||
# 2. 检查禁止的函数调用
|
# 2. 检查硬性禁止的函数调用
|
||||||
call_violations = self._check_calls(code)
|
critical_call_violations = self._check_critical_calls(code)
|
||||||
violations.extend(call_violations)
|
violations.extend(critical_call_violations)
|
||||||
|
|
||||||
# 3. 检查危险路径
|
# 3. 检查警告级别的导入
|
||||||
path_violations = self._check_paths(code)
|
warning_imports = self._check_warning_imports(code)
|
||||||
violations.extend(path_violations)
|
warnings.extend(warning_imports)
|
||||||
|
|
||||||
|
# 4. 检查警告级别的函数调用
|
||||||
|
warning_calls = self._check_warning_calls(code)
|
||||||
|
warnings.extend(warning_calls)
|
||||||
|
|
||||||
return RuleCheckResult(
|
return RuleCheckResult(
|
||||||
passed=len(violations) == 0,
|
passed=len(violations) == 0,
|
||||||
violations=violations
|
violations=violations,
|
||||||
|
warnings=warnings
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_imports(self, code: str) -> List[str]:
|
def _check_critical_imports(self, code: str) -> List[str]:
|
||||||
"""检查禁止的导入"""
|
"""检查硬性禁止的导入"""
|
||||||
violations = []
|
violations = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -130,26 +137,25 @@ class RuleChecker:
|
|||||||
if isinstance(node, ast.Import):
|
if isinstance(node, ast.Import):
|
||||||
for alias in node.names:
|
for alias in node.names:
|
||||||
module_name = alias.name.split('.')[0]
|
module_name = alias.name.split('.')[0]
|
||||||
if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
|
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||||||
violations.append(f"禁止导入模块: {alias.name}")
|
violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
|
||||||
|
|
||||||
elif isinstance(node, ast.ImportFrom):
|
elif isinstance(node, ast.ImportFrom):
|
||||||
if node.module:
|
if node.module:
|
||||||
module_name = node.module.split('.')[0]
|
module_name = node.module.split('.')[0]
|
||||||
if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS:
|
if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||||||
violations.append(f"禁止导入模块: {node.module}")
|
violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
|
||||||
|
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
# 如果代码有语法错误,使用正则匹配
|
for module in self.CRITICAL_FORBIDDEN_IMPORTS:
|
||||||
for module in self.FORBIDDEN_IMPORTS:
|
|
||||||
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
|
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
|
||||||
if re.search(pattern, code):
|
if re.search(pattern, code):
|
||||||
violations.append(f"禁止导入模块: {module}")
|
violations.append(f"严禁使用模块: {module}")
|
||||||
|
|
||||||
return violations
|
return violations
|
||||||
|
|
||||||
def _check_calls(self, code: str) -> List[str]:
|
def _check_critical_calls(self, code: str) -> List[str]:
|
||||||
"""检查禁止的函数调用"""
|
"""检查硬性禁止的函数调用"""
|
||||||
violations = []
|
violations = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -158,18 +164,60 @@ class RuleChecker:
|
|||||||
for node in ast.walk(tree):
|
for node in ast.walk(tree):
|
||||||
if isinstance(node, ast.Call):
|
if isinstance(node, ast.Call):
|
||||||
call_name = self._get_call_name(node)
|
call_name = self._get_call_name(node)
|
||||||
if call_name in self.FORBIDDEN_CALLS:
|
if call_name in self.CRITICAL_FORBIDDEN_CALLS:
|
||||||
violations.append(f"禁止调用函数: {call_name}")
|
violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
|
||||||
|
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
# 如果代码有语法错误,使用正则匹配
|
for func in self.CRITICAL_FORBIDDEN_CALLS:
|
||||||
for func in self.FORBIDDEN_CALLS:
|
|
||||||
pattern = rf'\b{re.escape(func)}\s*\('
|
pattern = rf'\b{re.escape(func)}\s*\('
|
||||||
if re.search(pattern, code):
|
if re.search(pattern, code):
|
||||||
violations.append(f"禁止调用函数: {func}")
|
violations.append(f"严禁调用: {func}")
|
||||||
|
|
||||||
return violations
|
return violations
|
||||||
|
|
||||||
|
def _check_warning_imports(self, code: str) -> List[str]:
|
||||||
|
"""检查警告级别的导入"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
module_name = alias.name.split('.')[0]
|
||||||
|
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
|
||||||
|
warnings.append(f"使用了网络相关模块: {alias.name}")
|
||||||
|
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module:
|
||||||
|
module_name = node.module.split('.')[0]
|
||||||
|
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
|
||||||
|
warnings.append(f"使用了网络相关模块: {node.module}")
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
def _check_warning_calls(self, code: str) -> List[str]:
|
||||||
|
"""检查警告级别的函数调用"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
tree = ast.parse(code)
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
call_name = self._get_call_name(node)
|
||||||
|
if call_name in self.WARNING_CALLS:
|
||||||
|
warnings.append(f"使用了敏感操作: {call_name}")
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
def _get_call_name(self, node: ast.Call) -> str:
|
def _get_call_name(self, node: ast.Call) -> str:
|
||||||
"""获取函数调用的完整名称"""
|
"""获取函数调用的完整名称"""
|
||||||
if isinstance(node.func, ast.Name):
|
if isinstance(node.func, ast.Name):
|
||||||
@@ -184,25 +232,9 @@ class RuleChecker:
|
|||||||
parts.append(current.id)
|
parts.append(current.id)
|
||||||
return '.'.join(reversed(parts))
|
return '.'.join(reversed(parts))
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def _check_paths(self, code: str) -> List[str]:
|
|
||||||
"""检查危险路径访问"""
|
|
||||||
violations = []
|
|
||||||
|
|
||||||
for pattern in self.DANGEROUS_PATH_PATTERNS:
|
|
||||||
matches = re.findall(pattern, code, re.IGNORECASE)
|
|
||||||
if matches:
|
|
||||||
# 排除 workspace 相关的合法路径
|
|
||||||
for match in matches:
|
|
||||||
if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower():
|
|
||||||
violations.append(f"检测到可疑路径模式: {match}")
|
|
||||||
break
|
|
||||||
|
|
||||||
return violations
|
|
||||||
|
|
||||||
|
|
||||||
def check_code_safety(code: str) -> RuleCheckResult:
|
def check_code_safety(code: str) -> RuleCheckResult:
|
||||||
"""便捷函数:检查代码安全性"""
|
"""便捷函数:检查代码安全性"""
|
||||||
checker = RuleChecker()
|
checker = RuleChecker()
|
||||||
return checker.check(code)
|
return checker.check(code)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
聊天视图组件
|
聊天视图组件
|
||||||
处理普通对话的 UI 展示
|
处理普通对话的 UI 展示 - 支持流式消息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
@@ -16,6 +16,7 @@ class ChatView:
|
|||||||
- 消息显示区域
|
- 消息显示区域
|
||||||
- 输入框
|
- 输入框
|
||||||
- 发送按钮
|
- 发送按钮
|
||||||
|
- 流式消息支持
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -33,6 +34,10 @@ class ChatView:
|
|||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.on_send = on_send
|
self.on_send = on_send
|
||||||
|
|
||||||
|
# 流式消息状态
|
||||||
|
self._stream_active = False
|
||||||
|
self._stream_tag = None
|
||||||
|
|
||||||
self._create_widgets()
|
self._create_widgets()
|
||||||
|
|
||||||
def _create_widgets(self):
|
def _create_widgets(self):
|
||||||
@@ -71,6 +76,7 @@ class ChatView:
|
|||||||
self.message_area.tag_configure('assistant', foreground='#81c784', font=('Microsoft YaHei UI', 11))
|
self.message_area.tag_configure('assistant', foreground='#81c784', font=('Microsoft YaHei UI', 11))
|
||||||
self.message_area.tag_configure('system', foreground='#ffb74d', font=('Microsoft YaHei UI', 10, 'italic'))
|
self.message_area.tag_configure('system', foreground='#ffb74d', font=('Microsoft YaHei UI', 10, 'italic'))
|
||||||
self.message_area.tag_configure('error', foreground='#ef5350', font=('Microsoft YaHei UI', 10))
|
self.message_area.tag_configure('error', foreground='#ef5350', font=('Microsoft YaHei UI', 10))
|
||||||
|
self.message_area.tag_configure('streaming', foreground='#81c784', font=('Microsoft YaHei UI', 11))
|
||||||
|
|
||||||
# 输入区域框架
|
# 输入区域框架
|
||||||
input_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
input_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
||||||
@@ -147,6 +153,55 @@ class ChatView:
|
|||||||
self.message_area.see(tk.END)
|
self.message_area.see(tk.END)
|
||||||
self.message_area.config(state=tk.DISABLED)
|
self.message_area.config(state=tk.DISABLED)
|
||||||
|
|
||||||
|
def start_stream_message(self, tag: str = 'assistant'):
|
||||||
|
"""
|
||||||
|
开始流式消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: 消息类型
|
||||||
|
"""
|
||||||
|
self._stream_active = True
|
||||||
|
self._stream_tag = tag
|
||||||
|
|
||||||
|
self.message_area.config(state=tk.NORMAL)
|
||||||
|
|
||||||
|
# 添加前缀
|
||||||
|
prefix_map = {
|
||||||
|
'user': '[你] ',
|
||||||
|
'assistant': '[助手] ',
|
||||||
|
'system': '[系统] ',
|
||||||
|
'error': '[错误] '
|
||||||
|
}
|
||||||
|
prefix = prefix_map.get(tag, '')
|
||||||
|
|
||||||
|
self.message_area.insert(tk.END, "\n" + prefix, tag)
|
||||||
|
self.message_area.see(tk.END)
|
||||||
|
# 保持 NORMAL 状态以便追加内容
|
||||||
|
|
||||||
|
def append_stream_chunk(self, chunk: str):
|
||||||
|
"""
|
||||||
|
追加流式消息片段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: 消息片段
|
||||||
|
"""
|
||||||
|
if not self._stream_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.message_area.insert(tk.END, chunk, self._stream_tag)
|
||||||
|
self.message_area.see(tk.END)
|
||||||
|
# 强制更新 UI
|
||||||
|
self.message_area.update_idletasks()
|
||||||
|
|
||||||
|
def end_stream_message(self):
|
||||||
|
"""结束流式消息"""
|
||||||
|
if self._stream_active:
|
||||||
|
self.message_area.insert(tk.END, "\n")
|
||||||
|
self.message_area.see(tk.END)
|
||||||
|
self.message_area.config(state=tk.DISABLED)
|
||||||
|
self._stream_active = False
|
||||||
|
self._stream_tag = None
|
||||||
|
|
||||||
def clear_messages(self):
|
def clear_messages(self):
|
||||||
"""清空消息区域"""
|
"""清空消息区域"""
|
||||||
self.message_area.config(state=tk.NORMAL)
|
self.message_area.config(state=tk.NORMAL)
|
||||||
|
|||||||
Reference in New Issue
Block a user