diff --git a/.env.example b/.env.example index 7b4d2ff..dff2964 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,16 @@ -LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions -LLM_API_KEY=sk-fxsxbgatrjjhsnjpkdfgfngukqoqqgitjpxfqfxifcipaqpc +# ======================================== +# LocalAgent 閰嶇疆鏂囦欢绀轰緥 +# ======================================== +# 浣跨敤鏂规硶锛?# 1. 澶嶅埗姝ゆ枃浠朵负 .env +# 2. 濉叆浣犵殑 API Key 鍜屽叾浠栭厤缃?# ======================================== + +# SiliconFlow API 閰嶇疆 +# 鑾峰彇 API Key: https://siliconflow.cn +LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions +LLM_API_KEY=your_api_key_here + +# 妯″瀷閰嶇疆 +# 鎰忓浘璇嗗埆妯″瀷锛堟帹鑽愪娇鐢ㄥ皬妯″瀷锛岄€熷害蹇級 INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct -GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \ No newline at end of file + +# 浠g爜鐢熸垚妯″瀷锛堟帹鑽愪娇鐢ㄥぇ妯″瀷锛屾晥鏋滃ソ锛?GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct diff --git a/llm/client.py b/llm/client.py index 9b9a754..941b821 100644 --- a/llm/client.py +++ b/llm/client.py @@ -1,12 +1,14 @@ """ LLM 统一调用客户端 所有模型通过 SiliconFlow API 调用 +支持流式和非流式两种模式 """ import os +import json import requests from pathlib import Path -from typing import Optional +from typing import Optional, Generator, Callable from dotenv import load_dotenv # 获取项目根目录 @@ -25,12 +27,19 @@ class LLMClient: 使用方式: client = LLMClient() + + # 非流式调用 response = client.chat( messages=[{"role": "user", "content": "你好"}], - model="Qwen/Qwen2.5-7B-Instruct", - temperature=0.7, - max_tokens=1024 + model="Qwen/Qwen2.5-7B-Instruct" ) + + # 流式调用 + for chunk in client.chat_stream( + messages=[{"role": "user", "content": "你好"}], + model="Qwen/Qwen2.5-7B-Instruct" + ): + print(chunk, end="", flush=True) """ def __init__(self): @@ -49,22 +58,21 @@ class LLMClient: messages: list[dict], model: str, temperature: float = 0.7, - max_tokens: int = 1024 + max_tokens: int = 1024, + timeout: int = 180 ) -> str: """ - 调用 LLM 进行对话 + 调用 LLM 进行对话(非流式) Args: - messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}] + messages: 消息列表 model: 模型名称 - temperature: 温度参数,控制随机性 + temperature: 温度参数 max_tokens: 最大生成 token 数 + timeout: 超时时间(秒),默认 180 秒 Returns: LLM 生成的文本内容 - - Raises: - LLMClientError: 网络异常或 API 返回错误 """ headers = { "Authorization": f"Bearer {self.api_key}", @@ -84,10 +92,10 @@ class LLMClient: self.api_url, headers=headers, json=payload, - timeout=60 + timeout=timeout ) except requests.exceptions.Timeout: - raise LLMClientError("请求超时,请检查网络连接") + raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") except requests.exceptions.ConnectionError: raise LLMClientError("网络连接失败,请检查网络设置") except requests.exceptions.RequestException as e: @@ -109,6 +117,121 @@ class LLMClient: return content except (KeyError, IndexError, TypeError) as e: raise LLMClientError(f"解析 API 响应失败: {str(e)}") + + def chat_stream( + self, + messages: list[dict], + model: str, + temperature: float = 0.7, + max_tokens: int = 2048, + timeout: int = 180 + ) -> Generator[str, None, None]: + """ + 调用 LLM 进行对话(流式) + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大生成 token 数 + timeout: 超时时间(秒) + + Yields: + 逐个返回生成的文本片段 + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "messages": messages, + "stream": True, + "temperature": temperature, + "max_tokens": max_tokens + } + + try: + response = requests.post( + self.api_url, + headers=headers, + json=payload, + timeout=timeout, + stream=True + ) + except requests.exceptions.Timeout: + raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") + except requests.exceptions.ConnectionError: + raise LLMClientError("网络连接失败,请检查网络设置") + except requests.exceptions.RequestException as e: + raise LLMClientError(f"网络请求异常: {str(e)}") + + if response.status_code != 200: + error_msg = f"API 返回错误 (状态码: {response.status_code})" + try: + error_detail = response.json() + if "error" in error_detail: + error_msg += f": {error_detail['error']}" + except: + error_msg += f": {response.text[:200]}" + raise LLMClientError(error_msg) + + # 解析 SSE 流 + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data = line[6:] # 去掉 "data: " 前缀 + if data == '[DONE]': + break + try: + chunk = json.loads(data) + if 'choices' in chunk and len(chunk['choices']) > 0: + delta = chunk['choices'][0].get('delta', {}) + content = delta.get('content', '') + if content: + yield content + except json.JSONDecodeError: + continue + + def chat_stream_collect( + self, + messages: list[dict], + model: str, + temperature: float = 0.7, + max_tokens: int = 2048, + timeout: int = 180, + on_chunk: Optional[Callable[[str], None]] = None + ) -> str: + """ + 流式调用并收集完整结果 + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大生成 token 数 + timeout: 超时时间(秒) + on_chunk: 每收到一个片段时的回调函数 + + Returns: + 完整的生成文本 + """ + full_content = [] + + for chunk in self.chat_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout + ): + full_content.append(chunk) + if on_chunk: + on_chunk(chunk) + + return ''.join(full_content) # 全局单例(延迟初始化) @@ -121,4 +244,3 @@ def get_client() -> LLMClient: if _client is None: _client = LLMClient() return _client - diff --git a/llm/prompts.py b/llm/prompts.py index 91ca255..4a2ff19 100644 --- a/llm/prompts.py +++ b/llm/prompts.py @@ -155,17 +155,26 @@ CODE_GENERATION_USER = """执行计划: # 安全审查 Prompt # ======================================== -SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。检查代码是否符合安全规范。 +SAFETY_REVIEW_SYSTEM = """你是一个代码安全审查员。你的任务是判断代码是否安全可执行。 -检查项: -1. 是否只操作 workspace/input 和 workspace/output 目录 -2. 是否有网络请求代码(requests, socket, urllib) -3. 是否有危险的文件删除操作(os.remove, shutil.rmtree) -4. 是否有执行外部命令的代码(subprocess, os.system) -5. 代码逻辑是否与用户需求一致 +【核心原则】 +- 代码只应操作 workspace/input(读取)和 workspace/output(写入) +- 不应有网络请求、执行系统命令等危险操作 +- 代码逻辑应与用户需求一致 + +【审查要点】 +1. 路径安全:是否只访问 workspace 目录?是否有路径遍历风险? +2. 网络安全:是否有网络请求?(如果用户明确要求下载等网络操作,需拒绝) +3. 文件安全:删除操作是否合理?(如果是清理临时文件可以接受,删除用户文件需拒绝) +4. 逻辑一致:代码是否实现了用户的需求? + +【判断标准】 +- 如果代码安全且符合需求 → pass: true +- 如果有安全风险或不符合需求 → pass: false +- 对于边界情况,倾向于通过(用户已确认执行) 输出JSON格式: -{"pass": true或false, "reason": "中文审查结论,一句话"}""" +{"pass": true或false, "reason": "中文审查结论,简洁说明"}""" SAFETY_REVIEW_USER = """用户需求:{user_input} diff --git a/main.py b/main.py index 3dd14e2..d031df4 100644 --- a/main.py +++ b/main.py @@ -171,30 +171,44 @@ class LocalAgentApp: f"识别为对话模式 (原因: {intent_result.reason})", 'system' ) - self.chat_view.add_message("正在生成回复...", 'system') - # 在后台线程调用 LLM - def do_chat(): + # 开始流式消息 + self.chat_view.start_stream_message('assistant') + + # 在后台线程调用 LLM(流式) + def do_chat_stream(): client = get_client() model = os.getenv("GENERATION_MODEL_NAME") - return client.chat( + + full_response = [] + for chunk in client.chat_stream( messages=[{"role": "user", "content": user_input}], model=model, temperature=0.7, - max_tokens=2048 - ) + max_tokens=2048, + timeout=300 + ): + full_response.append(chunk) + # 通过队列发送 chunk 到主线程更新 UI + self.result_queue.put((self._on_chat_chunk, (chunk,))) + + return ''.join(full_response) self._run_in_thread( - do_chat, - self._on_chat_result + do_chat_stream, + self._on_chat_complete ) - def _on_chat_result(self, response: Optional[str], error: Optional[Exception]): + def _on_chat_chunk(self, chunk: str): + """收到对话片段回调(主线程)""" + self.chat_view.append_stream_chunk(chunk) + + def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]): """对话完成回调""" + self.chat_view.end_stream_message() + if error: self.chat_view.add_message(f"对话失败: {str(error)}", 'error') - else: - self.chat_view.add_message(response, 'assistant') self.chat_view.set_input_enabled(True) @@ -261,13 +275,18 @@ class LocalAgentApp: self.current_task = None return + # 保存警告信息,传递给 LLM 审查 + self.current_task['warnings'] = rule_result.warnings + # 在后台线程进行 LLM 安全审查 self._run_in_thread( - review_code_safety, - self._on_safety_reviewed, - self.current_task['user_input'], - self.current_task['execution_plan'], - code + lambda: review_code_safety( + self.current_task['user_input'], + self.current_task['execution_plan'], + code, + rule_result.warnings # 传递警告给 LLM + ), + self._on_safety_reviewed ) def _on_safety_reviewed(self, review_result, error: Optional[Exception]): @@ -293,28 +312,31 @@ class LocalAgentApp: self._show_task_guide() def _generate_execution_plan(self, user_input: str) -> str: - """生成执行计划""" + """生成执行计划(使用流式传输)""" client = get_client() model = os.getenv("GENERATION_MODEL_NAME") - response = client.chat( + # 使用流式传输,避免超时 + response = client.chat_stream_collect( messages=[ {"role": "system", "content": EXECUTION_PLAN_SYSTEM}, {"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)} ], model=model, temperature=0.3, - max_tokens=1024 + max_tokens=1024, + timeout=300 # 5分钟超时 ) return response def _generate_code(self, user_input: str, execution_plan: str) -> str: - """生成执行代码""" + """生成执行代码(使用流式传输)""" client = get_client() model = os.getenv("GENERATION_MODEL_NAME") - response = client.chat( + # 使用流式传输,避免超时 + response = client.chat_stream_collect( messages=[ {"role": "system", "content": CODE_GENERATION_SYSTEM}, {"role": "user", "content": CODE_GENERATION_USER.format( @@ -324,7 +346,8 @@ class LocalAgentApp: ], model=model, temperature=0.2, - max_tokens=2048 + max_tokens=4096, # 代码可能较长 + timeout=300 # 5分钟超时 ) # 提取代码块 diff --git a/safety/llm_reviewer.py b/safety/llm_reviewer.py index b745154..ab0f5cb 100644 --- a/safety/llm_reviewer.py +++ b/safety/llm_reviewer.py @@ -5,7 +5,7 @@ LLM 软规则审查器 import os import json -from typing import Optional +from typing import Optional, List from dataclasses import dataclass from dotenv import load_dotenv @@ -36,7 +36,8 @@ class LLMReviewer: self, user_input: str, execution_plan: str, - code: str + code: str, + warnings: Optional[List[str]] = None ) -> LLMReviewResult: """ 审查代码安全性 @@ -45,6 +46,7 @@ class LLMReviewer: user_input: 用户原始需求 execution_plan: 执行计划 code: 待审查的代码 + warnings: 静态检查产生的警告列表 Returns: LLMReviewResult: 审查结果 @@ -52,20 +54,26 @@ class LLMReviewer: try: client = get_client() + # 构建警告信息 + warning_text = "" + if warnings and len(warnings) > 0: + warning_text = "\n\n【静态检查警告】请重点审查以下内容:\n" + "\n".join(f"- {w}" for w in warnings) + messages = [ {"role": "system", "content": SAFETY_REVIEW_SYSTEM}, {"role": "user", "content": SAFETY_REVIEW_USER.format( user_input=user_input, execution_plan=execution_plan, code=code - )} + ) + warning_text} ] response = client.chat( messages=messages, model=self.model_name, temperature=0.1, - max_tokens=512 + max_tokens=512, + timeout=120 ) return self._parse_response(response) @@ -124,9 +132,9 @@ class LLMReviewer: def review_code_safety( user_input: str, execution_plan: str, - code: str + code: str, + warnings: Optional[List[str]] = None ) -> LLMReviewResult: """便捷函数:审查代码安全性""" reviewer = LLMReviewer() - return reviewer.review(user_input, execution_plan, code) - + return reviewer.review(user_input, execution_plan, code, warnings) diff --git a/safety/rule_checker.py b/safety/rule_checker.py index be53481..bebf38a 100644 --- a/safety/rule_checker.py +++ b/safety/rule_checker.py @@ -1,11 +1,11 @@ """ 硬规则安全检查器 -静态扫描执行代码,检测危险操作 +只检测最危险的操作,其他交给 LLM 审查 """ import re import ast -from typing import List, Tuple +from typing import List from dataclasses import dataclass @@ -14,41 +14,41 @@ class RuleCheckResult: """规则检查结果""" passed: bool violations: List[str] # 违规项列表 - + warnings: List[str] # 警告项(交给 LLM 审查) + class RuleChecker: """ 硬规则检查器 - 静态扫描代码,检测以下危险操作: - 1. 网络请求: requests, socket, urllib, http.client - 2. 危险文件操作: os.remove, shutil.rmtree, os.unlink - 3. 执行外部命令: subprocess, os.system, os.popen - 4. 访问非 workspace 路径 + 只硬性禁止最危险的操作: + 1. 网络模块: socket(底层网络) + 2. 执行任意代码: eval, exec, compile + 3. 执行系统命令: subprocess, os.system, os.popen + 4. 动态导入: __import__ + + 其他操作(如文件删除、路径访问等)生成警告,交给 LLM 审查 """ - # 禁止导入的模块 - FORBIDDEN_IMPORTS = { - 'requests', - 'socket', - 'urllib', - 'urllib.request', - 'urllib.parse', - 'urllib.error', - 'http.client', - 'httplib', - 'ftplib', - 'smtplib', - 'telnetlib', - 'subprocess', + # 【硬性禁止】最危险的模块 - 直接拒绝 + CRITICAL_FORBIDDEN_IMPORTS = { + 'socket', # 底层网络,可绑定端口、建立连接 + 'subprocess', # 执行任意系统命令 + 'multiprocessing', # 可能绑定端口 + 'asyncio', # 可能包含网络操作 + 'ctypes', # 可调用任意 C 函数 + 'cffi', # 外部函数接口 } - # 禁止调用的函数(模块.函数 或 单独函数名) - FORBIDDEN_CALLS = { - 'os.remove', - 'os.unlink', - 'os.rmdir', - 'os.removedirs', + # 【硬性禁止】最危险的函数调用 - 直接拒绝 + CRITICAL_FORBIDDEN_CALLS = { + # 执行任意代码 + 'eval', + 'exec', + 'compile', + '__import__', + + # 执行系统命令 'os.system', 'os.popen', 'os.spawn', @@ -60,7 +60,6 @@ class RuleChecker: 'os.spawnve', 'os.spawnvp', 'os.spawnvpe', - 'os.exec', 'os.execl', 'os.execle', 'os.execlp', @@ -69,26 +68,28 @@ class RuleChecker: 'os.execve', 'os.execvp', 'os.execvpe', - 'shutil.rmtree', - 'shutil.move', # move 可能导致原文件丢失 - 'eval', - 'exec', - 'compile', - '__import__', } - # 危险路径模式(正则) - DANGEROUS_PATH_PATTERNS = [ - r'[A-Za-z]:\\', # Windows 绝对路径 - r'\\\\', # UNC 路径 - r'/etc/', - r'/usr/', - r'/bin/', - r'/home/', - r'/root/', - r'\.\./', # 父目录遍历 - r'\.\.', # 父目录 - ] + # 【警告】需要 LLM 审查的模块 + WARNING_IMPORTS = { + 'requests', # HTTP 请求 + 'urllib', # URL 处理 + 'http.client', # HTTP 客户端 + 'ftplib', # FTP + 'smtplib', # 邮件 + 'telnetlib', # Telnet + } + + # 【警告】需要 LLM 审查的函数调用 + WARNING_CALLS = { + 'os.remove', # 删除文件 + 'os.unlink', # 删除文件 + 'os.rmdir', # 删除目录 + 'os.removedirs', # 递归删除目录 + 'shutil.rmtree', # 递归删除目录树 + 'shutil.move', # 移动文件(可能丢失原文件) + 'open', # 打开文件(检查路径) + } def check(self, code: str) -> RuleCheckResult: """ @@ -100,27 +101,33 @@ class RuleChecker: Returns: RuleCheckResult: 检查结果 """ - violations = [] + violations = [] # 硬性违规,直接拒绝 + warnings = [] # 警告,交给 LLM 审查 - # 1. 检查禁止的导入 - import_violations = self._check_imports(code) - violations.extend(import_violations) + # 1. 检查硬性禁止的导入 + critical_import_violations = self._check_critical_imports(code) + violations.extend(critical_import_violations) - # 2. 检查禁止的函数调用 - call_violations = self._check_calls(code) - violations.extend(call_violations) + # 2. 检查硬性禁止的函数调用 + critical_call_violations = self._check_critical_calls(code) + violations.extend(critical_call_violations) - # 3. 检查危险路径 - path_violations = self._check_paths(code) - violations.extend(path_violations) + # 3. 检查警告级别的导入 + warning_imports = self._check_warning_imports(code) + warnings.extend(warning_imports) + + # 4. 检查警告级别的函数调用 + warning_calls = self._check_warning_calls(code) + warnings.extend(warning_calls) return RuleCheckResult( passed=len(violations) == 0, - violations=violations + violations=violations, + warnings=warnings ) - def _check_imports(self, code: str) -> List[str]: - """检查禁止的导入""" + def _check_critical_imports(self, code: str) -> List[str]: + """检查硬性禁止的导入""" violations = [] try: @@ -130,26 +137,25 @@ class RuleChecker: if isinstance(node, ast.Import): for alias in node.names: module_name = alias.name.split('.')[0] - if alias.name in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: - violations.append(f"禁止导入模块: {alias.name}") + if module_name in self.CRITICAL_FORBIDDEN_IMPORTS: + violations.append(f"严禁使用模块: {alias.name}(可能执行危险操作)") elif isinstance(node, ast.ImportFrom): if node.module: module_name = node.module.split('.')[0] - if node.module in self.FORBIDDEN_IMPORTS or module_name in self.FORBIDDEN_IMPORTS: - violations.append(f"禁止导入模块: {node.module}") + if module_name in self.CRITICAL_FORBIDDEN_IMPORTS: + violations.append(f"严禁使用模块: {node.module}(可能执行危险操作)") except SyntaxError: - # 如果代码有语法错误,使用正则匹配 - for module in self.FORBIDDEN_IMPORTS: + for module in self.CRITICAL_FORBIDDEN_IMPORTS: pattern = rf'\bimport\s+{re.escape(module)}\b|\bfrom\s+{re.escape(module)}\b' if re.search(pattern, code): - violations.append(f"禁止导入模块: {module}") + violations.append(f"严禁使用模块: {module}") return violations - def _check_calls(self, code: str) -> List[str]: - """检查禁止的函数调用""" + def _check_critical_calls(self, code: str) -> List[str]: + """检查硬性禁止的函数调用""" violations = [] try: @@ -158,18 +164,60 @@ class RuleChecker: for node in ast.walk(tree): if isinstance(node, ast.Call): call_name = self._get_call_name(node) - if call_name in self.FORBIDDEN_CALLS: - violations.append(f"禁止调用函数: {call_name}") + if call_name in self.CRITICAL_FORBIDDEN_CALLS: + violations.append(f"严禁调用: {call_name}(可能执行任意代码或命令)") except SyntaxError: - # 如果代码有语法错误,使用正则匹配 - for func in self.FORBIDDEN_CALLS: + for func in self.CRITICAL_FORBIDDEN_CALLS: pattern = rf'\b{re.escape(func)}\s*\(' if re.search(pattern, code): - violations.append(f"禁止调用函数: {func}") + violations.append(f"严禁调用: {func}") return violations + def _check_warning_imports(self, code: str) -> List[str]: + """检查警告级别的导入""" + warnings = [] + + try: + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split('.')[0] + if module_name in self.WARNING_IMPORTS or alias.name in self.WARNING_IMPORTS: + warnings.append(f"使用了网络相关模块: {alias.name}") + + elif isinstance(node, ast.ImportFrom): + if node.module: + module_name = node.module.split('.')[0] + if module_name in self.WARNING_IMPORTS or node.module in self.WARNING_IMPORTS: + warnings.append(f"使用了网络相关模块: {node.module}") + + except SyntaxError: + pass + + return warnings + + def _check_warning_calls(self, code: str) -> List[str]: + """检查警告级别的函数调用""" + warnings = [] + + try: + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + call_name = self._get_call_name(node) + if call_name in self.WARNING_CALLS: + warnings.append(f"使用了敏感操作: {call_name}") + + except SyntaxError: + pass + + return warnings + def _get_call_name(self, node: ast.Call) -> str: """获取函数调用的完整名称""" if isinstance(node.func, ast.Name): @@ -184,25 +232,9 @@ class RuleChecker: parts.append(current.id) return '.'.join(reversed(parts)) return '' - - def _check_paths(self, code: str) -> List[str]: - """检查危险路径访问""" - violations = [] - - for pattern in self.DANGEROUS_PATH_PATTERNS: - matches = re.findall(pattern, code, re.IGNORECASE) - if matches: - # 排除 workspace 相关的合法路径 - for match in matches: - if 'workspace' not in code[max(0, code.find(match)-50):code.find(match)+50].lower(): - violations.append(f"检测到可疑路径模式: {match}") - break - - return violations def check_code_safety(code: str) -> RuleCheckResult: """便捷函数:检查代码安全性""" checker = RuleChecker() return checker.check(code) - diff --git a/ui/chat_view.py b/ui/chat_view.py index 951de95..f686e94 100644 --- a/ui/chat_view.py +++ b/ui/chat_view.py @@ -1,6 +1,6 @@ """ 聊天视图组件 -处理普通对话的 UI 展示 +处理普通对话的 UI 展示 - 支持流式消息 """ import tkinter as tk @@ -16,6 +16,7 @@ class ChatView: - 消息显示区域 - 输入框 - 发送按钮 + - 流式消息支持 """ def __init__( @@ -33,6 +34,10 @@ class ChatView: self.parent = parent self.on_send = on_send + # 流式消息状态 + self._stream_active = False + self._stream_tag = None + self._create_widgets() def _create_widgets(self): @@ -71,6 +76,7 @@ class ChatView: self.message_area.tag_configure('assistant', foreground='#81c784', font=('Microsoft YaHei UI', 11)) self.message_area.tag_configure('system', foreground='#ffb74d', font=('Microsoft YaHei UI', 10, 'italic')) self.message_area.tag_configure('error', foreground='#ef5350', font=('Microsoft YaHei UI', 10)) + self.message_area.tag_configure('streaming', foreground='#81c784', font=('Microsoft YaHei UI', 11)) # 输入区域框架 input_frame = tk.Frame(self.frame, bg='#1e1e1e') @@ -147,6 +153,55 @@ class ChatView: self.message_area.see(tk.END) self.message_area.config(state=tk.DISABLED) + def start_stream_message(self, tag: str = 'assistant'): + """ + 开始流式消息 + + Args: + tag: 消息类型 + """ + self._stream_active = True + self._stream_tag = tag + + self.message_area.config(state=tk.NORMAL) + + # 添加前缀 + prefix_map = { + 'user': '[你] ', + 'assistant': '[助手] ', + 'system': '[系统] ', + 'error': '[错误] ' + } + prefix = prefix_map.get(tag, '') + + self.message_area.insert(tk.END, "\n" + prefix, tag) + self.message_area.see(tk.END) + # 保持 NORMAL 状态以便追加内容 + + def append_stream_chunk(self, chunk: str): + """ + 追加流式消息片段 + + Args: + chunk: 消息片段 + """ + if not self._stream_active: + return + + self.message_area.insert(tk.END, chunk, self._stream_tag) + self.message_area.see(tk.END) + # 强制更新 UI + self.message_area.update_idletasks() + + def end_stream_message(self): + """结束流式消息""" + if self._stream_active: + self.message_area.insert(tk.END, "\n") + self.message_area.see(tk.END) + self.message_area.config(state=tk.DISABLED) + self._stream_active = False + self._stream_tag = None + def clear_messages(self): """清空消息区域""" self.message_area.config(state=tk.NORMAL)