diff --git a/.env.example b/.env.example index dff2964..9ab0822 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,19 @@ -# ======================================== -# LocalAgent 閰嶇疆鏂囦欢绀轰緥 # ======================================== -# 浣跨敤鏂规硶锛?# 1. 澶嶅埗姝ゆ枃浠朵负 .env -# 2. 濉叆浣犵殑 API Key 鍜屽叾浠栭厤缃?# ======================================== +# LocalAgent Configuration Example +# ======================================== +# Usage: +# 1. Copy this file to .env +# 2. Fill in your API Key and other settings +# ======================================== -# SiliconFlow API 閰嶇疆 -# 鑾峰彇 API Key: https://siliconflow.cn +# SiliconFlow API Configuration +# Get API Key: https://siliconflow.cn LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions LLM_API_KEY=your_api_key_here -# 妯″瀷閰嶇疆 -# 鎰忓浘璇嗗埆妯″瀷锛堟帹鑽愪娇鐢ㄥ皬妯″瀷锛岄€熷害蹇級 +# Model Configuration +# Intent recognition model (small model recommended for speed) INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct -# 浠g爜鐢熸垚妯″瀷锛堟帹鑽愪娇鐢ㄥぇ妯″瀷锛屾晥鏋滃ソ锛?GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct +# Code generation model (large model recommended for quality) +GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct diff --git a/README.md b/README.md new file mode 100644 index 0000000..89e6f75 --- /dev/null +++ b/README.md @@ -0,0 +1,169 @@ +# LocalAgent - Windows Local AI Execution Assistant + +A Windows-based local AI assistant that can understand natural language commands and execute file processing tasks safely in a sandboxed environment. + +## Features + +- **Intent Recognition**: Automatically distinguishes between chat conversations and execution tasks +- **Code Generation**: Generates Python code based on user requirements +- **Safety Checks**: Multi-layer security with static analysis and LLM review +- **Sandbox Execution**: Runs generated code in an isolated environment +- **Task History**: Records all executed tasks for review +- **Streaming Responses**: Real-time display of LLM responses + +## Project Structure + +``` +LocalAgent/ +├── app/ # Main application +│ └── agent.py # Core application class +├── llm/ # LLM integration +│ ├── client.py # API client with retry support +│ └── prompts.py # Prompt templates +├── intent/ # Intent classification +│ ├── classifier.py # Intent classifier +│ └── labels.py # Intent labels +├── safety/ # Security checks +│ ├── rule_checker.py # Static rule checker +│ └── llm_reviewer.py # LLM-based code review +├── executor/ # Code execution +│ └── sandbox_runner.py # Sandbox executor +├── history/ # Task history +│ └── manager.py # History manager +├── ui/ # User interface +│ ├── chat_view.py # Chat interface +│ ├── task_guide_view.py # Task confirmation view +│ └── history_view.py # History view +├── tests/ # Unit tests +├── workspace/ # Working directory (auto-created) +│ ├── input/ # Input files +│ ├── output/ # Output files +│ ├── codes/ # Generated code +│ └── logs/ # Execution logs +├── main.py # Entry point +├── requirements.txt # Dependencies +└── .env.example # Configuration template +``` + +## Installation + +### Prerequisites + +- Python 3.10+ +- Windows OS +- SiliconFlow API Key ([Get one here](https://siliconflow.cn)) + +### Setup + +1. **Clone the repository** + ```bash + git clone + cd LocalAgent + ``` + +2. **Create virtual environment** (recommended using Anaconda) + ```bash + conda create -n localagent python=3.10 + conda activate localagent + ``` + +3. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +4. **Configure environment** + ```bash + cp .env.example .env + # Edit .env and add your API key + ``` + +5. **Run the application** + ```bash + python main.py + ``` + +## Configuration + +Edit `.env` file with your settings: + +```env +# SiliconFlow API Configuration +LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions +LLM_API_KEY=your_api_key_here + +# Model Configuration +INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct +GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct +``` + +## Usage + +### Chat Mode +Simply type questions or have conversations: +- "What is Python?" +- "Explain machine learning" + +### Execution Mode +Describe file processing tasks: +- "Copy all files from input to output" +- "Convert all PNG images to JPG format" +- "Rename files with today's date prefix" + +### Workflow +1. Place input files in `workspace/input/` +2. Describe your task in the chat +3. Review the execution plan and generated code +4. Click "Execute" to run +5. Find results in `workspace/output/` + +## Security + +LocalAgent implements multiple security layers: + +1. **Hard Rules** - Blocks dangerous operations: + - Network modules (socket, subprocess) + - Code execution (eval, exec) + - System commands (os.system, os.popen) + +2. **Soft Rules** - Warns about sensitive operations: + - File deletion + - Network requests (requests, urllib) + +3. **LLM Review** - Semantic analysis of generated code + +4. **Sandbox Execution** - Isolated subprocess with limited permissions + +## Testing + +Run unit tests: +```bash +python -m pytest tests/ -v +``` + +## Supported File Operations + +The generated code can use these libraries: + +**Standard Library:** +- os, sys, pathlib - Path operations +- shutil - File copy/move +- json, csv - Data formats +- zipfile, tarfile - Compression +- And more... + +**Third-party Libraries:** +- Pillow - Image processing +- openpyxl - Excel files +- python-docx - Word documents +- PyPDF2 - PDF files +- chardet - Encoding detection + +## License + +MIT License + +## Contributing + +Contributions are welcome! Please feel free to submit issues and pull requests. + diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..1ddde60 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,2 @@ +# 应用模块 + diff --git a/app/agent.py b/app/agent.py new file mode 100644 index 0000000..43ab0c6 --- /dev/null +++ b/app/agent.py @@ -0,0 +1,526 @@ +""" +LocalAgent 主应用类 +管理 UI 状态切换和协调各模块工作流程 +""" + +import os +import tkinter as tk +from tkinter import messagebox +from pathlib import Path +from typing import Optional, Dict, Any, Tuple +import threading +import queue + +from llm.client import get_client, LLMClientError +from llm.prompts import ( + EXECUTION_PLAN_SYSTEM, EXECUTION_PLAN_USER, + CODE_GENERATION_SYSTEM, CODE_GENERATION_USER +) +from intent.classifier import classify_intent, IntentResult +from intent.labels import CHAT, EXECUTION +from safety.rule_checker import check_code_safety +from safety.llm_reviewer import review_code_safety, LLMReviewResult +from executor.sandbox_runner import SandboxRunner, ExecutionResult +from ui.chat_view import ChatView +from ui.task_guide_view import TaskGuideView +from ui.history_view import HistoryView +from history.manager import get_history_manager, HistoryManager + + +class LocalAgentApp: + """ + LocalAgent 主应用 + + 职责: + 1. 管理 UI 状态切换 + 2. 协调各模块工作流程 + 3. 处理用户交互 + """ + + def __init__(self, project_root: Path): + self.project_root: Path = project_root + self.workspace: Path = project_root / "workspace" + self.runner: SandboxRunner = SandboxRunner(str(self.workspace)) + self.history: HistoryManager = get_history_manager(self.workspace) + + # 当前任务状态 + self.current_task: Optional[Dict[str, Any]] = None + + # 线程通信队列 + self.result_queue: queue.Queue = queue.Queue() + + # UI 组件 + self.root: Optional[tk.Tk] = None + self.main_container: Optional[tk.Frame] = None + self.chat_view: Optional[ChatView] = None + self.task_view: Optional[TaskGuideView] = None + self.history_view: Optional[HistoryView] = None + + # 初始化 UI + self._init_ui() + + def _init_ui(self) -> None: + """初始化 UI""" + self.root = tk.Tk() + self.root.title("LocalAgent - 本地 AI 助手") + self.root.geometry("800x700") + self.root.configure(bg='#1e1e1e') + + # 设置窗口图标(如果有的话) + try: + self.root.iconbitmap(self.project_root / "icon.ico") + except: + pass + + # 主容器 + self.main_container = tk.Frame(self.root, bg='#1e1e1e') + self.main_container.pack(fill=tk.BOTH, expand=True) + + # 聊天视图 + self.chat_view = ChatView( + self.main_container, + self._on_user_input, + on_show_history=self._show_history + ) + + # 定期检查后台任务结果 + self._check_queue() + + def _check_queue(self) -> None: + """检查后台任务队列""" + try: + while True: + callback, args = self.result_queue.get_nowait() + callback(*args) + except queue.Empty: + pass + + # 每 100ms 检查一次 + self.root.after(100, self._check_queue) + + def _run_in_thread(self, func: callable, callback: callable, *args) -> None: + """在后台线程运行函数,完成后回调""" + def wrapper(): + try: + result = func(*args) + self.result_queue.put((callback, (result, None))) + except Exception as e: + self.result_queue.put((callback, (None, e))) + + thread = threading.Thread(target=wrapper, daemon=True) + thread.start() + + def _on_user_input(self, user_input: str) -> None: + """处理用户输入""" + # 显示用户消息 + self.chat_view.add_message(user_input, 'user') + self.chat_view.set_input_enabled(False) + self.chat_view.show_loading("正在分析您的需求") + + # 在后台线程进行意图识别 + self._run_in_thread( + classify_intent, + lambda result, error: self._on_intent_result(user_input, result, error), + user_input + ) + + def _on_intent_result(self, user_input: str, intent_result: Optional[IntentResult], error: Optional[Exception]) -> None: + """意图识别完成回调""" + self.chat_view.hide_loading() + + if error: + self.chat_view.add_message(f"意图识别失败: {str(error)}", 'error') + self.chat_view.set_input_enabled(True) + return + + if intent_result.label == CHAT: + # 对话模式 + self._handle_chat(user_input, intent_result) + else: + # 执行模式 + self._handle_execution(user_input, intent_result) + + def _handle_chat(self, user_input: str, intent_result: IntentResult) -> None: + """处理对话任务""" + self.chat_view.add_message( + f"识别为对话模式 (原因: {intent_result.reason})", + 'system' + ) + + # 开始流式消息 + self.chat_view.start_stream_message('assistant') + + # 在后台线程调用 LLM(流式) + def do_chat_stream(): + client = get_client() + model = os.getenv("GENERATION_MODEL_NAME") + + full_response = [] + for chunk in client.chat_stream( + messages=[{"role": "user", "content": user_input}], + model=model, + temperature=0.7, + max_tokens=2048, + timeout=300 + ): + full_response.append(chunk) + # 通过队列发送 chunk 到主线程更新 UI + self.result_queue.put((self._on_chat_chunk, (chunk,))) + + return ''.join(full_response) + + self._run_in_thread( + do_chat_stream, + self._on_chat_complete + ) + + def _on_chat_chunk(self, chunk: str): + """收到对话片段回调(主线程)""" + self.chat_view.append_stream_chunk(chunk) + + def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]): + """对话完成回调""" + self.chat_view.end_stream_message() + + if error: + self.chat_view.add_message(f"对话失败: {str(error)}", 'error') + + self.chat_view.set_input_enabled(True) + + def _handle_execution(self, user_input: str, intent_result: IntentResult): + """处理执行任务""" + self.chat_view.add_message( + f"识别为执行任务 (置信度: {intent_result.confidence:.0%})\n原因: {intent_result.reason}", + 'system' + ) + self.chat_view.show_loading("正在生成执行计划") + + # 保存用户输入和意图结果 + self.current_task = { + 'user_input': user_input, + 'intent_result': intent_result + } + + # 在后台线程生成执行计划 + self._run_in_thread( + self._generate_execution_plan, + self._on_plan_generated, + user_input + ) + + def _on_plan_generated(self, plan: Optional[str], error: Optional[Exception]): + """执行计划生成完成回调""" + if error: + self.chat_view.hide_loading() + self.chat_view.add_message(f"生成执行计划失败: {str(error)}", 'error') + self.chat_view.set_input_enabled(True) + self.current_task = None + return + + self.current_task['execution_plan'] = plan + self.chat_view.update_loading_text("正在生成执行代码") + + # 在后台线程生成代码 + self._run_in_thread( + self._generate_code, + self._on_code_generated, + self.current_task['user_input'], + plan + ) + + def _on_code_generated(self, result: tuple, error: Optional[Exception]): + """代码生成完成回调""" + if error: + self.chat_view.hide_loading() + self.chat_view.add_message(f"生成代码失败: {str(error)}", 'error') + self.chat_view.set_input_enabled(True) + self.current_task = None + return + + # result 可能是 (code, extract_error) 元组 + if isinstance(result, tuple): + code, extract_error = result + if extract_error: + self.chat_view.hide_loading() + self.chat_view.add_message(f"代码提取失败: {str(extract_error)}", 'error') + self.chat_view.set_input_enabled(True) + self.current_task = None + return + else: + code = result + + self.current_task['code'] = code + self.chat_view.update_loading_text("正在进行安全检查") + + # 硬规则检查(同步,很快) + rule_result = check_code_safety(code) + if not rule_result.passed: + self.chat_view.hide_loading() + violations = "\n".join(f" • {v}" for v in rule_result.violations) + self.chat_view.add_message( + f"安全检查未通过,任务已取消:\n{violations}", + 'error' + ) + self.chat_view.set_input_enabled(True) + self.current_task = None + return + + # 保存警告信息,传递给 LLM 审查 + self.current_task['warnings'] = rule_result.warnings + + # 在后台线程进行 LLM 安全审查 + self._run_in_thread( + lambda: review_code_safety( + self.current_task['user_input'], + self.current_task['execution_plan'], + code, + rule_result.warnings # 传递警告给 LLM + ), + self._on_safety_reviewed + ) + + def _on_safety_reviewed(self, review_result, error: Optional[Exception]): + """安全审查完成回调""" + self.chat_view.hide_loading() + + if error: + self.chat_view.add_message(f"安全审查失败: {str(error)}", 'error') + self.chat_view.set_input_enabled(True) + self.current_task = None + return + + if not review_result.passed: + self.chat_view.add_message( + f"安全审查未通过: {review_result.reason}", + 'error' + ) + self.chat_view.set_input_enabled(True) + self.current_task = None + return + + self.chat_view.add_message("安全检查通过,请确认执行", 'system') + + # 显示任务引导视图 + self._show_task_guide() + + def _generate_execution_plan(self, user_input: str) -> str: + """生成执行计划(使用流式传输)""" + client = get_client() + model = os.getenv("GENERATION_MODEL_NAME") + + # 使用流式传输,避免超时 + response = client.chat_stream_collect( + messages=[ + {"role": "system", "content": EXECUTION_PLAN_SYSTEM}, + {"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)} + ], + model=model, + temperature=0.3, + max_tokens=1024, + timeout=300 # 5分钟超时 + ) + + return response + + def _generate_code(self, user_input: str, execution_plan: str) -> tuple: + """生成执行代码(使用流式传输)""" + client = get_client() + model = os.getenv("GENERATION_MODEL_NAME") + + # 使用流式传输,避免超时 + response = client.chat_stream_collect( + messages=[ + {"role": "system", "content": CODE_GENERATION_SYSTEM}, + {"role": "user", "content": CODE_GENERATION_USER.format( + user_input=user_input, + execution_plan=execution_plan + )} + ], + model=model, + temperature=0.2, + max_tokens=4096, # 代码可能较长 + timeout=300 # 5分钟超时 + ) + + # 提取代码块,捕获可能的异常 + try: + code = self._extract_code(response) + return (code, None) + except ValueError as e: + return (None, e) + + def _extract_code(self, response: str) -> str: + """从 LLM 响应中提取代码""" + import re + + # 尝试提取 ```python ... ``` 代码块 + pattern = r'```python\s*(.*?)\s*```' + matches = re.findall(pattern, response, re.DOTALL) + + if matches: + return matches[0] + + # 尝试提取 ``` ... ``` 代码块 + pattern = r'```\s*(.*?)\s*```' + matches = re.findall(pattern, response, re.DOTALL) + + if matches: + return matches[0] + + # 如果没有代码块,检查是否看起来像 Python 代码 + # 简单检查:是否包含 def 或 import 语句 + if 'import ' in response or 'def ' in response: + return response + + # 无法提取代码,抛出异常 + raise ValueError( + "无法从 LLM 响应中提取代码块。\n" + f"响应内容预览: {response[:200]}..." + ) + + def _show_task_guide(self): + """显示任务引导视图""" + if not self.current_task: + return + + # 隐藏聊天视图 + self.chat_view.get_frame().pack_forget() + + # 创建任务引导视图 + self.task_view = TaskGuideView( + self.main_container, + on_execute=self._on_execute_task, + on_cancel=self._on_cancel_task, + workspace_path=self.workspace + ) + + # 设置内容 + self.task_view.set_intent_result( + self.current_task['intent_result'].reason, + self.current_task['intent_result'].confidence + ) + self.task_view.set_execution_plan(self.current_task['execution_plan']) + self.task_view.set_code(self.current_task['code']) + + # 显示 + self.task_view.show() + + def _on_execute_task(self): + """执行任务""" + if not self.current_task: + return + + self.task_view.set_buttons_enabled(False) + + # 在后台线程执行 + def do_execute(): + return self.runner.execute(self.current_task['code']) + + self._run_in_thread( + do_execute, + self._on_execution_complete + ) + + def _on_execution_complete(self, result: Optional[ExecutionResult], error: Optional[Exception]): + """执行完成回调""" + if error: + messagebox.showerror("执行错误", f"执行失败: {str(error)}") + else: + # 保存历史记录 + if self.current_task: + self.history.add_record( + task_id=result.task_id, + user_input=self.current_task['user_input'], + intent_label=self.current_task['intent_result'].label, + intent_confidence=self.current_task['intent_result'].confidence, + execution_plan=self.current_task['execution_plan'], + code=self.current_task['code'], + success=result.success, + duration_ms=result.duration_ms, + stdout=result.stdout, + stderr=result.stderr, + log_path=result.log_path + ) + + self._show_execution_result(result) + # 刷新输出文件列表 + if self.task_view: + self.task_view.refresh_output() + + self._back_to_chat() + + def _show_execution_result(self, result: ExecutionResult): + """显示执行结果""" + if result.success: + status = "执行成功" + else: + status = "执行失败" + + message = f"""{status} + +任务 ID: {result.task_id} +耗时: {result.duration_ms} ms + +输出: +{result.stdout if result.stdout else '(无输出)'} + +{f'错误信息: {result.stderr}' if result.stderr else ''} +""" + + if result.success: + # 成功时显示结果并询问是否打开输出目录 + open_output = messagebox.askyesno( + "执行结果", + message + "\n\n是否打开输出文件夹?" + ) + if open_output: + os.startfile(str(self.workspace / "output")) + else: + # 失败时显示结果并询问是否打开日志 + open_log = messagebox.askyesno( + "执行结果", + message + "\n\n是否打开日志文件查看详情?" + ) + if open_log and result.log_path: + os.startfile(result.log_path) + + def _on_cancel_task(self): + """取消任务""" + self.current_task = None + self._back_to_chat() + + def _back_to_chat(self): + """返回聊天视图""" + if self.task_view: + self.task_view.hide() + self.task_view = None + + self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + self.chat_view.set_input_enabled(True) + self.current_task = None + + def _show_history(self): + """显示历史记录视图""" + # 隐藏聊天视图 + self.chat_view.get_frame().pack_forget() + + # 创建历史记录视图 + self.history_view = HistoryView( + self.main_container, + self.history, + on_back=self._hide_history + ) + self.history_view.show() + + def _hide_history(self): + """隐藏历史记录视图,返回聊天""" + if self.history_view: + self.history_view.hide() + self.history_view = None + + self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + def run(self): + """运行应用""" + self.root.mainloop() + diff --git a/debug_env.py b/debug_env.py deleted file mode 100644 index 29d0eee..0000000 --- a/debug_env.py +++ /dev/null @@ -1,25 +0,0 @@ -"""调试脚本""" -from pathlib import Path -from dotenv import load_dotenv -import os - -ENV_PATH = Path(__file__).parent / ".env" - -print(f"ENV_PATH: {ENV_PATH}") -print(f"ENV_PATH exists: {ENV_PATH.exists()}") - -# 读取文件内容 -if ENV_PATH.exists(): - print(f"File content:") - print(ENV_PATH.read_text(encoding='utf-8')) - -# 加载环境变量 -result = load_dotenv(ENV_PATH) -print(f"load_dotenv result: {result}") - -# 检查环境变量 -print(f"LLM_API_URL: {os.getenv('LLM_API_URL')}") -print(f"LLM_API_KEY: {os.getenv('LLM_API_KEY')}") -print(f"INTENT_MODEL_NAME: {os.getenv('INTENT_MODEL_NAME')}") -print(f"GENERATION_MODEL_NAME: {os.getenv('GENERATION_MODEL_NAME')}") - diff --git a/history/__init__.py b/history/__init__.py new file mode 100644 index 0000000..dbd40ba --- /dev/null +++ b/history/__init__.py @@ -0,0 +1,2 @@ +# 历史记录模块 + diff --git a/history/manager.py b/history/manager.py new file mode 100644 index 0000000..373ad1d --- /dev/null +++ b/history/manager.py @@ -0,0 +1,189 @@ +""" +任务历史记录管理器 +保存和加载任务执行历史 +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Optional, List +from dataclasses import dataclass, asdict + + +@dataclass +class TaskRecord: + """任务记录""" + task_id: str + timestamp: str + user_input: str + intent_label: str + intent_confidence: float + execution_plan: str + code: str + success: bool + duration_ms: int + stdout: str + stderr: str + log_path: str + + +class HistoryManager: + """ + 历史记录管理器 + + 将任务历史保存为 JSON 文件 + """ + + MAX_HISTORY_SIZE = 100 # 最多保存 100 条记录 + + def __init__(self, workspace_path: Optional[Path] = None): + if workspace_path: + self.workspace = workspace_path + else: + self.workspace = Path(__file__).parent.parent / "workspace" + + self.history_file = self.workspace / "history.json" + self._history: List[TaskRecord] = [] + self._load() + + def _load(self): + """从文件加载历史记录""" + if self.history_file.exists(): + try: + with open(self.history_file, 'r', encoding='utf-8') as f: + data = json.load(f) + self._history = [TaskRecord(**record) for record in data] + except (json.JSONDecodeError, TypeError, KeyError) as e: + print(f"[警告] 加载历史记录失败: {e}") + self._history = [] + else: + self._history = [] + + def _save(self): + """保存历史记录到文件""" + try: + # 确保目录存在 + self.history_file.parent.mkdir(parents=True, exist_ok=True) + + with open(self.history_file, 'w', encoding='utf-8') as f: + data = [asdict(record) for record in self._history] + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"[警告] 保存历史记录失败: {e}") + + def add_record( + self, + task_id: str, + user_input: str, + intent_label: str, + intent_confidence: float, + execution_plan: str, + code: str, + success: bool, + duration_ms: int, + stdout: str = "", + stderr: str = "", + log_path: str = "" + ) -> TaskRecord: + """ + 添加一条任务记录 + + Args: + task_id: 任务 ID + user_input: 用户输入 + intent_label: 意图标签 + intent_confidence: 意图置信度 + execution_plan: 执行计划 + code: 生成的代码 + success: 是否执行成功 + duration_ms: 执行耗时(毫秒) + stdout: 标准输出 + stderr: 标准错误 + log_path: 日志文件路径 + + Returns: + TaskRecord: 创建的记录 + """ + record = TaskRecord( + task_id=task_id, + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + user_input=user_input, + intent_label=intent_label, + intent_confidence=intent_confidence, + execution_plan=execution_plan, + code=code, + success=success, + duration_ms=duration_ms, + stdout=stdout, + stderr=stderr, + log_path=log_path + ) + + # 添加到列表开头(最新的在前) + self._history.insert(0, record) + + # 限制历史记录数量 + if len(self._history) > self.MAX_HISTORY_SIZE: + self._history = self._history[:self.MAX_HISTORY_SIZE] + + # 保存 + self._save() + + return record + + def get_all(self) -> List[TaskRecord]: + """获取所有历史记录""" + return self._history.copy() + + def get_recent(self, count: int = 10) -> List[TaskRecord]: + """获取最近的 N 条记录""" + return self._history[:count] + + def get_by_id(self, task_id: str) -> Optional[TaskRecord]: + """根据任务 ID 获取记录""" + for record in self._history: + if record.task_id == task_id: + return record + return None + + def clear(self): + """清空历史记录""" + self._history = [] + self._save() + + def get_stats(self) -> dict: + """获取统计信息""" + if not self._history: + return { + 'total': 0, + 'success': 0, + 'failed': 0, + 'success_rate': 0.0, + 'avg_duration_ms': 0 + } + + total = len(self._history) + success = sum(1 for r in self._history if r.success) + failed = total - success + avg_duration = sum(r.duration_ms for r in self._history) / total + + return { + 'total': total, + 'success': success, + 'failed': failed, + 'success_rate': success / total if total > 0 else 0.0, + 'avg_duration_ms': int(avg_duration) + } + + +# 全局单例 +_manager: Optional[HistoryManager] = None + + +def get_history_manager(workspace_path: Optional[Path] = None) -> HistoryManager: + """获取历史记录管理器单例""" + global _manager + if _manager is None: + _manager = HistoryManager(workspace_path) + return _manager + diff --git a/llm/client.py b/llm/client.py index 941b821..94f04b9 100644 --- a/llm/client.py +++ b/llm/client.py @@ -2,13 +2,15 @@ LLM 统一调用客户端 所有模型通过 SiliconFlow API 调用 支持流式和非流式两种模式 +支持自动重试机制 """ import os import json +import time import requests from pathlib import Path -from typing import Optional, Generator, Callable +from typing import Optional, Generator, Callable, List, Dict, Any from dotenv import load_dotenv # 获取项目根目录 @@ -40,29 +42,78 @@ class LLMClient: model="Qwen/Qwen2.5-7B-Instruct" ): print(chunk, end="", flush=True) + + 特性: + - 自动重试:网络错误时自动重试(默认3次) + - 指数退避:重试间隔逐渐增加 """ - def __init__(self): + # 重试配置 + DEFAULT_MAX_RETRIES = 3 + DEFAULT_RETRY_DELAY = 1.0 # 初始重试延迟(秒) + DEFAULT_RETRY_BACKOFF = 2.0 # 退避倍数 + + def __init__(self, max_retries: int = DEFAULT_MAX_RETRIES): load_dotenv(ENV_PATH) self.api_url = os.getenv("LLM_API_URL") self.api_key = os.getenv("LLM_API_KEY") + self.max_retries = max_retries if not self.api_url: raise LLMClientError("未配置 LLM_API_URL,请检查 .env 文件") if not self.api_key or self.api_key == "your_api_key_here": raise LLMClientError("未配置有效的 LLM_API_KEY,请检查 .env 文件") + def _should_retry(self, exception: Exception) -> bool: + """判断是否应该重试""" + # 网络连接错误、超时错误可以重试 + if isinstance(exception, (requests.exceptions.ConnectionError, + requests.exceptions.Timeout)): + return True + # 服务器错误(5xx)可以重试 + if isinstance(exception, LLMClientError): + error_msg = str(exception) + if "状态码: 5" in error_msg or "502" in error_msg or "503" in error_msg or "504" in error_msg: + return True + return False + + def _do_request_with_retry( + self, + request_func: Callable, + operation_name: str = "请求" + ): + """带重试的请求执行""" + last_exception = None + + for attempt in range(self.max_retries + 1): + try: + return request_func() + except Exception as e: + last_exception = e + + # 判断是否应该重试 + if attempt < self.max_retries and self._should_retry(e): + delay = self.DEFAULT_RETRY_DELAY * (self.DEFAULT_RETRY_BACKOFF ** attempt) + print(f"[重试] {operation_name}失败,{delay:.1f}秒后重试 ({attempt + 1}/{self.max_retries})...") + time.sleep(delay) + continue + else: + raise + + # 所有重试都失败 + raise last_exception + def chat( self, - messages: list[dict], + messages: List[Dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int = 1024, timeout: int = 180 ) -> str: """ - 调用 LLM 进行对话(非流式) + 调用 LLM 进行对话(非流式,带自动重试) Args: messages: 消息列表 @@ -74,60 +125,63 @@ class LLMClient: Returns: LLM 生成的文本内容 """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - payload = { - "model": model, - "messages": messages, - "stream": False, - "temperature": temperature, - "max_tokens": max_tokens - } - - try: - response = requests.post( - self.api_url, - headers=headers, - json=payload, - timeout=timeout - ) - except requests.exceptions.Timeout: - raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") - except requests.exceptions.ConnectionError: - raise LLMClientError("网络连接失败,请检查网络设置") - except requests.exceptions.RequestException as e: - raise LLMClientError(f"网络请求异常: {str(e)}") - - if response.status_code != 200: - error_msg = f"API 返回错误 (状态码: {response.status_code})" + def do_request(): + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "messages": messages, + "stream": False, + "temperature": temperature, + "max_tokens": max_tokens + } + try: - error_detail = response.json() - if "error" in error_detail: - error_msg += f": {error_detail['error']}" - except: - error_msg += f": {response.text[:200]}" - raise LLMClientError(error_msg) + response = requests.post( + self.api_url, + headers=headers, + json=payload, + timeout=timeout + ) + except requests.exceptions.Timeout: + raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") + except requests.exceptions.ConnectionError: + raise LLMClientError("网络连接失败,请检查网络设置") + except requests.exceptions.RequestException as e: + raise LLMClientError(f"网络请求异常: {str(e)}") + + if response.status_code != 200: + error_msg = f"API 返回错误 (状态码: {response.status_code})" + try: + error_detail = response.json() + if "error" in error_detail: + error_msg += f": {error_detail['error']}" + except: + error_msg += f": {response.text[:200]}" + raise LLMClientError(error_msg) + + try: + result = response.json() + content = result["choices"][0]["message"]["content"] + return content + except (KeyError, IndexError, TypeError) as e: + raise LLMClientError(f"解析 API 响应失败: {str(e)}") - try: - result = response.json() - content = result["choices"][0]["message"]["content"] - return content - except (KeyError, IndexError, TypeError) as e: - raise LLMClientError(f"解析 API 响应失败: {str(e)}") + return self._do_request_with_retry(do_request, "LLM调用") def chat_stream( self, - messages: list[dict], + messages: List[Dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int = 2048, timeout: int = 180 ) -> Generator[str, None, None]: """ - 调用 LLM 进行对话(流式) + 调用 LLM 进行对话(流式,带自动重试) Args: messages: 消息列表 @@ -139,43 +193,49 @@ class LLMClient: Yields: 逐个返回生成的文本片段 """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - payload = { - "model": model, - "messages": messages, - "stream": True, - "temperature": temperature, - "max_tokens": max_tokens - } - - try: - response = requests.post( - self.api_url, - headers=headers, - json=payload, - timeout=timeout, - stream=True - ) - except requests.exceptions.Timeout: - raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") - except requests.exceptions.ConnectionError: - raise LLMClientError("网络连接失败,请检查网络设置") - except requests.exceptions.RequestException as e: - raise LLMClientError(f"网络请求异常: {str(e)}") - - if response.status_code != 200: - error_msg = f"API 返回错误 (状态码: {response.status_code})" + def do_request(): + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "messages": messages, + "stream": True, + "temperature": temperature, + "max_tokens": max_tokens + } + try: - error_detail = response.json() - if "error" in error_detail: - error_msg += f": {error_detail['error']}" - except: - error_msg += f": {response.text[:200]}" - raise LLMClientError(error_msg) + response = requests.post( + self.api_url, + headers=headers, + json=payload, + timeout=timeout, + stream=True + ) + except requests.exceptions.Timeout: + raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试") + except requests.exceptions.ConnectionError: + raise LLMClientError("网络连接失败,请检查网络设置") + except requests.exceptions.RequestException as e: + raise LLMClientError(f"网络请求异常: {str(e)}") + + if response.status_code != 200: + error_msg = f"API 返回错误 (状态码: {response.status_code})" + try: + error_detail = response.json() + if "error" in error_detail: + error_msg += f": {error_detail['error']}" + except: + error_msg += f": {response.text[:200]}" + raise LLMClientError(error_msg) + + return response + + # 流式请求的重试只在建立连接阶段 + response = self._do_request_with_retry(do_request, "流式LLM调用") # 解析 SSE 流 for line in response.iter_lines(): @@ -197,7 +257,7 @@ class LLMClient: def chat_stream_collect( self, - messages: list[dict], + messages: List[Dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int = 2048, diff --git a/llm/prompts.py b/llm/prompts.py index 4a2ff19..0132def 100644 --- a/llm/prompts.py +++ b/llm/prompts.py @@ -108,7 +108,8 @@ import shutil from pathlib import Path # 工作目录(固定,不要修改) -WORKSPACE = Path(__file__).parent +# 代码保存在 workspace/codes/ 目录,向上一级是 workspace +WORKSPACE = Path(__file__).parent.parent INPUT_DIR = WORKSPACE / "input" OUTPUT_DIR = WORKSPACE / "output" diff --git a/main.py b/main.py index d031df4..6b4e3f6 100644 --- a/main.py +++ b/main.py @@ -37,10 +37,7 @@ import sys import tkinter as tk from tkinter import messagebox from pathlib import Path -from typing import Optional from dotenv import load_dotenv -import threading -import queue # 确保项目根目录在 Python 路径中 PROJECT_ROOT = Path(__file__).parent @@ -50,435 +47,11 @@ sys.path.insert(0, str(PROJECT_ROOT)) # 在导入其他模块之前先加载环境变量 load_dotenv(ENV_PATH) -from llm.client import get_client, LLMClientError -from llm.prompts import ( - EXECUTION_PLAN_SYSTEM, EXECUTION_PLAN_USER, - CODE_GENERATION_SYSTEM, CODE_GENERATION_USER -) -from intent.classifier import classify_intent, IntentResult -from intent.labels import CHAT, EXECUTION -from safety.rule_checker import check_code_safety -from safety.llm_reviewer import review_code_safety -from executor.sandbox_runner import SandboxRunner, ExecutionResult -from ui.chat_view import ChatView -from ui.task_guide_view import TaskGuideView +from app.agent import LocalAgentApp -class LocalAgentApp: - """ - LocalAgent 主应用 - - 职责: - 1. 管理 UI 状态切换 - 2. 协调各模块工作流程 - 3. 处理用户交互 - """ - - def __init__(self): - self.workspace = PROJECT_ROOT / "workspace" - self.runner = SandboxRunner(str(self.workspace)) - - # 当前任务状态 - self.current_task: Optional[dict] = None - - # 线程通信队列 - self.result_queue = queue.Queue() - - # 初始化 UI - self._init_ui() - - def _init_ui(self): - """初始化 UI""" - self.root = tk.Tk() - self.root.title("LocalAgent - 本地 AI 助手") - self.root.geometry("800x700") - self.root.configure(bg='#1e1e1e') - - # 设置窗口图标(如果有的话) - try: - self.root.iconbitmap(PROJECT_ROOT / "icon.ico") - except: - pass - - # 主容器 - self.main_container = tk.Frame(self.root, bg='#1e1e1e') - self.main_container.pack(fill=tk.BOTH, expand=True) - - # 聊天视图 - self.chat_view = ChatView(self.main_container, self._on_user_input) - - # 任务引导视图(初始隐藏) - self.task_view: Optional[TaskGuideView] = None - - # 定期检查后台任务结果 - self._check_queue() - - def _check_queue(self): - """检查后台任务队列""" - try: - while True: - callback, args = self.result_queue.get_nowait() - callback(*args) - except queue.Empty: - pass - - # 每 100ms 检查一次 - self.root.after(100, self._check_queue) - - def _run_in_thread(self, func, callback, *args): - """在后台线程运行函数,完成后回调""" - def wrapper(): - try: - result = func(*args) - self.result_queue.put((callback, (result, None))) - except Exception as e: - self.result_queue.put((callback, (None, e))) - - thread = threading.Thread(target=wrapper, daemon=True) - thread.start() - - def _on_user_input(self, user_input: str): - """处理用户输入""" - # 显示用户消息 - self.chat_view.add_message(user_input, 'user') - self.chat_view.set_input_enabled(False) - self.chat_view.add_message("正在分析您的需求...", 'system') - - # 在后台线程进行意图识别 - self._run_in_thread( - classify_intent, - lambda result, error: self._on_intent_result(user_input, result, error), - user_input - ) - - def _on_intent_result(self, user_input: str, intent_result: Optional[IntentResult], error: Optional[Exception]): - """意图识别完成回调""" - if error: - self.chat_view.add_message(f"意图识别失败: {str(error)}", 'error') - self.chat_view.set_input_enabled(True) - return - - if intent_result.label == CHAT: - # 对话模式 - self._handle_chat(user_input, intent_result) - else: - # 执行模式 - self._handle_execution(user_input, intent_result) - - def _handle_chat(self, user_input: str, intent_result: IntentResult): - """处理对话任务""" - self.chat_view.add_message( - f"识别为对话模式 (原因: {intent_result.reason})", - 'system' - ) - - # 开始流式消息 - self.chat_view.start_stream_message('assistant') - - # 在后台线程调用 LLM(流式) - def do_chat_stream(): - client = get_client() - model = os.getenv("GENERATION_MODEL_NAME") - - full_response = [] - for chunk in client.chat_stream( - messages=[{"role": "user", "content": user_input}], - model=model, - temperature=0.7, - max_tokens=2048, - timeout=300 - ): - full_response.append(chunk) - # 通过队列发送 chunk 到主线程更新 UI - self.result_queue.put((self._on_chat_chunk, (chunk,))) - - return ''.join(full_response) - - self._run_in_thread( - do_chat_stream, - self._on_chat_complete - ) - - def _on_chat_chunk(self, chunk: str): - """收到对话片段回调(主线程)""" - self.chat_view.append_stream_chunk(chunk) - - def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]): - """对话完成回调""" - self.chat_view.end_stream_message() - - if error: - self.chat_view.add_message(f"对话失败: {str(error)}", 'error') - - self.chat_view.set_input_enabled(True) - - def _handle_execution(self, user_input: str, intent_result: IntentResult): - """处理执行任务""" - self.chat_view.add_message( - f"识别为执行任务 (置信度: {intent_result.confidence:.0%})\n原因: {intent_result.reason}", - 'system' - ) - self.chat_view.add_message("正在生成执行计划...", 'system') - - # 保存用户输入和意图结果 - self.current_task = { - 'user_input': user_input, - 'intent_result': intent_result - } - - # 在后台线程生成执行计划 - self._run_in_thread( - self._generate_execution_plan, - self._on_plan_generated, - user_input - ) - - def _on_plan_generated(self, plan: Optional[str], error: Optional[Exception]): - """执行计划生成完成回调""" - if error: - self.chat_view.add_message(f"生成执行计划失败: {str(error)}", 'error') - self.chat_view.set_input_enabled(True) - self.current_task = None - return - - self.current_task['execution_plan'] = plan - self.chat_view.add_message("正在生成执行代码...", 'system') - - # 在后台线程生成代码 - self._run_in_thread( - self._generate_code, - self._on_code_generated, - self.current_task['user_input'], - plan - ) - - def _on_code_generated(self, code: Optional[str], error: Optional[Exception]): - """代码生成完成回调""" - if error: - self.chat_view.add_message(f"生成代码失败: {str(error)}", 'error') - self.chat_view.set_input_enabled(True) - self.current_task = None - return - - self.current_task['code'] = code - self.chat_view.add_message("正在进行安全检查...", 'system') - - # 硬规则检查(同步,很快) - rule_result = check_code_safety(code) - if not rule_result.passed: - violations = "\n".join(f" • {v}" for v in rule_result.violations) - self.chat_view.add_message( - f"安全检查未通过,任务已取消:\n{violations}", - 'error' - ) - self.chat_view.set_input_enabled(True) - self.current_task = None - return - - # 保存警告信息,传递给 LLM 审查 - self.current_task['warnings'] = rule_result.warnings - - # 在后台线程进行 LLM 安全审查 - self._run_in_thread( - lambda: review_code_safety( - self.current_task['user_input'], - self.current_task['execution_plan'], - code, - rule_result.warnings # 传递警告给 LLM - ), - self._on_safety_reviewed - ) - - def _on_safety_reviewed(self, review_result, error: Optional[Exception]): - """安全审查完成回调""" - if error: - self.chat_view.add_message(f"安全审查失败: {str(error)}", 'error') - self.chat_view.set_input_enabled(True) - self.current_task = None - return - - if not review_result.passed: - self.chat_view.add_message( - f"安全审查未通过: {review_result.reason}", - 'error' - ) - self.chat_view.set_input_enabled(True) - self.current_task = None - return - - self.chat_view.add_message("安全检查通过,请确认执行", 'system') - - # 显示任务引导视图 - self._show_task_guide() - - def _generate_execution_plan(self, user_input: str) -> str: - """生成执行计划(使用流式传输)""" - client = get_client() - model = os.getenv("GENERATION_MODEL_NAME") - - # 使用流式传输,避免超时 - response = client.chat_stream_collect( - messages=[ - {"role": "system", "content": EXECUTION_PLAN_SYSTEM}, - {"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)} - ], - model=model, - temperature=0.3, - max_tokens=1024, - timeout=300 # 5分钟超时 - ) - - return response - - def _generate_code(self, user_input: str, execution_plan: str) -> str: - """生成执行代码(使用流式传输)""" - client = get_client() - model = os.getenv("GENERATION_MODEL_NAME") - - # 使用流式传输,避免超时 - response = client.chat_stream_collect( - messages=[ - {"role": "system", "content": CODE_GENERATION_SYSTEM}, - {"role": "user", "content": CODE_GENERATION_USER.format( - user_input=user_input, - execution_plan=execution_plan - )} - ], - model=model, - temperature=0.2, - max_tokens=4096, # 代码可能较长 - timeout=300 # 5分钟超时 - ) - - # 提取代码块 - code = self._extract_code(response) - return code - - def _extract_code(self, response: str) -> str: - """从 LLM 响应中提取代码""" - import re - - # 尝试提取 ```python ... ``` 代码块 - pattern = r'```python\s*(.*?)\s*```' - matches = re.findall(pattern, response, re.DOTALL) - - if matches: - return matches[0] - - # 尝试提取 ``` ... ``` 代码块 - pattern = r'```\s*(.*?)\s*```' - matches = re.findall(pattern, response, re.DOTALL) - - if matches: - return matches[0] - - # 如果没有代码块,返回原始响应 - return response - - def _show_task_guide(self): - """显示任务引导视图""" - if not self.current_task: - return - - # 隐藏聊天视图 - self.chat_view.get_frame().pack_forget() - - # 创建任务引导视图 - self.task_view = TaskGuideView( - self.main_container, - on_execute=self._on_execute_task, - on_cancel=self._on_cancel_task, - workspace_path=self.workspace - ) - - # 设置内容 - self.task_view.set_intent_result( - self.current_task['intent_result'].reason, - self.current_task['intent_result'].confidence - ) - self.task_view.set_execution_plan(self.current_task['execution_plan']) - - # 显示 - self.task_view.show() - - def _on_execute_task(self): - """执行任务""" - if not self.current_task: - return - - self.task_view.set_buttons_enabled(False) - - # 在后台线程执行 - def do_execute(): - return self.runner.execute(self.current_task['code']) - - self._run_in_thread( - do_execute, - self._on_execution_complete - ) - - def _on_execution_complete(self, result: Optional[ExecutionResult], error: Optional[Exception]): - """执行完成回调""" - if error: - messagebox.showerror("执行错误", f"执行失败: {str(error)}") - else: - self._show_execution_result(result) - # 刷新输出文件列表 - if self.task_view: - self.task_view.refresh_output() - - self._back_to_chat() - - def _show_execution_result(self, result: ExecutionResult): - """显示执行结果""" - if result.success: - status = "执行成功" - else: - status = "执行失败" - - message = f"""{status} - -任务 ID: {result.task_id} -耗时: {result.duration_ms} ms -日志文件: {result.log_path} - -输出: -{result.stdout if result.stdout else '(无输出)'} - -{f'错误信息: {result.stderr}' if result.stderr else ''} -""" - - if result.success: - messagebox.showinfo("执行结果", message) - # 打开 output 目录 - os.startfile(str(self.workspace / "output")) - else: - messagebox.showerror("执行结果", message) - - def _on_cancel_task(self): - """取消任务""" - self.current_task = None - self._back_to_chat() - - def _back_to_chat(self): - """返回聊天视图""" - if self.task_view: - self.task_view.hide() - self.task_view = None - - self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - self.chat_view.set_input_enabled(True) - self.current_task = None - - def run(self): - """运行应用""" - self.root.mainloop() - - -def check_environment(): +def check_environment() -> bool: """检查运行环境""" - load_dotenv(ENV_PATH) - api_key = os.getenv("LLM_API_KEY") if not api_key or api_key == "your_api_key_here": @@ -510,6 +83,17 @@ def check_environment(): return True +def setup_workspace(): + """创建工作目录""" + workspace = PROJECT_ROOT / "workspace" + (workspace / "input").mkdir(parents=True, exist_ok=True) + (workspace / "output").mkdir(parents=True, exist_ok=True) + (workspace / "logs").mkdir(parents=True, exist_ok=True) + (workspace / "codes").mkdir(parents=True, exist_ok=True) + + return workspace + + def main(): """主入口""" print("=" * 50) @@ -521,19 +105,17 @@ def main(): sys.exit(1) # 创建工作目录 - workspace = PROJECT_ROOT / "workspace" - (workspace / "input").mkdir(parents=True, exist_ok=True) - (workspace / "output").mkdir(parents=True, exist_ok=True) - (workspace / "logs").mkdir(parents=True, exist_ok=True) + workspace = setup_workspace() print(f"工作目录: {workspace}") print(f"输入目录: {workspace / 'input'}") print(f"输出目录: {workspace / 'output'}") print(f"日志目录: {workspace / 'logs'}") + print(f"代码目录: {workspace / 'codes'}") print("=" * 50) # 启动应用 - app = LocalAgentApp() + app = LocalAgentApp(PROJECT_ROOT) app.run() diff --git a/requirements.txt b/requirements.txt index 9c2eec5..4d77158 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,6 @@ openpyxl>=3.1.0 # Excel 处理 python-docx>=1.0.0 # Word 文档处理 PyPDF2>=3.0.0 # PDF 处理 chardet>=5.0.0 # 文件编码检测 + +# 测试依赖(可选) +pytest>=7.0.0 # 单元测试框架 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..57f5731 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# 测试模块 + diff --git a/tests/test_history_manager.py b/tests/test_history_manager.py new file mode 100644 index 0000000..ab8f58d --- /dev/null +++ b/tests/test_history_manager.py @@ -0,0 +1,235 @@ +""" +历史记录管理器单元测试 +""" + +import unittest +import sys +import tempfile +import shutil +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from history.manager import HistoryManager, TaskRecord + + +class TestHistoryManager(unittest.TestCase): + """历史记录管理器测试""" + + def setUp(self): + """创建临时目录用于测试""" + self.temp_dir = Path(tempfile.mkdtemp()) + self.manager = HistoryManager(self.temp_dir) + + def tearDown(self): + """清理临时目录""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_add_record(self): + """测试添加记录""" + record = self.manager.add_record( + task_id="test_001", + user_input="复制文件", + intent_label="execution", + intent_confidence=0.95, + execution_plan="复制所有文件", + code="shutil.copy(...)", + success=True, + duration_ms=100 + ) + + self.assertEqual(record.task_id, "test_001") + self.assertEqual(record.user_input, "复制文件") + self.assertTrue(record.success) + + def test_get_all(self): + """测试获取所有记录""" + # 添加多条记录 + for i in range(3): + self.manager.add_record( + task_id=f"test_{i:03d}", + user_input=f"任务 {i}", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + records = self.manager.get_all() + self.assertEqual(len(records), 3) + + def test_get_recent(self): + """测试获取最近记录""" + # 添加 5 条记录 + for i in range(5): + self.manager.add_record( + task_id=f"test_{i:03d}", + user_input=f"任务 {i}", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + # 获取最近 3 条 + recent = self.manager.get_recent(3) + self.assertEqual(len(recent), 3) + # 最新的在前 + self.assertEqual(recent[0].task_id, "test_004") + + def test_get_by_id(self): + """测试根据 ID 获取记录""" + self.manager.add_record( + task_id="unique_id", + user_input="测试", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + record = self.manager.get_by_id("unique_id") + self.assertIsNotNone(record) + self.assertEqual(record.task_id, "unique_id") + + # 不存在的 ID + not_found = self.manager.get_by_id("not_exist") + self.assertIsNone(not_found) + + def test_clear(self): + """测试清空记录""" + # 添加记录 + self.manager.add_record( + task_id="test", + user_input="测试", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + self.assertEqual(len(self.manager.get_all()), 1) + + # 清空 + self.manager.clear() + self.assertEqual(len(self.manager.get_all()), 0) + + def test_get_stats(self): + """测试统计信息""" + # 添加成功和失败的记录 + self.manager.add_record( + task_id="success_1", + user_input="成功任务", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + self.manager.add_record( + task_id="success_2", + user_input="成功任务2", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=200 + ) + self.manager.add_record( + task_id="failed_1", + user_input="失败任务", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=False, + duration_ms=50 + ) + + stats = self.manager.get_stats() + self.assertEqual(stats['total'], 3) + self.assertEqual(stats['success'], 2) + self.assertEqual(stats['failed'], 1) + self.assertAlmostEqual(stats['success_rate'], 2/3) + + def test_persistence(self): + """测试持久化""" + # 添加记录 + self.manager.add_record( + task_id="persist_test", + user_input="持久化测试", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + # 创建新的管理器实例(模拟重启) + new_manager = HistoryManager(self.temp_dir) + + # 应该能读取到之前的记录 + records = new_manager.get_all() + self.assertEqual(len(records), 1) + self.assertEqual(records[0].task_id, "persist_test") + + def test_max_history_size(self): + """测试历史记录数量限制""" + # 添加超过限制的记录 + for i in range(HistoryManager.MAX_HISTORY_SIZE + 10): + self.manager.add_record( + task_id=f"test_{i:03d}", + user_input=f"任务 {i}", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100 + ) + + # 应该只保留最大数量 + records = self.manager.get_all() + self.assertEqual(len(records), HistoryManager.MAX_HISTORY_SIZE) + + +class TestTaskRecord(unittest.TestCase): + """任务记录数据类测试""" + + def test_create_record(self): + """测试创建记录""" + record = TaskRecord( + task_id="test", + timestamp="2024-01-01 12:00:00", + user_input="测试", + intent_label="execution", + intent_confidence=0.9, + execution_plan="计划", + code="代码", + success=True, + duration_ms=100, + stdout="输出", + stderr="", + log_path="/path/to/log" + ) + + self.assertEqual(record.task_id, "test") + self.assertTrue(record.success) + self.assertEqual(record.duration_ms, 100) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_intent_classifier.py b/tests/test_intent_classifier.py new file mode 100644 index 0000000..89df003 --- /dev/null +++ b/tests/test_intent_classifier.py @@ -0,0 +1,94 @@ +""" +意图分类器单元测试 +""" + +import unittest +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from intent.labels import CHAT, EXECUTION, VALID_LABELS, EXECUTION_CONFIDENCE_THRESHOLD + + +class TestIntentLabels(unittest.TestCase): + """意图标签测试""" + + def test_labels_defined(self): + """测试标签已定义""" + self.assertEqual(CHAT, "chat") + self.assertEqual(EXECUTION, "execution") + + def test_valid_labels(self): + """测试有效标签集合""" + self.assertIn(CHAT, VALID_LABELS) + self.assertIn(EXECUTION, VALID_LABELS) + self.assertEqual(len(VALID_LABELS), 2) + + def test_confidence_threshold(self): + """测试置信度阈值""" + self.assertGreater(EXECUTION_CONFIDENCE_THRESHOLD, 0) + self.assertLessEqual(EXECUTION_CONFIDENCE_THRESHOLD, 1) + + +class TestIntentClassifierParsing(unittest.TestCase): + """意图分类器解析测试(不需要 API)""" + + def setUp(self): + from intent.classifier import IntentClassifier + self.classifier = IntentClassifier() + + def test_parse_valid_chat_response(self): + """测试解析有效的 chat 响应""" + response = '{"label": "chat", "confidence": 0.95, "reason": "这是一个问答"}' + result = self.classifier._parse_response(response) + self.assertEqual(result.label, CHAT) + self.assertEqual(result.confidence, 0.95) + self.assertEqual(result.reason, "这是一个问答") + + def test_parse_valid_execution_response(self): + """测试解析有效的 execution 响应""" + response = '{"label": "execution", "confidence": 0.9, "reason": "需要复制文件"}' + result = self.classifier._parse_response(response) + self.assertEqual(result.label, EXECUTION) + self.assertEqual(result.confidence, 0.9) + + def test_parse_low_confidence_execution(self): + """测试低置信度的 execution 降级为 chat""" + response = '{"label": "execution", "confidence": 0.5, "reason": "不太确定"}' + result = self.classifier._parse_response(response) + # 低于阈值应该降级为 chat + self.assertEqual(result.label, CHAT) + + def test_parse_invalid_label(self): + """测试无效标签降级为 chat""" + response = '{"label": "unknown", "confidence": 0.9, "reason": "测试"}' + result = self.classifier._parse_response(response) + self.assertEqual(result.label, CHAT) + + def test_parse_invalid_json(self): + """测试无效 JSON 降级为 chat""" + response = 'not a json' + result = self.classifier._parse_response(response) + self.assertEqual(result.label, CHAT) + self.assertEqual(result.confidence, 0.0) + + def test_extract_json_with_prefix(self): + """测试从带前缀的文本中提取 JSON""" + text = 'Here is the result: {"label": "chat", "confidence": 0.8, "reason": "test"}' + json_str = self.classifier._extract_json(text) + self.assertTrue(json_str.startswith('{')) + self.assertTrue(json_str.endswith('}')) + + def test_extract_json_with_suffix(self): + """测试从带后缀的文本中提取 JSON""" + text = '{"label": "chat", "confidence": 0.8, "reason": "test"} That is my answer.' + json_str = self.classifier._extract_json(text) + self.assertTrue(json_str.startswith('{')) + self.assertTrue(json_str.endswith('}')) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_rule_checker.py b/tests/test_rule_checker.py new file mode 100644 index 0000000..40f0ef6 --- /dev/null +++ b/tests/test_rule_checker.py @@ -0,0 +1,160 @@ +""" +安全检查器单元测试 +""" + +import unittest +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from safety.rule_checker import RuleChecker, check_code_safety + + +class TestRuleChecker(unittest.TestCase): + """规则检查器测试""" + + def setUp(self): + self.checker = RuleChecker() + + # ========== 硬性禁止测试 ========== + + def test_block_socket_import(self): + """测试禁止 socket 模块""" + code = "import socket\ns = socket.socket()" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('socket' in v for v in result.violations)) + + def test_block_subprocess_import(self): + """测试禁止 subprocess 模块""" + code = "import subprocess\nsubprocess.run(['ls'])" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('subprocess' in v for v in result.violations)) + + def test_block_eval(self): + """测试禁止 eval""" + code = "result = eval('1+1')" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('eval' in v for v in result.violations)) + + def test_block_exec(self): + """测试禁止 exec""" + code = "exec('print(1)')" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('exec' in v for v in result.violations)) + + def test_block_os_system(self): + """测试禁止 os.system""" + code = "import os\nos.system('dir')" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('os.system' in v for v in result.violations)) + + def test_block_os_popen(self): + """测试禁止 os.popen""" + code = "import os\nos.popen('dir')" + result = self.checker.check(code) + self.assertFalse(result.passed) + self.assertTrue(any('os.popen' in v for v in result.violations)) + + # ========== 警告测试 ========== + + def test_warn_requests_import(self): + """测试 requests 模块产生警告""" + code = "import requests\nresponse = requests.get('http://example.com')" + result = self.checker.check(code) + self.assertTrue(result.passed) # 不应该被阻止 + self.assertTrue(any('requests' in w for w in result.warnings)) + + def test_warn_os_remove(self): + """测试 os.remove 产生警告""" + code = "import os\nos.remove('file.txt')" + result = self.checker.check(code) + self.assertTrue(result.passed) # 不应该被阻止 + self.assertTrue(any('os.remove' in w for w in result.warnings)) + + def test_warn_shutil_rmtree(self): + """测试 shutil.rmtree 产生警告""" + code = "import shutil\nshutil.rmtree('folder')" + result = self.checker.check(code) + self.assertTrue(result.passed) # 不应该被阻止 + self.assertTrue(any('shutil.rmtree' in w for w in result.warnings)) + + # ========== 安全代码测试 ========== + + def test_safe_file_copy(self): + """测试安全的文件复制代码""" + code = """ +import shutil +from pathlib import Path + +INPUT_DIR = Path('workspace/input') +OUTPUT_DIR = Path('workspace/output') + +for f in INPUT_DIR.glob('*'): + shutil.copy(f, OUTPUT_DIR / f.name) +""" + result = self.checker.check(code) + self.assertTrue(result.passed) + self.assertEqual(len(result.violations), 0) + + def test_safe_image_processing(self): + """测试安全的图片处理代码""" + code = """ +from PIL import Image +from pathlib import Path + +INPUT_DIR = Path('workspace/input') +OUTPUT_DIR = Path('workspace/output') + +for img_path in INPUT_DIR.glob('*.png'): + img = Image.open(img_path) + img = img.resize((100, 100)) + img.save(OUTPUT_DIR / img_path.name) +""" + result = self.checker.check(code) + self.assertTrue(result.passed) + self.assertEqual(len(result.violations), 0) + + def test_safe_excel_processing(self): + """测试安全的 Excel 处理代码""" + code = """ +import openpyxl +from pathlib import Path + +INPUT_DIR = Path('workspace/input') +OUTPUT_DIR = Path('workspace/output') + +for xlsx_path in INPUT_DIR.glob('*.xlsx'): + wb = openpyxl.load_workbook(xlsx_path) + ws = wb.active + ws['A1'] = 'Modified' + wb.save(OUTPUT_DIR / xlsx_path.name) +""" + result = self.checker.check(code) + self.assertTrue(result.passed) + self.assertEqual(len(result.violations), 0) + + +class TestCheckCodeSafety(unittest.TestCase): + """便捷函数测试""" + + def test_convenience_function(self): + """测试便捷函数""" + result = check_code_safety("print('hello')") + self.assertTrue(result.passed) + + def test_convenience_function_block(self): + """测试便捷函数阻止危险代码""" + result = check_code_safety("import socket") + self.assertFalse(result.passed) + + +if __name__ == '__main__': + unittest.main() + diff --git a/ui/chat_view.py b/ui/chat_view.py index f686e94..0988d32 100644 --- a/ui/chat_view.py +++ b/ui/chat_view.py @@ -1,6 +1,6 @@ """ 聊天视图组件 -处理普通对话的 UI 展示 - 支持流式消息 +处理普通对话的 UI 展示 - 支持流式消息和加载动画 """ import tkinter as tk @@ -8,6 +8,58 @@ from tkinter import scrolledtext from typing import Callable, Optional +class LoadingIndicator: + """加载动画指示器""" + + FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + + def __init__(self, parent: tk.Widget, text: str = "处理中"): + self.parent = parent + self.text = text + self.frame_index = 0 + self.running = False + self.after_id = None + + # 创建标签 + self.label = tk.Label( + parent, + text="", + font=('Microsoft YaHei UI', 10), + fg='#ffd54f', + bg='#1e1e1e' + ) + + def start(self, text: str = None): + """开始动画""" + if text: + self.text = text + self.running = True + self.label.pack(pady=5) + self._animate() + + def stop(self): + """停止动画""" + self.running = False + if self.after_id: + self.parent.after_cancel(self.after_id) + self.after_id = None + self.label.pack_forget() + + def update_text(self, text: str): + """更新提示文字""" + self.text = text + + def _animate(self): + """动画帧更新""" + if not self.running: + return + + frame = self.FRAMES[self.frame_index] + self.label.config(text=f"{frame} {self.text}...") + self.frame_index = (self.frame_index + 1) % len(self.FRAMES) + self.after_id = self.parent.after(100, self._animate) + + class ChatView: """ 聊天视图 @@ -22,7 +74,8 @@ class ChatView: def __init__( self, parent: tk.Widget, - on_send: Callable[[str], None] + on_send: Callable[[str], None], + on_show_history: Optional[Callable[[], None]] = None ): """ 初始化聊天视图 @@ -30,14 +83,19 @@ class ChatView: Args: parent: 父容器 on_send: 发送消息回调函数 + on_show_history: 显示历史记录回调函数 """ self.parent = parent self.on_send = on_send + self.on_show_history = on_show_history # 流式消息状态 self._stream_active = False self._stream_tag = None + # 加载指示器 + self.loading: Optional[LoadingIndicator] = None + self._create_widgets() def _create_widgets(self): @@ -46,15 +104,37 @@ class ChatView: self.frame = tk.Frame(self.parent, bg='#1e1e1e') self.frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + # 标题栏(包含标题和历史按钮) + title_frame = tk.Frame(self.frame, bg='#1e1e1e') + title_frame.pack(fill=tk.X, pady=(0, 10)) + # 标题 title_label = tk.Label( - self.frame, + title_frame, text="LocalAgent - 本地 AI 助手", font=('Microsoft YaHei UI', 16, 'bold'), fg='#61dafb', bg='#1e1e1e' ) - title_label.pack(pady=(0, 10)) + title_label.pack(side=tk.LEFT, expand=True) + + # 历史记录按钮 + if self.on_show_history: + self.history_btn = tk.Button( + title_frame, + text="📜 历史", + font=('Microsoft YaHei UI', 10), + bg='#424242', + fg='#ce93d8', + activebackground='#616161', + activeforeground='#ce93d8', + relief=tk.FLAT, + padx=10, + pady=3, + cursor='hand2', + command=self.on_show_history + ) + self.history_btn.pack(side=tk.RIGHT) # 消息显示区域 self.message_area = scrolledtext.ScrolledText( @@ -118,6 +198,9 @@ class ChatView: "- 输入文件处理需求(如\"复制文件\"、\"整理图片\")将触发执行模式" ) self.add_message(welcome_msg, 'system') + + # 创建加载指示器(放在消息区域下方) + self.loading = LoadingIndicator(self.frame) def _on_enter_pressed(self, event): """回车键处理""" @@ -214,6 +297,21 @@ class ChatView: self.input_entry.config(state=state) self.send_button.config(state=state) + def show_loading(self, text: str = "处理中"): + """显示加载动画""" + if self.loading: + self.loading.start(text) + + def hide_loading(self): + """隐藏加载动画""" + if self.loading: + self.loading.stop() + + def update_loading_text(self, text: str): + """更新加载提示文字""" + if self.loading: + self.loading.update_text(text) + def get_frame(self) -> tk.Frame: """获取主框架""" return self.frame diff --git a/ui/history_view.py b/ui/history_view.py new file mode 100644 index 0000000..d3118e6 --- /dev/null +++ b/ui/history_view.py @@ -0,0 +1,335 @@ +""" +历史记录视图组件 +显示任务执行历史 +""" + +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Callable, List, Optional +from pathlib import Path + +from history.manager import TaskRecord, HistoryManager + + +class HistoryView: + """ + 历史记录视图 + + 显示任务执行历史列表,支持查看详情 + """ + + def __init__( + self, + parent: tk.Widget, + history_manager: HistoryManager, + on_back: Callable[[], None] + ): + self.parent = parent + self.history = history_manager + self.on_back = on_back + + self._selected_record: Optional[TaskRecord] = None + self._create_widgets() + + def _create_widgets(self): + """创建 UI 组件""" + self.frame = tk.Frame(self.parent, bg='#1e1e1e') + + # 标题栏 + title_frame = tk.Frame(self.frame, bg='#1e1e1e') + title_frame.pack(fill=tk.X, padx=10, pady=10) + + # 返回按钮 + back_btn = tk.Button( + title_frame, + text="← 返回", + font=('Microsoft YaHei UI', 10), + bg='#424242', + fg='white', + activebackground='#616161', + activeforeground='white', + relief=tk.FLAT, + padx=10, + cursor='hand2', + command=self.on_back + ) + back_btn.pack(side=tk.LEFT) + + # 标题 + title_label = tk.Label( + title_frame, + text="📜 任务历史记录", + font=('Microsoft YaHei UI', 14, 'bold'), + fg='#ce93d8', + bg='#1e1e1e' + ) + title_label.pack(side=tk.LEFT, padx=20) + + # 统计信息 + stats = self.history.get_stats() + stats_text = f"共 {stats['total']} 条 | 成功 {stats['success']} | 失败 {stats['failed']} | 成功率 {stats['success_rate']:.0%}" + stats_label = tk.Label( + title_frame, + text=stats_text, + font=('Microsoft YaHei UI', 9), + fg='#888888', + bg='#1e1e1e' + ) + stats_label.pack(side=tk.RIGHT) + + # 主内容区域(左右分栏) + content_frame = tk.Frame(self.frame, bg='#1e1e1e') + content_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=(0, 10)) + + # 左侧:历史列表 + list_frame = tk.LabelFrame( + content_frame, + text=" 任务列表 ", + font=('Microsoft YaHei UI', 10, 'bold'), + fg='#4fc3f7', + bg='#1e1e1e', + relief=tk.GROOVE + ) + list_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 5)) + + # 列表框 + list_container = tk.Frame(list_frame, bg='#2d2d2d') + list_container.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) + + # 使用 Treeview 显示列表 + columns = ('time', 'input', 'status', 'duration') + self.tree = ttk.Treeview(list_container, columns=columns, show='headings', height=15) + + # 配置列 + self.tree.heading('time', text='时间') + self.tree.heading('input', text='任务描述') + self.tree.heading('status', text='状态') + self.tree.heading('duration', text='耗时') + + self.tree.column('time', width=120, minwidth=100) + self.tree.column('input', width=250, minwidth=150) + self.tree.column('status', width=60, minwidth=50) + self.tree.column('duration', width=70, minwidth=50) + + # 滚动条 + scrollbar = ttk.Scrollbar(list_container, orient=tk.VERTICAL, command=self.tree.yview) + self.tree.configure(yscrollcommand=scrollbar.set) + + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # 绑定选择事件 + self.tree.bind('<>', self._on_select) + + # 右侧:详情面板 + detail_frame = tk.LabelFrame( + content_frame, + text=" 任务详情 ", + font=('Microsoft YaHei UI', 10, 'bold'), + fg='#81c784', + bg='#1e1e1e', + relief=tk.GROOVE + ) + detail_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=(5, 0)) + + # 详情文本框 + detail_container = tk.Frame(detail_frame, bg='#2d2d2d') + detail_container.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) + + self.detail_text = tk.Text( + detail_container, + wrap=tk.WORD, + font=('Microsoft YaHei UI', 10), + bg='#2d2d2d', + fg='#d4d4d4', + relief=tk.FLAT, + padx=10, + pady=10, + state=tk.DISABLED + ) + + detail_scrollbar = ttk.Scrollbar(detail_container, orient=tk.VERTICAL, command=self.detail_text.yview) + self.detail_text.configure(yscrollcommand=detail_scrollbar.set) + + detail_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.detail_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # 配置详情文本样式 + self.detail_text.tag_configure('title', font=('Microsoft YaHei UI', 11, 'bold'), foreground='#ffd54f') + self.detail_text.tag_configure('label', font=('Microsoft YaHei UI', 10, 'bold'), foreground='#4fc3f7') + self.detail_text.tag_configure('success', foreground='#81c784') + self.detail_text.tag_configure('error', foreground='#ef5350') + self.detail_text.tag_configure('code', font=('Consolas', 9), foreground='#ce93d8') + + # 底部按钮 + btn_frame = tk.Frame(self.frame, bg='#1e1e1e') + btn_frame.pack(fill=tk.X, padx=10, pady=(0, 10)) + + # 打开日志按钮 + self.open_log_btn = tk.Button( + btn_frame, + text="📄 打开日志", + font=('Microsoft YaHei UI', 10), + bg='#424242', + fg='white', + activebackground='#616161', + activeforeground='white', + relief=tk.FLAT, + padx=15, + cursor='hand2', + state=tk.DISABLED, + command=self._open_log + ) + self.open_log_btn.pack(side=tk.LEFT) + + # 清空历史按钮 + clear_btn = tk.Button( + btn_frame, + text="🗑️ 清空历史", + font=('Microsoft YaHei UI', 10), + bg='#d32f2f', + fg='white', + activebackground='#f44336', + activeforeground='white', + relief=tk.FLAT, + padx=15, + cursor='hand2', + command=self._clear_history + ) + clear_btn.pack(side=tk.RIGHT) + + # 加载数据 + self._load_data() + + def _load_data(self): + """加载历史数据到列表""" + # 清空现有数据 + for item in self.tree.get_children(): + self.tree.delete(item) + + # 加载历史记录 + records = self.history.get_all() + + for record in records: + # 截断过长的输入 + input_text = record.user_input + if len(input_text) > 30: + input_text = input_text[:30] + "..." + + status = "✓ 成功" if record.success else "✗ 失败" + duration = f"{record.duration_ms}ms" + + # 提取时间(只显示时分秒) + time_parts = record.timestamp.split(' ') + time_str = time_parts[1] if len(time_parts) > 1 else record.timestamp + date_str = time_parts[0] if len(time_parts) > 0 else "" + display_time = f"{date_str}\n{time_str}" + + self.tree.insert('', tk.END, iid=record.task_id, values=( + record.timestamp, + input_text, + status, + duration + )) + + # 显示空状态提示 + if not records: + self._show_detail("暂无历史记录\n\n执行任务后,记录将显示在这里。") + + def _on_select(self, event): + """选择记录事件""" + selection = self.tree.selection() + if not selection: + return + + task_id = selection[0] + record = self.history.get_by_id(task_id) + + if record: + self._selected_record = record + self._show_record_detail(record) + self.open_log_btn.config(state=tk.NORMAL) + + def _show_record_detail(self, record: TaskRecord): + """显示记录详情""" + self.detail_text.config(state=tk.NORMAL) + self.detail_text.delete(1.0, tk.END) + + # 标题 + self.detail_text.insert(tk.END, f"任务 ID: {record.task_id}\n", 'title') + self.detail_text.insert(tk.END, f"时间: {record.timestamp}\n\n") + + # 用户输入 + self.detail_text.insert(tk.END, "用户输入:\n", 'label') + self.detail_text.insert(tk.END, f"{record.user_input}\n\n") + + # 执行状态 + self.detail_text.insert(tk.END, "执行状态: ", 'label') + if record.success: + self.detail_text.insert(tk.END, "成功 ✓\n", 'success') + else: + self.detail_text.insert(tk.END, "失败 ✗\n", 'error') + + self.detail_text.insert(tk.END, f"耗时: {record.duration_ms}ms\n\n") + + # 执行计划 + self.detail_text.insert(tk.END, "执行计划:\n", 'label') + plan_preview = record.execution_plan[:500] + "..." if len(record.execution_plan) > 500 else record.execution_plan + self.detail_text.insert(tk.END, f"{plan_preview}\n\n") + + # 输出 + if record.stdout: + self.detail_text.insert(tk.END, "输出:\n", 'label') + self.detail_text.insert(tk.END, f"{record.stdout}\n\n") + + # 错误 + if record.stderr: + self.detail_text.insert(tk.END, "错误:\n", 'label') + self.detail_text.insert(tk.END, f"{record.stderr}\n", 'error') + + self.detail_text.config(state=tk.DISABLED) + + def _show_detail(self, text: str): + """显示详情文本""" + self.detail_text.config(state=tk.NORMAL) + self.detail_text.delete(1.0, tk.END) + self.detail_text.insert(tk.END, text) + self.detail_text.config(state=tk.DISABLED) + + def _open_log(self): + """打开日志文件""" + if self._selected_record and self._selected_record.log_path: + import os + log_path = Path(self._selected_record.log_path) + if log_path.exists(): + os.startfile(str(log_path)) + else: + messagebox.showwarning("提示", f"日志文件不存在:\n{log_path}") + + def _clear_history(self): + """清空历史记录""" + result = messagebox.askyesno( + "确认清空", + "确定要清空所有历史记录吗?\n此操作不可恢复。", + icon='warning' + ) + + if result: + self.history.clear() + self._load_data() + self._show_detail("历史记录已清空") + self.open_log_btn.config(state=tk.DISABLED) + + def show(self): + """显示视图""" + self._load_data() # 刷新数据 + self.frame.pack(fill=tk.BOTH, expand=True) + + def hide(self): + """隐藏视图""" + self.frame.pack_forget() + + def get_frame(self) -> tk.Frame: + """获取主框架""" + return self.frame + diff --git a/ui/task_guide_view.py b/ui/task_guide_view.py index 50bcfa3..09f2adc 100644 --- a/ui/task_guide_view.py +++ b/ui/task_guide_view.py @@ -243,6 +243,9 @@ class TaskGuideView: # 执行计划区域(Markdown) self._create_plan_section() + # 代码预览区域(可折叠) + self._create_code_section() + # 风险提示区域 self._create_risk_section() @@ -306,6 +309,148 @@ class TaskGuideView: scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.plan_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + def _create_code_section(self): + """创建代码预览区域(可折叠)""" + # 折叠状态 + self._code_expanded = False + + # 外层框架 + self.code_section = tk.LabelFrame( + self.frame, + text=" 💻 生成的代码 ", + font=('Microsoft YaHei UI', 10, 'bold'), + fg='#64b5f6', + bg='#1e1e1e', + relief=tk.GROOVE + ) + self.code_section.pack(fill=tk.X, padx=10, pady=3) + + # 展开/折叠按钮 + self.toggle_code_btn = tk.Button( + self.code_section, + text="▶ 点击展开代码预览", + font=('Microsoft YaHei UI', 9), + bg='#2d2d2d', + fg='#64b5f6', + activebackground='#3d3d3d', + activeforeground='#64b5f6', + relief=tk.FLAT, + cursor='hand2', + command=self._toggle_code_view + ) + self.toggle_code_btn.pack(fill=tk.X, padx=5, pady=5) + + # 代码显示区域(初始隐藏) + self.code_frame = tk.Frame(self.code_section, bg='#1e1e1e') + + # 代码文本框 + self.code_text = tk.Text( + self.code_frame, + wrap=tk.NONE, + font=('Consolas', 10), + bg='#1e1e1e', + fg='#d4d4d4', + insertbackground='white', + relief=tk.FLAT, + height=12, + padx=8, + pady=5 + ) + + # 配置代码高亮标签 + self.code_text.tag_configure('keyword', foreground='#569cd6') + self.code_text.tag_configure('string', foreground='#ce9178') + self.code_text.tag_configure('comment', foreground='#6a9955') + self.code_text.tag_configure('function', foreground='#dcdcaa') + self.code_text.tag_configure('number', foreground='#b5cea8') + + # 滚动条 + code_scrollbar_y = ttk.Scrollbar(self.code_frame, orient=tk.VERTICAL, command=self.code_text.yview) + code_scrollbar_x = ttk.Scrollbar(self.code_frame, orient=tk.HORIZONTAL, command=self.code_text.xview) + self.code_text.configure(yscrollcommand=code_scrollbar_y.set, xscrollcommand=code_scrollbar_x.set) + + code_scrollbar_y.pack(side=tk.RIGHT, fill=tk.Y) + code_scrollbar_x.pack(side=tk.BOTTOM, fill=tk.X) + self.code_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # 复制按钮 + self.copy_code_btn = tk.Button( + self.code_frame, + text="📋 复制代码", + font=('Microsoft YaHei UI', 9), + bg='#424242', + fg='white', + activebackground='#616161', + activeforeground='white', + relief=tk.FLAT, + cursor='hand2', + command=self._copy_code + ) + + def _toggle_code_view(self): + """切换代码预览的展开/折叠状态""" + self._code_expanded = not self._code_expanded + + if self._code_expanded: + self.toggle_code_btn.config(text="▼ 点击折叠代码预览") + self.code_frame.pack(fill=tk.BOTH, expand=True, padx=3, pady=(0, 5)) + self.copy_code_btn.pack(pady=5) + else: + self.toggle_code_btn.config(text="▶ 点击展开代码预览") + self.copy_code_btn.pack_forget() + self.code_frame.pack_forget() + + def _copy_code(self): + """复制代码到剪贴板""" + code = self.code_text.get(1.0, tk.END).strip() + self.frame.clipboard_clear() + self.frame.clipboard_append(code) + + # 显示复制成功提示 + original_text = self.copy_code_btn.cget('text') + self.copy_code_btn.config(text="✓ 已复制!") + self.frame.after(1500, lambda: self.copy_code_btn.config(text=original_text)) + + def _apply_syntax_highlight(self, code: str): + """应用简单的语法高亮""" + import re + + # 关键字 + keywords = r'\b(import|from|def|class|if|else|elif|for|while|try|except|finally|with|as|return|yield|raise|pass|break|continue|and|or|not|in|is|None|True|False|lambda|global|nonlocal)\b' + # 字符串 + strings = r'(\"\"\"[\s\S]*?\"\"\"|\'\'\'[\s\S]*?\'\'\'|\"[^\"]*\"|\'[^\']*\')' + # 注释 + comments = r'(#.*$)' + # 函数调用 + functions = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(' + # 数字 + numbers = r'\b(\d+\.?\d*)\b' + + # 先插入纯文本 + self.code_text.delete(1.0, tk.END) + self.code_text.insert(1.0, code) + + # 应用高亮 + for match in re.finditer(keywords, code, re.MULTILINE): + start = f"1.0+{match.start()}c" + end = f"1.0+{match.end()}c" + self.code_text.tag_add('keyword', start, end) + + for match in re.finditer(strings, code, re.MULTILINE): + start = f"1.0+{match.start()}c" + end = f"1.0+{match.end()}c" + self.code_text.tag_add('string', start, end) + + for match in re.finditer(comments, code, re.MULTILINE): + start = f"1.0+{match.start()}c" + end = f"1.0+{match.end()}c" + self.code_text.tag_add('comment', start, end) + + for match in re.finditer(numbers, code, re.MULTILINE): + start = f"1.0+{match.start(1)}c" + end = f"1.0+{match.end(1)}c" + self.code_text.tag_add('number', start, end) + def _create_risk_section(self): """创建风险提示区域""" section = tk.LabelFrame( @@ -423,6 +568,12 @@ class TaskGuideView: """设置执行计划(Markdown 格式)""" self.plan_text.set_markdown(plan) + def set_code(self, code: str): + """设置生成的代码""" + self.code_text.config(state=tk.NORMAL) + self._apply_syntax_highlight(code) + self.code_text.config(state=tk.DISABLED) + def set_risk_info(self, info: str): """设置风险提示""" self.risk_label.config(text=info)