feat: implement streaming support for chat and enhance safety review process

- Updated .env.example to include API key placeholder and configuration instructions.
- Refactored main.py to support streaming responses from the LLM, improving user experience during chat interactions.
- Enhanced LLMClient to include methods for streaming chat and collecting responses.
- Modified safety review process to pass static analysis warnings to the LLM for better code safety evaluation.
- Improved UI components in chat_view.py to handle streaming messages effectively.
This commit is contained in:
Mimikko-zeus
2026-01-07 09:43:40 +08:00
parent dad0d2629a
commit 1ba5f0f7d6
7 changed files with 406 additions and 145 deletions

View File

@@ -1,4 +1,16 @@
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions # ========================================
LLM_API_KEY=sk-fxsxbgatrjjhsnjpkdfgfngukqoqqgitjpxfqfxifcipaqpc # LocalAgent 閰嶇疆鏂囦欢绀轰緥
# ========================================
# 浣跨敤鏂规硶锛?# 1. 澶嶅埗姝ゆ枃浠朵负 .env
# 2. 濉叆浣犵殑 API Key 鍜屽叾浠栭厤缃?# ========================================
# SiliconFlow API 閰嶇疆
# 鑾峰彇 API Key: https://siliconflow.cn
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions
LLM_API_KEY=your_api_key_here
# 妯″瀷閰嶇疆
# 鎰忓浘璇嗗埆妯″瀷锛堟帹鑽愪娇鐢ㄥ皬妯″瀷锛岄€熷害蹇級
INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
# 浠g爜鐢熸垚妯″瀷锛堟帹鑽愪娇鐢ㄥぇ妯″瀷锛屾晥鏋滃ソ锛?GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct

View File

@@ -1,12 +1,14 @@
""" """
LLM 统一调用客户端 LLM 统一调用客户端
所有模型通过 SiliconFlow API 调用 所有模型通过 SiliconFlow API 调用
支持流式和非流式两种模式
""" """
import os import os
import json
import requests import requests
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Generator, Callable
from dotenv import load_dotenv from dotenv import load_dotenv
# 获取项目根目录 # 获取项目根目录
@@ -25,12 +27,19 @@ class LLMClient:
使用方式: 使用方式:
client = LLMClient() client = LLMClient()
# 非流式调用
response = client.chat( response = client.chat(
messages=[{"role": "user", "content": "你好"}], messages=[{"role": "user", "content": "你好"}],
model="Qwen/Qwen2.5-7B-Instruct", model="Qwen/Qwen2.5-7B-Instruct"
temperature=0.7,
max_tokens=1024
) )
# 流式调用
for chunk in client.chat_stream(
messages=[{"role": "user", "content": "你好"}],
model="Qwen/Qwen2.5-7B-Instruct"
):
print(chunk, end="", flush=True)
""" """
def __init__(self): def __init__(self):
@@ -49,22 +58,21 @@ class LLMClient:
messages: list[dict], messages: list[dict],
model: str, model: str,
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 1024 max_tokens: int = 1024,
timeout: int = 180
) -> str: ) -> str:
""" """
调用 LLM 进行对话 调用 LLM 进行对话(非流式)
Args: Args:
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}] messages: 消息列表
model: 模型名称 model: 模型名称
temperature: 温度参数,控制随机性 temperature: 温度参数
max_tokens: 最大生成 token 数 max_tokens: 最大生成 token 数
timeout: 超时时间(秒),默认 180 秒
Returns: Returns:
LLM 生成的文本内容 LLM 生成的文本内容
Raises:
LLMClientError: 网络异常或 API 返回错误
""" """
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@@ -84,10 +92,10 @@ class LLMClient:
self.api_url, self.api_url,
headers=headers, headers=headers,
json=payload, json=payload,
timeout=60 timeout=timeout
) )
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
raise LLMClientError("请求超时,请检查网络连接") raise LLMClientError(f"请求超时{timeout}秒),请检查网络连接或稍后重试")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
raise LLMClientError("网络连接失败,请检查网络设置") raise LLMClientError("网络连接失败,请检查网络设置")
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
@@ -110,6 +118,121 @@ class LLMClient:
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise LLMClientError(f"解析 API 响应失败: {str(e)}") raise LLMClientError(f"解析 API 响应失败: {str(e)}")
def chat_stream(
self,
messages: list[dict],
model: str,
temperature: float = 0.7,
max_tokens: int = 2048,
timeout: int = 180
) -> Generator[str, None, None]:
"""
调用 LLM 进行对话(流式)
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大生成 token 数
timeout: 超时时间(秒)
Yields:
逐个返回生成的文本片段
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": True,
"temperature": temperature,
"max_tokens": max_tokens
}
try:
response = requests.post(
self.api_url,
headers=headers,
json=payload,
timeout=timeout,
stream=True
)
except requests.exceptions.Timeout:
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
except requests.exceptions.ConnectionError:
raise LLMClientError("网络连接失败,请检查网络设置")
except requests.exceptions.RequestException as e:
raise LLMClientError(f"网络请求异常: {str(e)}")
if response.status_code != 200:
error_msg = f"API 返回错误 (状态码: {response.status_code})"
try:
error_detail = response.json()
if "error" in error_detail:
error_msg += f": {error_detail['error']}"
except:
error_msg += f": {response.text[:200]}"
raise LLMClientError(error_msg)
# 解析 SSE 流
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = line[6:] # 去掉 "data: " 前缀
if data == '[DONE]':
break
try:
chunk = json.loads(data)
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
yield content
except json.JSONDecodeError:
continue
def chat_stream_collect(
self,
messages: list[dict],
model: str,
temperature: float = 0.7,
max_tokens: int = 2048,
timeout: int = 180,
on_chunk: Optional[Callable[[str], None]] = None
) -> str:
"""
流式调用并收集完整结果
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大生成 token 数
timeout: 超时时间(秒)
on_chunk: 每收到一个片段时的回调函数
Returns:
完整的生成文本
"""
full_content = []
for chunk in self.chat_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout
):
full_content.append(chunk)
if on_chunk:
on_chunk(chunk)
return ''.join(full_content)
# 全局单例(延迟初始化) # 全局单例(延迟初始化)
_client: Optional[LLMClient] = None _client: Optional[LLMClient] = None
@@ -121,4 +244,3 @@ def get_client() -> LLMClient:
if _client is None: if _client is None:
_client = LLMClient() _client = LLMClient()
return _client return _client

View File

@@ -155,17 +155,26 @@ CODE_GENERATION_USER = """执行计划:
# 安全审查 Prompt # 安全审查 Prompt
# ======================================== # ========================================
SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。检查代码是否符合安全规范 SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。你的任务是判断代码是否安全可执行
检查项: 【核心原则】
1. 是否只操作 workspace/input 和 workspace/output 目录 - 代码只应操作 workspace/input(读取)和 workspace/output(写入)
2. 是否有网络请求代码requests, socket, urllib - 不应有网络请求、执行系统命令等危险操作
3. 是否有危险的文件删除操作os.remove, shutil.rmtree - 代码逻辑应与用户需求一致
4. 是否有执行外部命令的代码subprocess, os.system
5. 代码逻辑是否与用户需求一致 【审查要点】
1. 路径安全:是否只访问 workspace 目录?是否有路径遍历风险?
2. 网络安全:是否有网络请求?(如果用户明确要求下载等网络操作,需拒绝)
3. 文件安全:删除操作是否合理?(如果是清理临时文件可以接受,删除用户文件需拒绝)
4. 逻辑一致:代码是否实现了用户的需求?
【判断标准】
- 如果代码安全且符合需求 → pass: true
- 如果有安全风险或不符合需求 → pass: false
- 对于边界情况,倾向于通过(用户已确认执行)
输出JSON格式 输出JSON格式
{"pass": true或false, "reason": "中文审查结论,一句话"}""" {"pass": true或false, "reason": "中文审查结论,简洁说明"}"""
SAFETY_REVIEW_USER = """用户需求:{user_input} SAFETY_REVIEW_USER = """用户需求:{user_input}

67
main.py
View File

@@ -171,30 +171,44 @@ class LocalAgentApp:
f"识别为对话模式 (原因: {intent_result.reason})", f"识别为对话模式 (原因: {intent_result.reason})",
'system' 'system'
) )
self.chat_view.add_message("正在生成回复...", 'system')
# 在后台线程调用 LLM # 开始流式消息
def do_chat(): self.chat_view.start_stream_message('assistant')
# 在后台线程调用 LLM流式
def do_chat_stream():
client = get_client() client = get_client()
model = os.getenv("GENERATION_MODEL_NAME") model = os.getenv("GENERATION_MODEL_NAME")
return client.chat(
full_response = []
for chunk in client.chat_stream(
messages=[{"role": "user", "content": user_input}], messages=[{"role": "user", "content": user_input}],
model=model, model=model,
temperature=0.7, temperature=0.7,
max_tokens=2048 max_tokens=2048,
) timeout=300
):
full_response.append(chunk)
# 通过队列发送 chunk 到主线程更新 UI
self.result_queue.put((self._on_chat_chunk, (chunk,)))
return ''.join(full_response)
self._run_in_thread( self._run_in_thread(
do_chat, do_chat_stream,
self._on_chat_result self._on_chat_complete
) )
def _on_chat_result(self, response: Optional[str], error: Optional[Exception]): def _on_chat_chunk(self, chunk: str):
"""收到对话片段回调(主线程)"""
self.chat_view.append_stream_chunk(chunk)
def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]):
"""对话完成回调""" """对话完成回调"""
self.chat_view.end_stream_message()
if error: if error:
self.chat_view.add_message(f"对话失败: {str(error)}", 'error') self.chat_view.add_message(f"对话失败: {str(error)}", 'error')
else:
self.chat_view.add_message(response, 'assistant')
self.chat_view.set_input_enabled(True) self.chat_view.set_input_enabled(True)
@@ -261,13 +275,18 @@ class LocalAgentApp:
self.current_task = None self.current_task = None
return return
# 保存警告信息,传递给 LLM 审查
self.current_task['warnings'] = rule_result.warnings
# 在后台线程进行 LLM 安全审查 # 在后台线程进行 LLM 安全审查
self._run_in_thread( self._run_in_thread(
review_code_safety, lambda: review_code_safety(
self._on_safety_reviewed, self.current_task['user_input'],
self.current_task['user_input'], self.current_task['execution_plan'],
self.current_task['execution_plan'], code,
code rule_result.warnings # 传递警告给 LLM
),
self._on_safety_reviewed
) )
def _on_safety_reviewed(self, review_result, error: Optional[Exception]): def _on_safety_reviewed(self, review_result, error: Optional[Exception]):
@@ -293,28 +312,31 @@ class LocalAgentApp:
self._show_task_guide() self._show_task_guide()
def _generate_execution_plan(self, user_input: str) -> str: def _generate_execution_plan(self, user_input: str) -> str:
"""生成执行计划""" """生成执行计划(使用流式传输)"""
client = get_client() client = get_client()
model = os.getenv("GENERATION_MODEL_NAME") model = os.getenv("GENERATION_MODEL_NAME")
response = client.chat( # 使用流式传输,避免超时
response = client.chat_stream_collect(
messages=[ messages=[
{"role": "system", "content": EXECUTION_PLAN_SYSTEM}, {"role": "system", "content": EXECUTION_PLAN_SYSTEM},
{"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)} {"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)}
], ],
model=model, model=model,
temperature=0.3, temperature=0.3,
max_tokens=1024 max_tokens=1024,
timeout=300 # 5分钟超时
) )
return response return response
def _generate_code(self, user_input: str, execution_plan: str) -> str: def _generate_code(self, user_input: str, execution_plan: str) -> str:
"""生成执行代码""" """生成执行代码(使用流式传输)"""
client = get_client() client = get_client()
model = os.getenv("GENERATION_MODEL_NAME") model = os.getenv("GENERATION_MODEL_NAME")
response = client.chat( # 使用流式传输,避免超时
response = client.chat_stream_collect(
messages=[ messages=[
{"role": "system", "content": CODE_GENERATION_SYSTEM}, {"role": "system", "content": CODE_GENERATION_SYSTEM},
{"role": "user", "content": CODE_GENERATION_USER.format( {"role": "user", "content": CODE_GENERATION_USER.format(
@@ -324,7 +346,8 @@ class LocalAgentApp:
], ],
model=model, model=model,
temperature=0.2, temperature=0.2,
max_tokens=2048 max_tokens=4096, # 代码可能较长
timeout=300 # 5分钟超时
) )
# 提取代码块 # 提取代码块

View File

@@ -5,7 +5,7 @@ LLM 软规则审查器
import os import os
import json import json
from typing import Optional from typing import Optional, List
from dataclasses import dataclass from dataclasses import dataclass
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -36,7 +36,8 @@ class LLMReviewer:
self, self,
user_input: str, user_input: str,
execution_plan: str, execution_plan: str,
code: str code: str,
warnings: Optional[List[str]] = None
) -> LLMReviewResult: ) -> LLMReviewResult:
""" """
审查代码安全性 审查代码安全性
@@ -45,6 +46,7 @@ class LLMReviewer:
user_input: 用户原始需求 user_input: 用户原始需求
execution_plan: 执行计划 execution_plan: 执行计划
code: 待审查的代码 code: 待审查的代码
warnings: 静态检查产生的警告列表
Returns: Returns:
LLMReviewResult: 审查结果 LLMReviewResult: 审查结果
@@ -52,20 +54,26 @@ class LLMReviewer:
try: try:
client = get_client() client = get_client()
# 构建警告信息
warning_text = ""
if warnings and len(warnings) > 0:
warning_text = "\n\n【静态检查警告】请重点审查以下内容:\n" + "\n".join(f"- {w}" for w in warnings)
messages = [ messages = [
{"role": "system", "content": SAFETY_REVIEW_SYSTEM}, {"role": "system", "content": SAFETY_REVIEW_SYSTEM},
{"role": "user", "content": SAFETY_REVIEW_USER.format( {"role": "user", "content": SAFETY_REVIEW_USER.format(
user_input=user_input, user_input=user_input,
execution_plan=execution_plan, execution_plan=execution_plan,
code=code code=code
)} ) + warning_text}
] ]
response = client.chat( response = client.chat(
messages=messages, messages=messages,
model=self.model_name, model=self.model_name,
temperature=0.1, temperature=0.1,
max_tokens=512 max_tokens=512,
timeout=120
) )
return self._parse_response(response) return self._parse_response(response)
@@ -124,9 +132,9 @@ class LLMReviewer:
def review_code_safety( def review_code_safety(
user_input: str, user_input: str,
execution_plan: str, execution_plan: str,
code: str code: str,
warnings: Optional[List[str]] = None
) -> LLMReviewResult: ) -> LLMReviewResult:
"""便捷函数:审查代码安全性""" """便捷函数:审查代码安全性"""
reviewer = LLMReviewer() reviewer = LLMReviewer()
return reviewer.review(user_input, execution_plan, code) return reviewer.review(user_input, execution_plan, code, warnings)

View File

@@ -1,11 +1,11 @@
""" """
硬规则安全检查器 硬规则安全检查器
静态扫描执行代码,检测危险操作 检测危险操作,其他交给 LLM 审查
""" """
import re import re
import ast import ast
from typing import List, Tuple from typing import List
from dataclasses import dataclass from dataclasses import dataclass
@@ -14,41 +14,41 @@ class RuleCheckResult:
"""规则检查结果""" """规则检查结果"""
passed: bool passed: bool
violations: List[str] # 违规项列表 violations: List[str] # 违规项列表
warnings: List[str] # 警告项(交给 LLM 审查)
class RuleChecker: class RuleChecker:
""" """
硬规则检查器 硬规则检查器
静态扫描代码,检测以下危险操作: 只硬性禁止最危险操作:
1. 网络请求: requests, socket, urllib, http.client 1. 网络模块: socket底层网络
2. 危险文件操作: os.remove, shutil.rmtree, os.unlink 2. 执行任意代码: eval, exec, compile
3. 执行外部命令: subprocess, os.system, os.popen 3. 执行系统命令: subprocess, os.system, os.popen
4. 访问非 workspace 路径 4. 动态导入: __import__
其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查
""" """
# 禁止导入的模块 # 【硬性禁止】最危险的模块 - 直接拒绝
FORBIDDEN_IMPORTS = { CRITICAL_FORBIDDEN_IMPORTS = {
'requests', 'socket', # 底层网络,可绑定端口、建立连接
'socket', 'subprocess', # 执行任意系统命令
'urllib', 'multiprocessing', # 可能绑定端口
'urllib.request', 'asyncio', # 可能包含网络操作
'urllib.parse', 'ctypes', # 可调用任意 C 函数
'urllib.error', 'cffi', # 外部函数接口
'http.client',
'httplib',
'ftplib',
'smtplib',
'telnetlib',
'subprocess',
} }
# 禁止调用的函数(模块.函数 或 单独函数名) # 【硬性禁止】最危险的函数调用 - 直接拒绝
FORBIDDEN_CALLS = { CRITICAL_FORBIDDEN_CALLS = {
'os.remove', # 执行任意代码
'os.unlink', 'eval',
'os.rmdir', 'exec',
'os.removedirs', 'compile',
'__import__',
# 执行系统命令
'os.system', 'os.system',
'os.popen', 'os.popen',
'os.spawn', 'os.spawn',
@@ -60,7 +60,6 @@ class RuleChecker:
'os.spawnve', 'os.spawnve',
'os.spawnvp', 'os.spawnvp',
'os.spawnvpe', 'os.spawnvpe',
'os.exec',
'os.execl', 'os.execl',
'os.execle', 'os.execle',
'os.execlp', 'os.execlp',
@@ -69,26 +68,28 @@ class RuleChecker:
'os.execve', 'os.execve',
'os.execvp', 'os.execvp',
'os.execvpe', 'os.execvpe',
'shutil.rmtree',
'shutil.move', # move 可能导致原文件丢失
'eval',
'exec',
'compile',
'__import__',
} }
# 危险路径模式(正则) # 【警告】需要 LLM 审查的模块
DANGEROUS_PATH_PATTERNS = [ WARNING_IMPORTS = {
r'[A-Za-z]:\\', # Windows 绝对路径 'requests', # HTTP 请求
r'\\\\', # UNC 路径 'urllib', # URL 处理
r'/etc/', 'http.client', # HTTP 客户端
r'/usr/', 'ftplib', # FTP
r'/bin/', 'smtplib', # 邮件
r'/home/', 'telnetlib', # Telnet
r'/root/', }
r'\.\./', # 父目录遍历
r'\.\.', # 父目录 # 【警告】需要 LLM 审查的函数调用
] WARNING_CALLS = {
'os.remove', # 删除文件
'os.unlink', # 删除文件
'os.rmdir', # 删除目录
'os.removedirs', # 递归删除目录
'shutil.rmtree', # 递归删除目录树
'shutil.move', # 移动文件(可能丢失原文件)
'open', # 打开文件(检查路径)
}
def check(self, code: str) -> RuleCheckResult: def check(self, code: str) -> RuleCheckResult:
""" """
@@ -100,27 +101,33 @@ class RuleChecker:
Returns: Returns:
RuleCheckResult: 检查结果 RuleCheckResult: 检查结果
""" """
violations = [] violations = [] # 硬性违规,直接拒绝
warnings = [] # 警告,交给 LLM 审查
# 1. 检查禁止的导入 # 1. 检查硬性禁止的导入
import_violations = self._check_imports(code) critical_import_violations = self._check_critical_imports(code)
violations.extend(import_violations) violations.extend(critical_import_violations)
# 2. 检查禁止的函数调用 # 2. 检查硬性禁止的函数调用
call_violations = self._check_calls(code) critical_call_violations = self._check_critical_calls(code)
violations.extend(call_violations) violations.extend(critical_call_violations)
# 3. 检查危险路径 # 3. 检查警告级别的导入
path_violations = self._check_paths(code) warning_imports = self._check_warning_imports(code)
violations.extend(path_violations) warnings.extend(warning_imports)
# 4. 检查警告级别的函数调用
warning_calls = self._check_warning_calls(code)
warnings.extend(warning_calls)
return RuleCheckResult( return RuleCheckResult(
passed=len(violations) == 0, passed=len(violations) == 0,
violations=violations violations=violations,
warnings=warnings
) )
def _check_imports(self, code: str) -> List[str]: def _check_critical_imports(self, code: str) -> List[str]:
"""检查禁止的导入""" """检查硬性禁止的导入"""
violations = [] violations = []
try: try:
@@ -130,26 +137,25 @@ class RuleChecker:
if isinstance(node, ast.Import): if isinstance(node, ast.Import):
for alias in node.names: for alias in node.names:
module_name = alias.name.split('.')[0] module_name = alias.name.split('.')[0]
if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {alias.name}") violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)")
elif isinstance(node, ast.ImportFrom): elif isinstance(node, ast.ImportFrom):
if node.module: if node.module:
module_name = node.module.split('.')[0] module_name = node.module.split('.')[0]
if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: if module_name in self.CRITICAL_FORBIDDEN_IMPORTS:
violations.append(f"禁止导入模块: {node.module}") violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)")
except SyntaxError: except SyntaxError:
# 如果代码有语法错误,使用正则匹配 for module in self.CRITICAL_FORBIDDEN_IMPORTS:
for module in self.FORBIDDEN_IMPORTS:
pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b' pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b'
if re.search(pattern, code): if re.search(pattern, code):
violations.append(f"禁止导入模块: {module}") violations.append(f"严禁使用模块: {module}")
return violations return violations
def _check_calls(self, code: str) -> List[str]: def _check_critical_calls(self, code: str) -> List[str]:
"""检查禁止的函数调用""" """检查硬性禁止的函数调用"""
violations = [] violations = []
try: try:
@@ -158,18 +164,60 @@ class RuleChecker:
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
call_name = self._get_call_name(node) call_name = self._get_call_name(node)
if call_name in self.FORBIDDEN_CALLS: if call_name in self.CRITICAL_FORBIDDEN_CALLS:
violations.append(f"禁止调用函数: {call_name}") violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)")
except SyntaxError: except SyntaxError:
# 如果代码有语法错误,使用正则匹配 for func in self.CRITICAL_FORBIDDEN_CALLS:
for func in self.FORBIDDEN_CALLS:
pattern = rf'\b{re.escape(func)}\s*\(' pattern = rf'\b{re.escape(func)}\s*\('
if re.search(pattern, code): if re.search(pattern, code):
violations.append(f"禁止调用函数: {func}") violations.append(f"严禁调用: {func}")
return violations return violations
def _check_warning_imports(self, code: str) -> List[str]:
"""检查警告级别的导入"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split('.')[0]
if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {alias.name}")
elif isinstance(node, ast.ImportFrom):
if node.module:
module_name = node.module.split('.')[0]
if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS:
warnings.append(f"使用了网络相关模块: {node.module}")
except SyntaxError:
pass
return warnings
def _check_warning_calls(self, code: str) -> List[str]:
"""检查警告级别的函数调用"""
warnings = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
call_name = self._get_call_name(node)
if call_name in self.WARNING_CALLS:
warnings.append(f"使用了敏感操作: {call_name}")
except SyntaxError:
pass
return warnings
def _get_call_name(self, node: ast.Call) -> str: def _get_call_name(self, node: ast.Call) -> str:
"""获取函数调用的完整名称""" """获取函数调用的完整名称"""
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
@@ -185,24 +233,8 @@ class RuleChecker:
return '.'.join(reversed(parts)) return '.'.join(reversed(parts))
return '' return ''
def _check_paths(self, code: str) -> List[str]:
"""检查危险路径访问"""
violations = []
for pattern in self.DANGEROUS_PATH_PATTERNS:
matches = re.findall(pattern, code, re.IGNORECASE)
if matches:
# 排除 workspace 相关的合法路径
for match in matches:
if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower():
violations.append(f"检测到可疑路径模式: {match}")
break
return violations
def check_code_safety(code: str) -> RuleCheckResult: def check_code_safety(code: str) -> RuleCheckResult:
"""便捷函数:检查代码安全性""" """便捷函数:检查代码安全性"""
checker = RuleChecker() checker = RuleChecker()
return checker.check(code) return checker.check(code)

View File

@@ -1,6 +1,6 @@
""" """
聊天视图组件 聊天视图组件
处理普通对话的 UI 展示 处理普通对话的 UI 展示 - 支持流式消息
""" """
import tkinter as tk import tkinter as tk
@@ -16,6 +16,7 @@ class ChatView:
- 消息显示区域 - 消息显示区域
- 输入框 - 输入框
- 发送按钮 - 发送按钮
- 流式消息支持
""" """
def __init__( def __init__(
@@ -33,6 +34,10 @@ class ChatView:
self.parent = parent self.parent = parent
self.on_send = on_send self.on_send = on_send
# 流式消息状态
self._stream_active = False
self._stream_tag = None
self._create_widgets() self._create_widgets()
def _create_widgets(self): def _create_widgets(self):
@@ -71,6 +76,7 @@ class ChatView:
self.message_area.tag_configure('assistant', foreground='#81c784', font=('Microsoft YaHei UI', 11)) self.message_area.tag_configure('assistant', foreground='#81c784', font=('Microsoft YaHei UI', 11))
self.message_area.tag_configure('system', foreground='#ffb74d', font=('Microsoft YaHei UI', 10, 'italic')) self.message_area.tag_configure('system', foreground='#ffb74d', font=('Microsoft YaHei UI', 10, 'italic'))
self.message_area.tag_configure('error', foreground='#ef5350', font=('Microsoft YaHei UI', 10)) self.message_area.tag_configure('error', foreground='#ef5350', font=('Microsoft YaHei UI', 10))
self.message_area.tag_configure('streaming', foreground='#81c784', font=('Microsoft YaHei UI', 11))
# 输入区域框架 # 输入区域框架
input_frame = tk.Frame(self.frame, bg='#1e1e1e') input_frame = tk.Frame(self.frame, bg='#1e1e1e')
@@ -147,6 +153,55 @@ class ChatView:
self.message_area.see(tk.END) self.message_area.see(tk.END)
self.message_area.config(state=tk.DISABLED) self.message_area.config(state=tk.DISABLED)
def start_stream_message(self, tag: str = 'assistant'):
"""
开始流式消息
Args:
tag: 消息类型
"""
self._stream_active = True
self._stream_tag = tag
self.message_area.config(state=tk.NORMAL)
# 添加前缀
prefix_map = {
'user': '[你] ',
'assistant': '[助手] ',
'system': '[系统] ',
'error': '[错误] '
}
prefix = prefix_map.get(tag, '')
self.message_area.insert(tk.END, "\n" + prefix, tag)
self.message_area.see(tk.END)
# 保持 NORMAL 状态以便追加内容
def append_stream_chunk(self, chunk: str):
"""
追加流式消息片段
Args:
chunk: 消息片段
"""
if not self._stream_active:
return
self.message_area.insert(tk.END, chunk, self._stream_tag)
self.message_area.see(tk.END)
# 强制更新 UI
self.message_area.update_idletasks()
def end_stream_message(self):
"""结束流式消息"""
if self._stream_active:
self.message_area.insert(tk.END, "\n")
self.message_area.see(tk.END)
self.message_area.config(state=tk.DISABLED)
self._stream_active = False
self._stream_tag = None
def clear_messages(self): def clear_messages(self):
"""清空消息区域""" """清空消息区域"""
self.message_area.config(state=tk.NORMAL) self.message_area.config(state=tk.NORMAL)