feat: enhance LocalAgent configuration and UI components
- Updated .env.example to provide clearer configuration instructions and API key setup. - Removed debug_env.py as it was no longer needed. - Refactored main.py to streamline application initialization and workspace setup. - Introduced a new HistoryManager for managing task execution history. - Enhanced UI components in chat_view.py and task_guide_view.py to improve user interaction and code preview functionality. - Added loading indicators and improved task history display in the UI. - Implemented unit tests for history management and intent classification.
This commit is contained in:
21
.env.example
21
.env.example
@@ -1,16 +1,19 @@
|
||||
# ========================================
|
||||
# LocalAgent 閰嶇疆鏂囦欢绀轰緥
|
||||
# ========================================
|
||||
# 浣跨敤鏂规硶锛?# 1. 澶嶅埗姝ゆ枃浠朵负 .env
|
||||
# 2. 濉叆浣犵殑 API Key 鍜屽叾浠栭厤缃?# ========================================
|
||||
# LocalAgent Configuration Example
|
||||
# ========================================
|
||||
# Usage:
|
||||
# 1. Copy this file to .env
|
||||
# 2. Fill in your API Key and other settings
|
||||
# ========================================
|
||||
|
||||
# SiliconFlow API 閰嶇疆
|
||||
# 鑾峰彇 API Key: https://siliconflow.cn
|
||||
# SiliconFlow API Configuration
|
||||
# Get API Key: https://siliconflow.cn
|
||||
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions
|
||||
LLM_API_KEY=your_api_key_here
|
||||
|
||||
# 妯″瀷閰嶇疆
|
||||
# 鎰忓浘璇嗗埆妯″瀷锛堟帹鑽愪娇鐢ㄥ皬妯″瀷锛岄€熷害蹇級
|
||||
# Model Configuration
|
||||
# Intent recognition model (small model recommended for speed)
|
||||
INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
||||
|
||||
# 浠g爜鐢熸垚妯″瀷锛堟帹鑽愪娇鐢ㄥぇ妯″瀷锛屾晥鏋滃ソ锛?GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
||||
# Code generation model (large model recommended for quality)
|
||||
GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
||||
|
||||
169
README.md
Normal file
169
README.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# LocalAgent - Windows Local AI Execution Assistant
|
||||
|
||||
A Windows-based local AI assistant that can understand natural language commands and execute file processing tasks safely in a sandboxed environment.
|
||||
|
||||
## Features
|
||||
|
||||
- **Intent Recognition**: Automatically distinguishes between chat conversations and execution tasks
|
||||
- **Code Generation**: Generates Python code based on user requirements
|
||||
- **Safety Checks**: Multi-layer security with static analysis and LLM review
|
||||
- **Sandbox Execution**: Runs generated code in an isolated environment
|
||||
- **Task History**: Records all executed tasks for review
|
||||
- **Streaming Responses**: Real-time display of LLM responses
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
LocalAgent/
|
||||
├── app/ # Main application
|
||||
│ └── agent.py # Core application class
|
||||
├── llm/ # LLM integration
|
||||
│ ├── client.py # API client with retry support
|
||||
│ └── prompts.py # Prompt templates
|
||||
├── intent/ # Intent classification
|
||||
│ ├── classifier.py # Intent classifier
|
||||
│ └── labels.py # Intent labels
|
||||
├── safety/ # Security checks
|
||||
│ ├── rule_checker.py # Static rule checker
|
||||
│ └── llm_reviewer.py # LLM-based code review
|
||||
├── executor/ # Code execution
|
||||
│ └── sandbox_runner.py # Sandbox executor
|
||||
├── history/ # Task history
|
||||
│ └── manager.py # History manager
|
||||
├── ui/ # User interface
|
||||
│ ├── chat_view.py # Chat interface
|
||||
│ ├── task_guide_view.py # Task confirmation view
|
||||
│ └── history_view.py # History view
|
||||
├── tests/ # Unit tests
|
||||
├── workspace/ # Working directory (auto-created)
|
||||
│ ├── input/ # Input files
|
||||
│ ├── output/ # Output files
|
||||
│ ├── codes/ # Generated code
|
||||
│ └── logs/ # Execution logs
|
||||
├── main.py # Entry point
|
||||
├── requirements.txt # Dependencies
|
||||
└── .env.example # Configuration template
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.10+
|
||||
- Windows OS
|
||||
- SiliconFlow API Key ([Get one here](https://siliconflow.cn))
|
||||
|
||||
### Setup
|
||||
|
||||
1. **Clone the repository**
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd LocalAgent
|
||||
```
|
||||
|
||||
2. **Create virtual environment** (recommended using Anaconda)
|
||||
```bash
|
||||
conda create -n localagent python=3.10
|
||||
conda activate localagent
|
||||
```
|
||||
|
||||
3. **Install dependencies**
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. **Configure environment**
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your API key
|
||||
```
|
||||
|
||||
5. **Run the application**
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit `.env` file with your settings:
|
||||
|
||||
```env
|
||||
# SiliconFlow API Configuration
|
||||
LLM_API_URL=https://api.siliconflow.cn/v1/chat/completions
|
||||
LLM_API_KEY=your_api_key_here
|
||||
|
||||
# Model Configuration
|
||||
INTENT_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
||||
GENERATION_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Chat Mode
|
||||
Simply type questions or have conversations:
|
||||
- "What is Python?"
|
||||
- "Explain machine learning"
|
||||
|
||||
### Execution Mode
|
||||
Describe file processing tasks:
|
||||
- "Copy all files from input to output"
|
||||
- "Convert all PNG images to JPG format"
|
||||
- "Rename files with today's date prefix"
|
||||
|
||||
### Workflow
|
||||
1. Place input files in `workspace/input/`
|
||||
2. Describe your task in the chat
|
||||
3. Review the execution plan and generated code
|
||||
4. Click "Execute" to run
|
||||
5. Find results in `workspace/output/`
|
||||
|
||||
## Security
|
||||
|
||||
LocalAgent implements multiple security layers:
|
||||
|
||||
1. **Hard Rules** - Blocks dangerous operations:
|
||||
- Network modules (socket, subprocess)
|
||||
- Code execution (eval, exec)
|
||||
- System commands (os.system, os.popen)
|
||||
|
||||
2. **Soft Rules** - Warns about sensitive operations:
|
||||
- File deletion
|
||||
- Network requests (requests, urllib)
|
||||
|
||||
3. **LLM Review** - Semantic analysis of generated code
|
||||
|
||||
4. **Sandbox Execution** - Isolated subprocess with limited permissions
|
||||
|
||||
## Testing
|
||||
|
||||
Run unit tests:
|
||||
```bash
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
## Supported File Operations
|
||||
|
||||
The generated code can use these libraries:
|
||||
|
||||
**Standard Library:**
|
||||
- os, sys, pathlib - Path operations
|
||||
- shutil - File copy/move
|
||||
- json, csv - Data formats
|
||||
- zipfile, tarfile - Compression
|
||||
- And more...
|
||||
|
||||
**Third-party Libraries:**
|
||||
- Pillow - Image processing
|
||||
- openpyxl - Excel files
|
||||
- python-docx - Word documents
|
||||
- PyPDF2 - PDF files
|
||||
- chardet - Encoding detection
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||
|
||||
2
app/__init__.py
Normal file
2
app/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# 应用模块
|
||||
|
||||
526
app/agent.py
Normal file
526
app/agent.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
LocalAgent 主应用类
|
||||
管理 UI 状态切换和协调各模块工作流程
|
||||
"""
|
||||
|
||||
import os
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from llm.client import get_client, LLMClientError
|
||||
from llm.prompts import (
|
||||
EXECUTION_PLAN_SYSTEM, EXECUTION_PLAN_USER,
|
||||
CODE_GENERATION_SYSTEM, CODE_GENERATION_USER
|
||||
)
|
||||
from intent.classifier import classify_intent, IntentResult
|
||||
from intent.labels import CHAT, EXECUTION
|
||||
from safety.rule_checker import check_code_safety
|
||||
from safety.llm_reviewer import review_code_safety, LLMReviewResult
|
||||
from executor.sandbox_runner import SandboxRunner, ExecutionResult
|
||||
from ui.chat_view import ChatView
|
||||
from ui.task_guide_view import TaskGuideView
|
||||
from ui.history_view import HistoryView
|
||||
from history.manager import get_history_manager, HistoryManager
|
||||
|
||||
|
||||
class LocalAgentApp:
|
||||
"""
|
||||
LocalAgent 主应用
|
||||
|
||||
职责:
|
||||
1. 管理 UI 状态切换
|
||||
2. 协调各模块工作流程
|
||||
3. 处理用户交互
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: Path):
|
||||
self.project_root: Path = project_root
|
||||
self.workspace: Path = project_root / "workspace"
|
||||
self.runner: SandboxRunner = SandboxRunner(str(self.workspace))
|
||||
self.history: HistoryManager = get_history_manager(self.workspace)
|
||||
|
||||
# 当前任务状态
|
||||
self.current_task: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 线程通信队列
|
||||
self.result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# UI 组件
|
||||
self.root: Optional[tk.Tk] = None
|
||||
self.main_container: Optional[tk.Frame] = None
|
||||
self.chat_view: Optional[ChatView] = None
|
||||
self.task_view: Optional[TaskGuideView] = None
|
||||
self.history_view: Optional[HistoryView] = None
|
||||
|
||||
# 初始化 UI
|
||||
self._init_ui()
|
||||
|
||||
def _init_ui(self) -> None:
|
||||
"""初始化 UI"""
|
||||
self.root = tk.Tk()
|
||||
self.root.title("LocalAgent - 本地 AI 助手")
|
||||
self.root.geometry("800x700")
|
||||
self.root.configure(bg='#1e1e1e')
|
||||
|
||||
# 设置窗口图标(如果有的话)
|
||||
try:
|
||||
self.root.iconbitmap(self.project_root / "icon.ico")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 主容器
|
||||
self.main_container = tk.Frame(self.root, bg='#1e1e1e')
|
||||
self.main_container.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
# 聊天视图
|
||||
self.chat_view = ChatView(
|
||||
self.main_container,
|
||||
self._on_user_input,
|
||||
on_show_history=self._show_history
|
||||
)
|
||||
|
||||
# 定期检查后台任务结果
|
||||
self._check_queue()
|
||||
|
||||
def _check_queue(self) -> None:
|
||||
"""检查后台任务队列"""
|
||||
try:
|
||||
while True:
|
||||
callback, args = self.result_queue.get_nowait()
|
||||
callback(*args)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# 每 100ms 检查一次
|
||||
self.root.after(100, self._check_queue)
|
||||
|
||||
def _run_in_thread(self, func: callable, callback: callable, *args) -> None:
|
||||
"""在后台线程运行函数,完成后回调"""
|
||||
def wrapper():
|
||||
try:
|
||||
result = func(*args)
|
||||
self.result_queue.put((callback, (result, None)))
|
||||
except Exception as e:
|
||||
self.result_queue.put((callback, (None, e)))
|
||||
|
||||
thread = threading.Thread(target=wrapper, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _on_user_input(self, user_input: str) -> None:
|
||||
"""处理用户输入"""
|
||||
# 显示用户消息
|
||||
self.chat_view.add_message(user_input, 'user')
|
||||
self.chat_view.set_input_enabled(False)
|
||||
self.chat_view.show_loading("正在分析您的需求")
|
||||
|
||||
# 在后台线程进行意图识别
|
||||
self._run_in_thread(
|
||||
classify_intent,
|
||||
lambda result, error: self._on_intent_result(user_input, result, error),
|
||||
user_input
|
||||
)
|
||||
|
||||
def _on_intent_result(self, user_input: str, intent_result: Optional[IntentResult], error: Optional[Exception]) -> None:
|
||||
"""意图识别完成回调"""
|
||||
self.chat_view.hide_loading()
|
||||
|
||||
if error:
|
||||
self.chat_view.add_message(f"意图识别失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
return
|
||||
|
||||
if intent_result.label == CHAT:
|
||||
# 对话模式
|
||||
self._handle_chat(user_input, intent_result)
|
||||
else:
|
||||
# 执行模式
|
||||
self._handle_execution(user_input, intent_result)
|
||||
|
||||
def _handle_chat(self, user_input: str, intent_result: IntentResult) -> None:
|
||||
"""处理对话任务"""
|
||||
self.chat_view.add_message(
|
||||
f"识别为对话模式 (原因: {intent_result.reason})",
|
||||
'system'
|
||||
)
|
||||
|
||||
# 开始流式消息
|
||||
self.chat_view.start_stream_message('assistant')
|
||||
|
||||
# 在后台线程调用 LLM(流式)
|
||||
def do_chat_stream():
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
full_response = []
|
||||
for chunk in client.chat_stream(
|
||||
messages=[{"role": "user", "content": user_input}],
|
||||
model=model,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
timeout=300
|
||||
):
|
||||
full_response.append(chunk)
|
||||
# 通过队列发送 chunk 到主线程更新 UI
|
||||
self.result_queue.put((self._on_chat_chunk, (chunk,)))
|
||||
|
||||
return ''.join(full_response)
|
||||
|
||||
self._run_in_thread(
|
||||
do_chat_stream,
|
||||
self._on_chat_complete
|
||||
)
|
||||
|
||||
def _on_chat_chunk(self, chunk: str):
|
||||
"""收到对话片段回调(主线程)"""
|
||||
self.chat_view.append_stream_chunk(chunk)
|
||||
|
||||
def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]):
|
||||
"""对话完成回调"""
|
||||
self.chat_view.end_stream_message()
|
||||
|
||||
if error:
|
||||
self.chat_view.add_message(f"对话失败: {str(error)}", 'error')
|
||||
|
||||
self.chat_view.set_input_enabled(True)
|
||||
|
||||
def _handle_execution(self, user_input: str, intent_result: IntentResult):
|
||||
"""处理执行任务"""
|
||||
self.chat_view.add_message(
|
||||
f"识别为执行任务 (置信度: {intent_result.confidence:.0%})\n原因: {intent_result.reason}",
|
||||
'system'
|
||||
)
|
||||
self.chat_view.show_loading("正在生成执行计划")
|
||||
|
||||
# 保存用户输入和意图结果
|
||||
self.current_task = {
|
||||
'user_input': user_input,
|
||||
'intent_result': intent_result
|
||||
}
|
||||
|
||||
# 在后台线程生成执行计划
|
||||
self._run_in_thread(
|
||||
self._generate_execution_plan,
|
||||
self._on_plan_generated,
|
||||
user_input
|
||||
)
|
||||
|
||||
def _on_plan_generated(self, plan: Optional[str], error: Optional[Exception]):
|
||||
"""执行计划生成完成回调"""
|
||||
if error:
|
||||
self.chat_view.hide_loading()
|
||||
self.chat_view.add_message(f"生成执行计划失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
self.current_task['execution_plan'] = plan
|
||||
self.chat_view.update_loading_text("正在生成执行代码")
|
||||
|
||||
# 在后台线程生成代码
|
||||
self._run_in_thread(
|
||||
self._generate_code,
|
||||
self._on_code_generated,
|
||||
self.current_task['user_input'],
|
||||
plan
|
||||
)
|
||||
|
||||
def _on_code_generated(self, result: tuple, error: Optional[Exception]):
|
||||
"""代码生成完成回调"""
|
||||
if error:
|
||||
self.chat_view.hide_loading()
|
||||
self.chat_view.add_message(f"生成代码失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
# result 可能是 (code, extract_error) 元组
|
||||
if isinstance(result, tuple):
|
||||
code, extract_error = result
|
||||
if extract_error:
|
||||
self.chat_view.hide_loading()
|
||||
self.chat_view.add_message(f"代码提取失败: {str(extract_error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
else:
|
||||
code = result
|
||||
|
||||
self.current_task['code'] = code
|
||||
self.chat_view.update_loading_text("正在进行安全检查")
|
||||
|
||||
# 硬规则检查(同步,很快)
|
||||
rule_result = check_code_safety(code)
|
||||
if not rule_result.passed:
|
||||
self.chat_view.hide_loading()
|
||||
violations = "\n".join(f" • {v}" for v in rule_result.violations)
|
||||
self.chat_view.add_message(
|
||||
f"安全检查未通过,任务已取消:\n{violations}",
|
||||
'error'
|
||||
)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
# 保存警告信息,传递给 LLM 审查
|
||||
self.current_task['warnings'] = rule_result.warnings
|
||||
|
||||
# 在后台线程进行 LLM 安全审查
|
||||
self._run_in_thread(
|
||||
lambda: review_code_safety(
|
||||
self.current_task['user_input'],
|
||||
self.current_task['execution_plan'],
|
||||
code,
|
||||
rule_result.warnings # 传递警告给 LLM
|
||||
),
|
||||
self._on_safety_reviewed
|
||||
)
|
||||
|
||||
def _on_safety_reviewed(self, review_result, error: Optional[Exception]):
|
||||
"""安全审查完成回调"""
|
||||
self.chat_view.hide_loading()
|
||||
|
||||
if error:
|
||||
self.chat_view.add_message(f"安全审查失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
if not review_result.passed:
|
||||
self.chat_view.add_message(
|
||||
f"安全审查未通过: {review_result.reason}",
|
||||
'error'
|
||||
)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
self.chat_view.add_message("安全检查通过,请确认执行", 'system')
|
||||
|
||||
# 显示任务引导视图
|
||||
self._show_task_guide()
|
||||
|
||||
def _generate_execution_plan(self, user_input: str) -> str:
|
||||
"""生成执行计划(使用流式传输)"""
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
# 使用流式传输,避免超时
|
||||
response = client.chat_stream_collect(
|
||||
messages=[
|
||||
{"role": "system", "content": EXECUTION_PLAN_SYSTEM},
|
||||
{"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)}
|
||||
],
|
||||
model=model,
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _generate_code(self, user_input: str, execution_plan: str) -> tuple:
|
||||
"""生成执行代码(使用流式传输)"""
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
# 使用流式传输,避免超时
|
||||
response = client.chat_stream_collect(
|
||||
messages=[
|
||||
{"role": "system", "content": CODE_GENERATION_SYSTEM},
|
||||
{"role": "user", "content": CODE_GENERATION_USER.format(
|
||||
user_input=user_input,
|
||||
execution_plan=execution_plan
|
||||
)}
|
||||
],
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
max_tokens=4096, # 代码可能较长
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
# 提取代码块,捕获可能的异常
|
||||
try:
|
||||
code = self._extract_code(response)
|
||||
return (code, None)
|
||||
except ValueError as e:
|
||||
return (None, e)
|
||||
|
||||
def _extract_code(self, response: str) -> str:
|
||||
"""从 LLM 响应中提取代码"""
|
||||
import re
|
||||
|
||||
# 尝试提取 ```python ... ``` 代码块
|
||||
pattern = r'```python\s*(.*?)\s*```'
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
# 尝试提取 ``` ... ``` 代码块
|
||||
pattern = r'```\s*(.*?)\s*```'
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
# 如果没有代码块,检查是否看起来像 Python 代码
|
||||
# 简单检查:是否包含 def 或 import 语句
|
||||
if 'import ' in response or 'def ' in response:
|
||||
return response
|
||||
|
||||
# 无法提取代码,抛出异常
|
||||
raise ValueError(
|
||||
"无法从 LLM 响应中提取代码块。\n"
|
||||
f"响应内容预览: {response[:200]}..."
|
||||
)
|
||||
|
||||
def _show_task_guide(self):
|
||||
"""显示任务引导视图"""
|
||||
if not self.current_task:
|
||||
return
|
||||
|
||||
# 隐藏聊天视图
|
||||
self.chat_view.get_frame().pack_forget()
|
||||
|
||||
# 创建任务引导视图
|
||||
self.task_view = TaskGuideView(
|
||||
self.main_container,
|
||||
on_execute=self._on_execute_task,
|
||||
on_cancel=self._on_cancel_task,
|
||||
workspace_path=self.workspace
|
||||
)
|
||||
|
||||
# 设置内容
|
||||
self.task_view.set_intent_result(
|
||||
self.current_task['intent_result'].reason,
|
||||
self.current_task['intent_result'].confidence
|
||||
)
|
||||
self.task_view.set_execution_plan(self.current_task['execution_plan'])
|
||||
self.task_view.set_code(self.current_task['code'])
|
||||
|
||||
# 显示
|
||||
self.task_view.show()
|
||||
|
||||
def _on_execute_task(self):
|
||||
"""执行任务"""
|
||||
if not self.current_task:
|
||||
return
|
||||
|
||||
self.task_view.set_buttons_enabled(False)
|
||||
|
||||
# 在后台线程执行
|
||||
def do_execute():
|
||||
return self.runner.execute(self.current_task['code'])
|
||||
|
||||
self._run_in_thread(
|
||||
do_execute,
|
||||
self._on_execution_complete
|
||||
)
|
||||
|
||||
def _on_execution_complete(self, result: Optional[ExecutionResult], error: Optional[Exception]):
|
||||
"""执行完成回调"""
|
||||
if error:
|
||||
messagebox.showerror("执行错误", f"执行失败: {str(error)}")
|
||||
else:
|
||||
# 保存历史记录
|
||||
if self.current_task:
|
||||
self.history.add_record(
|
||||
task_id=result.task_id,
|
||||
user_input=self.current_task['user_input'],
|
||||
intent_label=self.current_task['intent_result'].label,
|
||||
intent_confidence=self.current_task['intent_result'].confidence,
|
||||
execution_plan=self.current_task['execution_plan'],
|
||||
code=self.current_task['code'],
|
||||
success=result.success,
|
||||
duration_ms=result.duration_ms,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
log_path=result.log_path
|
||||
)
|
||||
|
||||
self._show_execution_result(result)
|
||||
# 刷新输出文件列表
|
||||
if self.task_view:
|
||||
self.task_view.refresh_output()
|
||||
|
||||
self._back_to_chat()
|
||||
|
||||
def _show_execution_result(self, result: ExecutionResult):
|
||||
"""显示执行结果"""
|
||||
if result.success:
|
||||
status = "执行成功"
|
||||
else:
|
||||
status = "执行失败"
|
||||
|
||||
message = f"""{status}
|
||||
|
||||
任务 ID: {result.task_id}
|
||||
耗时: {result.duration_ms} ms
|
||||
|
||||
输出:
|
||||
{result.stdout if result.stdout else '(无输出)'}
|
||||
|
||||
{f'错误信息: {result.stderr}' if result.stderr else ''}
|
||||
"""
|
||||
|
||||
if result.success:
|
||||
# 成功时显示结果并询问是否打开输出目录
|
||||
open_output = messagebox.askyesno(
|
||||
"执行结果",
|
||||
message + "\n\n是否打开输出文件夹?"
|
||||
)
|
||||
if open_output:
|
||||
os.startfile(str(self.workspace / "output"))
|
||||
else:
|
||||
# 失败时显示结果并询问是否打开日志
|
||||
open_log = messagebox.askyesno(
|
||||
"执行结果",
|
||||
message + "\n\n是否打开日志文件查看详情?"
|
||||
)
|
||||
if open_log and result.log_path:
|
||||
os.startfile(result.log_path)
|
||||
|
||||
def _on_cancel_task(self):
|
||||
"""取消任务"""
|
||||
self.current_task = None
|
||||
self._back_to_chat()
|
||||
|
||||
def _back_to_chat(self):
|
||||
"""返回聊天视图"""
|
||||
if self.task_view:
|
||||
self.task_view.hide()
|
||||
self.task_view = None
|
||||
|
||||
self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
|
||||
def _show_history(self):
|
||||
"""显示历史记录视图"""
|
||||
# 隐藏聊天视图
|
||||
self.chat_view.get_frame().pack_forget()
|
||||
|
||||
# 创建历史记录视图
|
||||
self.history_view = HistoryView(
|
||||
self.main_container,
|
||||
self.history,
|
||||
on_back=self._hide_history
|
||||
)
|
||||
self.history_view.show()
|
||||
|
||||
def _hide_history(self):
|
||||
"""隐藏历史记录视图,返回聊天"""
|
||||
if self.history_view:
|
||||
self.history_view.hide()
|
||||
self.history_view = None
|
||||
|
||||
self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
|
||||
def run(self):
|
||||
"""运行应用"""
|
||||
self.root.mainloop()
|
||||
|
||||
25
debug_env.py
25
debug_env.py
@@ -1,25 +0,0 @@
|
||||
"""调试脚本"""
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
ENV_PATH = Path(__file__).parent / ".env"
|
||||
|
||||
print(f"ENV_PATH: {ENV_PATH}")
|
||||
print(f"ENV_PATH exists: {ENV_PATH.exists()}")
|
||||
|
||||
# 读取文件内容
|
||||
if ENV_PATH.exists():
|
||||
print(f"File content:")
|
||||
print(ENV_PATH.read_text(encoding='utf-8'))
|
||||
|
||||
# 加载环境变量
|
||||
result = load_dotenv(ENV_PATH)
|
||||
print(f"load_dotenv result: {result}")
|
||||
|
||||
# 检查环境变量
|
||||
print(f"LLM_API_URL: {os.getenv('LLM_API_URL')}")
|
||||
print(f"LLM_API_KEY: {os.getenv('LLM_API_KEY')}")
|
||||
print(f"INTENT_MODEL_NAME: {os.getenv('INTENT_MODEL_NAME')}")
|
||||
print(f"GENERATION_MODEL_NAME: {os.getenv('GENERATION_MODEL_NAME')}")
|
||||
|
||||
2
history/__init__.py
Normal file
2
history/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# 历史记录模块
|
||||
|
||||
189
history/manager.py
Normal file
189
history/manager.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
任务历史记录管理器
|
||||
保存和加载任务执行历史
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
"""任务记录"""
|
||||
task_id: str
|
||||
timestamp: str
|
||||
user_input: str
|
||||
intent_label: str
|
||||
intent_confidence: float
|
||||
execution_plan: str
|
||||
code: str
|
||||
success: bool
|
||||
duration_ms: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
log_path: str
|
||||
|
||||
|
||||
class HistoryManager:
|
||||
"""
|
||||
历史记录管理器
|
||||
|
||||
将任务历史保存为 JSON 文件
|
||||
"""
|
||||
|
||||
MAX_HISTORY_SIZE = 100 # 最多保存 100 条记录
|
||||
|
||||
def __init__(self, workspace_path: Optional[Path] = None):
|
||||
if workspace_path:
|
||||
self.workspace = workspace_path
|
||||
else:
|
||||
self.workspace = Path(__file__).parent.parent / "workspace"
|
||||
|
||||
self.history_file = self.workspace / "history.json"
|
||||
self._history: List[TaskRecord] = []
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""从文件加载历史记录"""
|
||||
if self.history_file.exists():
|
||||
try:
|
||||
with open(self.history_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self._history = [TaskRecord(**record) for record in data]
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
print(f"[警告] 加载历史记录失败: {e}")
|
||||
self._history = []
|
||||
else:
|
||||
self._history = []
|
||||
|
||||
def _save(self):
|
||||
"""保存历史记录到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.history_file, 'w', encoding='utf-8') as f:
|
||||
data = [asdict(record) for record in self._history]
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"[警告] 保存历史记录失败: {e}")
|
||||
|
||||
def add_record(
|
||||
self,
|
||||
task_id: str,
|
||||
user_input: str,
|
||||
intent_label: str,
|
||||
intent_confidence: float,
|
||||
execution_plan: str,
|
||||
code: str,
|
||||
success: bool,
|
||||
duration_ms: int,
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
log_path: str = ""
|
||||
) -> TaskRecord:
|
||||
"""
|
||||
添加一条任务记录
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
user_input: 用户输入
|
||||
intent_label: 意图标签
|
||||
intent_confidence: 意图置信度
|
||||
execution_plan: 执行计划
|
||||
code: 生成的代码
|
||||
success: 是否执行成功
|
||||
duration_ms: 执行耗时(毫秒)
|
||||
stdout: 标准输出
|
||||
stderr: 标准错误
|
||||
log_path: 日志文件路径
|
||||
|
||||
Returns:
|
||||
TaskRecord: 创建的记录
|
||||
"""
|
||||
record = TaskRecord(
|
||||
task_id=task_id,
|
||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
user_input=user_input,
|
||||
intent_label=intent_label,
|
||||
intent_confidence=intent_confidence,
|
||||
execution_plan=execution_plan,
|
||||
code=code,
|
||||
success=success,
|
||||
duration_ms=duration_ms,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
log_path=log_path
|
||||
)
|
||||
|
||||
# 添加到列表开头(最新的在前)
|
||||
self._history.insert(0, record)
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self._history) > self.MAX_HISTORY_SIZE:
|
||||
self._history = self._history[:self.MAX_HISTORY_SIZE]
|
||||
|
||||
# 保存
|
||||
self._save()
|
||||
|
||||
return record
|
||||
|
||||
def get_all(self) -> List[TaskRecord]:
|
||||
"""获取所有历史记录"""
|
||||
return self._history.copy()
|
||||
|
||||
def get_recent(self, count: int = 10) -> List[TaskRecord]:
|
||||
"""获取最近的 N 条记录"""
|
||||
return self._history[:count]
|
||||
|
||||
def get_by_id(self, task_id: str) -> Optional[TaskRecord]:
|
||||
"""根据任务 ID 获取记录"""
|
||||
for record in self._history:
|
||||
if record.task_id == task_id:
|
||||
return record
|
||||
return None
|
||||
|
||||
def clear(self):
|
||||
"""清空历史记录"""
|
||||
self._history = []
|
||||
self._save()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
if not self._history:
|
||||
return {
|
||||
'total': 0,
|
||||
'success': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0,
|
||||
'avg_duration_ms': 0
|
||||
}
|
||||
|
||||
total = len(self._history)
|
||||
success = sum(1 for r in self._history if r.success)
|
||||
failed = total - success
|
||||
avg_duration = sum(r.duration_ms for r in self._history) / total
|
||||
|
||||
return {
|
||||
'total': total,
|
||||
'success': success,
|
||||
'failed': failed,
|
||||
'success_rate': success / total if total > 0 else 0.0,
|
||||
'avg_duration_ms': int(avg_duration)
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_manager: Optional[HistoryManager] = None
|
||||
|
||||
|
||||
def get_history_manager(workspace_path: Optional[Path] = None) -> HistoryManager:
|
||||
"""获取历史记录管理器单例"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = HistoryManager(workspace_path)
|
||||
return _manager
|
||||
|
||||
228
llm/client.py
228
llm/client.py
@@ -2,13 +2,15 @@
|
||||
LLM 统一调用客户端
|
||||
所有模型通过 SiliconFlow API 调用
|
||||
支持流式和非流式两种模式
|
||||
支持自动重试机制
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Callable
|
||||
from typing import Optional, Generator, Callable, List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 获取项目根目录
|
||||
@@ -40,29 +42,78 @@ class LLMClient:
|
||||
model="Qwen/Qwen2.5-7B-Instruct"
|
||||
):
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
特性:
|
||||
- 自动重试:网络错误时自动重试(默认3次)
|
||||
- 指数退避:重试间隔逐渐增加
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 重试配置
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_RETRY_DELAY = 1.0 # 初始重试延迟(秒)
|
||||
DEFAULT_RETRY_BACKOFF = 2.0 # 退避倍数
|
||||
|
||||
def __init__(self, max_retries: int = DEFAULT_MAX_RETRIES):
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
self.api_url = os.getenv("LLM_API_URL")
|
||||
self.api_key = os.getenv("LLM_API_KEY")
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not self.api_url:
|
||||
raise LLMClientError("未配置 LLM_API_URL,请检查 .env 文件")
|
||||
if not self.api_key or self.api_key == "your_api_key_here":
|
||||
raise LLMClientError("未配置有效的 LLM_API_KEY,请检查 .env 文件")
|
||||
|
||||
def _should_retry(self, exception: Exception) -> bool:
|
||||
"""判断是否应该重试"""
|
||||
# 网络连接错误、超时错误可以重试
|
||||
if isinstance(exception, (requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout)):
|
||||
return True
|
||||
# 服务器错误(5xx)可以重试
|
||||
if isinstance(exception, LLMClientError):
|
||||
error_msg = str(exception)
|
||||
if "状态码: 5" in error_msg or "502" in error_msg or "503" in error_msg or "504" in error_msg:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _do_request_with_retry(
|
||||
self,
|
||||
request_func: Callable,
|
||||
operation_name: str = "请求"
|
||||
):
|
||||
"""带重试的请求执行"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return request_func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# 判断是否应该重试
|
||||
if attempt < self.max_retries and self._should_retry(e):
|
||||
delay = self.DEFAULT_RETRY_DELAY * (self.DEFAULT_RETRY_BACKOFF ** attempt)
|
||||
print(f"[重试] {operation_name}失败,{delay:.1f}秒后重试 ({attempt + 1}/{self.max_retries})...")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
# 所有重试都失败
|
||||
raise last_exception
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 1024,
|
||||
timeout: int = 180
|
||||
) -> str:
|
||||
"""
|
||||
调用 LLM 进行对话(非流式)
|
||||
调用 LLM 进行对话(非流式,带自动重试)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
@@ -74,60 +125,63 @@ class LLMClient:
|
||||
Returns:
|
||||
LLM 生成的文本内容
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||||
def do_request():
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
try:
|
||||
error_detail = response.json()
|
||||
if "error" in error_detail:
|
||||
error_msg += f": {error_detail['error']}"
|
||||
except:
|
||||
error_msg += f": {response.text[:200]}"
|
||||
raise LLMClientError(error_msg)
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||||
try:
|
||||
error_detail = response.json()
|
||||
if "error" in error_detail:
|
||||
error_msg += f": {error_detail['error']}"
|
||||
except:
|
||||
error_msg += f": {response.text[:200]}"
|
||||
raise LLMClientError(error_msg)
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
return content
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
return content
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
|
||||
return self._do_request_with_retry(do_request, "LLM调用")
|
||||
|
||||
def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
timeout: int = 180
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
调用 LLM 进行对话(流式)
|
||||
调用 LLM 进行对话(流式,带自动重试)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
@@ -139,43 +193,49 @@ class LLMClient:
|
||||
Yields:
|
||||
逐个返回生成的文本片段
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
stream=True
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||||
def do_request():
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
try:
|
||||
error_detail = response.json()
|
||||
if "error" in error_detail:
|
||||
error_msg += f": {error_detail['error']}"
|
||||
except:
|
||||
error_msg += f": {response.text[:200]}"
|
||||
raise LLMClientError(error_msg)
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
stream=True
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||||
try:
|
||||
error_detail = response.json()
|
||||
if "error" in error_detail:
|
||||
error_msg += f": {error_detail['error']}"
|
||||
except:
|
||||
error_msg += f": {response.text[:200]}"
|
||||
raise LLMClientError(error_msg)
|
||||
|
||||
return response
|
||||
|
||||
# 流式请求的重试只在建立连接阶段
|
||||
response = self._do_request_with_retry(do_request, "流式LLM调用")
|
||||
|
||||
# 解析 SSE 流
|
||||
for line in response.iter_lines():
|
||||
@@ -197,7 +257,7 @@ class LLMClient:
|
||||
|
||||
def chat_stream_collect(
|
||||
self,
|
||||
messages: list[dict],
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
|
||||
@@ -108,7 +108,8 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# 工作目录(固定,不要修改)
|
||||
WORKSPACE = Path(__file__).parent
|
||||
# 代码保存在 workspace/codes/ 目录,向上一级是 workspace
|
||||
WORKSPACE = Path(__file__).parent.parent
|
||||
INPUT_DIR = WORKSPACE / "input"
|
||||
OUTPUT_DIR = WORKSPACE / "output"
|
||||
|
||||
|
||||
450
main.py
450
main.py
@@ -37,10 +37,7 @@ import sys
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# 确保项目根目录在 Python 路径中
|
||||
PROJECT_ROOT = Path(__file__).parent
|
||||
@@ -50,435 +47,11 @@ sys.path.insert(0, str(PROJECT_ROOT))
|
||||
# 在导入其他模块之前先加载环境变量
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
from llm.client import get_client, LLMClientError
|
||||
from llm.prompts import (
|
||||
EXECUTION_PLAN_SYSTEM, EXECUTION_PLAN_USER,
|
||||
CODE_GENERATION_SYSTEM, CODE_GENERATION_USER
|
||||
)
|
||||
from intent.classifier import classify_intent, IntentResult
|
||||
from intent.labels import CHAT, EXECUTION
|
||||
from safety.rule_checker import check_code_safety
|
||||
from safety.llm_reviewer import review_code_safety
|
||||
from executor.sandbox_runner import SandboxRunner, ExecutionResult
|
||||
from ui.chat_view import ChatView
|
||||
from ui.task_guide_view import TaskGuideView
|
||||
from app.agent import LocalAgentApp
|
||||
|
||||
|
||||
class LocalAgentApp:
|
||||
"""
|
||||
LocalAgent 主应用
|
||||
|
||||
职责:
|
||||
1. 管理 UI 状态切换
|
||||
2. 协调各模块工作流程
|
||||
3. 处理用户交互
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.workspace = PROJECT_ROOT / "workspace"
|
||||
self.runner = SandboxRunner(str(self.workspace))
|
||||
|
||||
# 当前任务状态
|
||||
self.current_task: Optional[dict] = None
|
||||
|
||||
# 线程通信队列
|
||||
self.result_queue = queue.Queue()
|
||||
|
||||
# 初始化 UI
|
||||
self._init_ui()
|
||||
|
||||
def _init_ui(self):
|
||||
"""初始化 UI"""
|
||||
self.root = tk.Tk()
|
||||
self.root.title("LocalAgent - 本地 AI 助手")
|
||||
self.root.geometry("800x700")
|
||||
self.root.configure(bg='#1e1e1e')
|
||||
|
||||
# 设置窗口图标(如果有的话)
|
||||
try:
|
||||
self.root.iconbitmap(PROJECT_ROOT / "icon.ico")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 主容器
|
||||
self.main_container = tk.Frame(self.root, bg='#1e1e1e')
|
||||
self.main_container.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
# 聊天视图
|
||||
self.chat_view = ChatView(self.main_container, self._on_user_input)
|
||||
|
||||
# 任务引导视图(初始隐藏)
|
||||
self.task_view: Optional[TaskGuideView] = None
|
||||
|
||||
# 定期检查后台任务结果
|
||||
self._check_queue()
|
||||
|
||||
def _check_queue(self):
|
||||
"""检查后台任务队列"""
|
||||
try:
|
||||
while True:
|
||||
callback, args = self.result_queue.get_nowait()
|
||||
callback(*args)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# 每 100ms 检查一次
|
||||
self.root.after(100, self._check_queue)
|
||||
|
||||
def _run_in_thread(self, func, callback, *args):
|
||||
"""在后台线程运行函数,完成后回调"""
|
||||
def wrapper():
|
||||
try:
|
||||
result = func(*args)
|
||||
self.result_queue.put((callback, (result, None)))
|
||||
except Exception as e:
|
||||
self.result_queue.put((callback, (None, e)))
|
||||
|
||||
thread = threading.Thread(target=wrapper, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _on_user_input(self, user_input: str):
|
||||
"""处理用户输入"""
|
||||
# 显示用户消息
|
||||
self.chat_view.add_message(user_input, 'user')
|
||||
self.chat_view.set_input_enabled(False)
|
||||
self.chat_view.add_message("正在分析您的需求...", 'system')
|
||||
|
||||
# 在后台线程进行意图识别
|
||||
self._run_in_thread(
|
||||
classify_intent,
|
||||
lambda result, error: self._on_intent_result(user_input, result, error),
|
||||
user_input
|
||||
)
|
||||
|
||||
def _on_intent_result(self, user_input: str, intent_result: Optional[IntentResult], error: Optional[Exception]):
|
||||
"""意图识别完成回调"""
|
||||
if error:
|
||||
self.chat_view.add_message(f"意图识别失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
return
|
||||
|
||||
if intent_result.label == CHAT:
|
||||
# 对话模式
|
||||
self._handle_chat(user_input, intent_result)
|
||||
else:
|
||||
# 执行模式
|
||||
self._handle_execution(user_input, intent_result)
|
||||
|
||||
def _handle_chat(self, user_input: str, intent_result: IntentResult):
|
||||
"""处理对话任务"""
|
||||
self.chat_view.add_message(
|
||||
f"识别为对话模式 (原因: {intent_result.reason})",
|
||||
'system'
|
||||
)
|
||||
|
||||
# 开始流式消息
|
||||
self.chat_view.start_stream_message('assistant')
|
||||
|
||||
# 在后台线程调用 LLM(流式)
|
||||
def do_chat_stream():
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
full_response = []
|
||||
for chunk in client.chat_stream(
|
||||
messages=[{"role": "user", "content": user_input}],
|
||||
model=model,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
timeout=300
|
||||
):
|
||||
full_response.append(chunk)
|
||||
# 通过队列发送 chunk 到主线程更新 UI
|
||||
self.result_queue.put((self._on_chat_chunk, (chunk,)))
|
||||
|
||||
return ''.join(full_response)
|
||||
|
||||
self._run_in_thread(
|
||||
do_chat_stream,
|
||||
self._on_chat_complete
|
||||
)
|
||||
|
||||
def _on_chat_chunk(self, chunk: str):
|
||||
"""收到对话片段回调(主线程)"""
|
||||
self.chat_view.append_stream_chunk(chunk)
|
||||
|
||||
def _on_chat_complete(self, response: Optional[str], error: Optional[Exception]):
|
||||
"""对话完成回调"""
|
||||
self.chat_view.end_stream_message()
|
||||
|
||||
if error:
|
||||
self.chat_view.add_message(f"对话失败: {str(error)}", 'error')
|
||||
|
||||
self.chat_view.set_input_enabled(True)
|
||||
|
||||
def _handle_execution(self, user_input: str, intent_result: IntentResult):
|
||||
"""处理执行任务"""
|
||||
self.chat_view.add_message(
|
||||
f"识别为执行任务 (置信度: {intent_result.confidence:.0%})\n原因: {intent_result.reason}",
|
||||
'system'
|
||||
)
|
||||
self.chat_view.add_message("正在生成执行计划...", 'system')
|
||||
|
||||
# 保存用户输入和意图结果
|
||||
self.current_task = {
|
||||
'user_input': user_input,
|
||||
'intent_result': intent_result
|
||||
}
|
||||
|
||||
# 在后台线程生成执行计划
|
||||
self._run_in_thread(
|
||||
self._generate_execution_plan,
|
||||
self._on_plan_generated,
|
||||
user_input
|
||||
)
|
||||
|
||||
def _on_plan_generated(self, plan: Optional[str], error: Optional[Exception]):
|
||||
"""执行计划生成完成回调"""
|
||||
if error:
|
||||
self.chat_view.add_message(f"生成执行计划失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
self.current_task['execution_plan'] = plan
|
||||
self.chat_view.add_message("正在生成执行代码...", 'system')
|
||||
|
||||
# 在后台线程生成代码
|
||||
self._run_in_thread(
|
||||
self._generate_code,
|
||||
self._on_code_generated,
|
||||
self.current_task['user_input'],
|
||||
plan
|
||||
)
|
||||
|
||||
def _on_code_generated(self, code: Optional[str], error: Optional[Exception]):
|
||||
"""代码生成完成回调"""
|
||||
if error:
|
||||
self.chat_view.add_message(f"生成代码失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
self.current_task['code'] = code
|
||||
self.chat_view.add_message("正在进行安全检查...", 'system')
|
||||
|
||||
# 硬规则检查(同步,很快)
|
||||
rule_result = check_code_safety(code)
|
||||
if not rule_result.passed:
|
||||
violations = "\n".join(f" • {v}" for v in rule_result.violations)
|
||||
self.chat_view.add_message(
|
||||
f"安全检查未通过,任务已取消:\n{violations}",
|
||||
'error'
|
||||
)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
# 保存警告信息,传递给 LLM 审查
|
||||
self.current_task['warnings'] = rule_result.warnings
|
||||
|
||||
# 在后台线程进行 LLM 安全审查
|
||||
self._run_in_thread(
|
||||
lambda: review_code_safety(
|
||||
self.current_task['user_input'],
|
||||
self.current_task['execution_plan'],
|
||||
code,
|
||||
rule_result.warnings # 传递警告给 LLM
|
||||
),
|
||||
self._on_safety_reviewed
|
||||
)
|
||||
|
||||
def _on_safety_reviewed(self, review_result, error: Optional[Exception]):
|
||||
"""安全审查完成回调"""
|
||||
if error:
|
||||
self.chat_view.add_message(f"安全审查失败: {str(error)}", 'error')
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
if not review_result.passed:
|
||||
self.chat_view.add_message(
|
||||
f"安全审查未通过: {review_result.reason}",
|
||||
'error'
|
||||
)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
return
|
||||
|
||||
self.chat_view.add_message("安全检查通过,请确认执行", 'system')
|
||||
|
||||
# 显示任务引导视图
|
||||
self._show_task_guide()
|
||||
|
||||
def _generate_execution_plan(self, user_input: str) -> str:
|
||||
"""生成执行计划(使用流式传输)"""
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
# 使用流式传输,避免超时
|
||||
response = client.chat_stream_collect(
|
||||
messages=[
|
||||
{"role": "system", "content": EXECUTION_PLAN_SYSTEM},
|
||||
{"role": "user", "content": EXECUTION_PLAN_USER.format(user_input=user_input)}
|
||||
],
|
||||
model=model,
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _generate_code(self, user_input: str, execution_plan: str) -> str:
|
||||
"""生成执行代码(使用流式传输)"""
|
||||
client = get_client()
|
||||
model = os.getenv("GENERATION_MODEL_NAME")
|
||||
|
||||
# 使用流式传输,避免超时
|
||||
response = client.chat_stream_collect(
|
||||
messages=[
|
||||
{"role": "system", "content": CODE_GENERATION_SYSTEM},
|
||||
{"role": "user", "content": CODE_GENERATION_USER.format(
|
||||
user_input=user_input,
|
||||
execution_plan=execution_plan
|
||||
)}
|
||||
],
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
max_tokens=4096, # 代码可能较长
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
# 提取代码块
|
||||
code = self._extract_code(response)
|
||||
return code
|
||||
|
||||
def _extract_code(self, response: str) -> str:
|
||||
"""从 LLM 响应中提取代码"""
|
||||
import re
|
||||
|
||||
# 尝试提取 ```python ... ``` 代码块
|
||||
pattern = r'```python\s*(.*?)\s*```'
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
# 尝试提取 ``` ... ``` 代码块
|
||||
pattern = r'```\s*(.*?)\s*```'
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
# 如果没有代码块,返回原始响应
|
||||
return response
|
||||
|
||||
def _show_task_guide(self):
|
||||
"""显示任务引导视图"""
|
||||
if not self.current_task:
|
||||
return
|
||||
|
||||
# 隐藏聊天视图
|
||||
self.chat_view.get_frame().pack_forget()
|
||||
|
||||
# 创建任务引导视图
|
||||
self.task_view = TaskGuideView(
|
||||
self.main_container,
|
||||
on_execute=self._on_execute_task,
|
||||
on_cancel=self._on_cancel_task,
|
||||
workspace_path=self.workspace
|
||||
)
|
||||
|
||||
# 设置内容
|
||||
self.task_view.set_intent_result(
|
||||
self.current_task['intent_result'].reason,
|
||||
self.current_task['intent_result'].confidence
|
||||
)
|
||||
self.task_view.set_execution_plan(self.current_task['execution_plan'])
|
||||
|
||||
# 显示
|
||||
self.task_view.show()
|
||||
|
||||
def _on_execute_task(self):
|
||||
"""执行任务"""
|
||||
if not self.current_task:
|
||||
return
|
||||
|
||||
self.task_view.set_buttons_enabled(False)
|
||||
|
||||
# 在后台线程执行
|
||||
def do_execute():
|
||||
return self.runner.execute(self.current_task['code'])
|
||||
|
||||
self._run_in_thread(
|
||||
do_execute,
|
||||
self._on_execution_complete
|
||||
)
|
||||
|
||||
def _on_execution_complete(self, result: Optional[ExecutionResult], error: Optional[Exception]):
|
||||
"""执行完成回调"""
|
||||
if error:
|
||||
messagebox.showerror("执行错误", f"执行失败: {str(error)}")
|
||||
else:
|
||||
self._show_execution_result(result)
|
||||
# 刷新输出文件列表
|
||||
if self.task_view:
|
||||
self.task_view.refresh_output()
|
||||
|
||||
self._back_to_chat()
|
||||
|
||||
def _show_execution_result(self, result: ExecutionResult):
|
||||
"""显示执行结果"""
|
||||
if result.success:
|
||||
status = "执行成功"
|
||||
else:
|
||||
status = "执行失败"
|
||||
|
||||
message = f"""{status}
|
||||
|
||||
任务 ID: {result.task_id}
|
||||
耗时: {result.duration_ms} ms
|
||||
日志文件: {result.log_path}
|
||||
|
||||
输出:
|
||||
{result.stdout if result.stdout else '(无输出)'}
|
||||
|
||||
{f'错误信息: {result.stderr}' if result.stderr else ''}
|
||||
"""
|
||||
|
||||
if result.success:
|
||||
messagebox.showinfo("执行结果", message)
|
||||
# 打开 output 目录
|
||||
os.startfile(str(self.workspace / "output"))
|
||||
else:
|
||||
messagebox.showerror("执行结果", message)
|
||||
|
||||
def _on_cancel_task(self):
|
||||
"""取消任务"""
|
||||
self.current_task = None
|
||||
self._back_to_chat()
|
||||
|
||||
def _back_to_chat(self):
|
||||
"""返回聊天视图"""
|
||||
if self.task_view:
|
||||
self.task_view.hide()
|
||||
self.task_view = None
|
||||
|
||||
self.chat_view.get_frame().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
self.chat_view.set_input_enabled(True)
|
||||
self.current_task = None
|
||||
|
||||
def run(self):
|
||||
"""运行应用"""
|
||||
self.root.mainloop()
|
||||
|
||||
|
||||
def check_environment():
|
||||
def check_environment() -> bool:
|
||||
"""检查运行环境"""
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
|
||||
if not api_key or api_key == "your_api_key_here":
|
||||
@@ -510,6 +83,17 @@ def check_environment():
|
||||
return True
|
||||
|
||||
|
||||
def setup_workspace():
|
||||
"""创建工作目录"""
|
||||
workspace = PROJECT_ROOT / "workspace"
|
||||
(workspace / "input").mkdir(parents=True, exist_ok=True)
|
||||
(workspace / "output").mkdir(parents=True, exist_ok=True)
|
||||
(workspace / "logs").mkdir(parents=True, exist_ok=True)
|
||||
(workspace / "codes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return workspace
|
||||
|
||||
|
||||
def main():
|
||||
"""主入口"""
|
||||
print("=" * 50)
|
||||
@@ -521,19 +105,17 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
# 创建工作目录
|
||||
workspace = PROJECT_ROOT / "workspace"
|
||||
(workspace / "input").mkdir(parents=True, exist_ok=True)
|
||||
(workspace / "output").mkdir(parents=True, exist_ok=True)
|
||||
(workspace / "logs").mkdir(parents=True, exist_ok=True)
|
||||
workspace = setup_workspace()
|
||||
|
||||
print(f"工作目录: {workspace}")
|
||||
print(f"输入目录: {workspace / 'input'}")
|
||||
print(f"输出目录: {workspace / 'output'}")
|
||||
print(f"日志目录: {workspace / 'logs'}")
|
||||
print(f"代码目录: {workspace / 'codes'}")
|
||||
print("=" * 50)
|
||||
|
||||
# 启动应用
|
||||
app = LocalAgentApp()
|
||||
app = LocalAgentApp(PROJECT_ROOT)
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
@@ -14,3 +14,6 @@ openpyxl>=3.1.0 # Excel 处理
|
||||
python-docx>=1.0.0 # Word 文档处理
|
||||
PyPDF2>=3.0.0 # PDF 处理
|
||||
chardet>=5.0.0 # 文件编码检测
|
||||
|
||||
# 测试依赖(可选)
|
||||
pytest>=7.0.0 # 单元测试框架
|
||||
|
||||
2
tests/__init__.py
Normal file
2
tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# 测试模块
|
||||
|
||||
235
tests/test_history_manager.py
Normal file
235
tests/test_history_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
历史记录管理器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from history.manager import HistoryManager, TaskRecord
|
||||
|
||||
|
||||
class TestHistoryManager(unittest.TestCase):
|
||||
"""历史记录管理器测试"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建临时目录用于测试"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.manager = HistoryManager(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理临时目录"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_add_record(self):
|
||||
"""测试添加记录"""
|
||||
record = self.manager.add_record(
|
||||
task_id="test_001",
|
||||
user_input="复制文件",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制所有文件",
|
||||
code="shutil.copy(...)",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
self.assertEqual(record.task_id, "test_001")
|
||||
self.assertEqual(record.user_input, "复制文件")
|
||||
self.assertTrue(record.success)
|
||||
|
||||
def test_get_all(self):
|
||||
"""测试获取所有记录"""
|
||||
# 添加多条记录
|
||||
for i in range(3):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
records = self.manager.get_all()
|
||||
self.assertEqual(len(records), 3)
|
||||
|
||||
def test_get_recent(self):
|
||||
"""测试获取最近记录"""
|
||||
# 添加 5 条记录
|
||||
for i in range(5):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 获取最近 3 条
|
||||
recent = self.manager.get_recent(3)
|
||||
self.assertEqual(len(recent), 3)
|
||||
# 最新的在前
|
||||
self.assertEqual(recent[0].task_id, "test_004")
|
||||
|
||||
def test_get_by_id(self):
|
||||
"""测试根据 ID 获取记录"""
|
||||
self.manager.add_record(
|
||||
task_id="unique_id",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
record = self.manager.get_by_id("unique_id")
|
||||
self.assertIsNotNone(record)
|
||||
self.assertEqual(record.task_id, "unique_id")
|
||||
|
||||
# 不存在的 ID
|
||||
not_found = self.manager.get_by_id("not_exist")
|
||||
self.assertIsNone(not_found)
|
||||
|
||||
def test_clear(self):
|
||||
"""测试清空记录"""
|
||||
# 添加记录
|
||||
self.manager.add_record(
|
||||
task_id="test",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
self.assertEqual(len(self.manager.get_all()), 1)
|
||||
|
||||
# 清空
|
||||
self.manager.clear()
|
||||
self.assertEqual(len(self.manager.get_all()), 0)
|
||||
|
||||
def test_get_stats(self):
|
||||
"""测试统计信息"""
|
||||
# 添加成功和失败的记录
|
||||
self.manager.add_record(
|
||||
task_id="success_1",
|
||||
user_input="成功任务",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
self.manager.add_record(
|
||||
task_id="success_2",
|
||||
user_input="成功任务2",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=200
|
||||
)
|
||||
self.manager.add_record(
|
||||
task_id="failed_1",
|
||||
user_input="失败任务",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=False,
|
||||
duration_ms=50
|
||||
)
|
||||
|
||||
stats = self.manager.get_stats()
|
||||
self.assertEqual(stats['total'], 3)
|
||||
self.assertEqual(stats['success'], 2)
|
||||
self.assertEqual(stats['failed'], 1)
|
||||
self.assertAlmostEqual(stats['success_rate'], 2/3)
|
||||
|
||||
def test_persistence(self):
|
||||
"""测试持久化"""
|
||||
# 添加记录
|
||||
self.manager.add_record(
|
||||
task_id="persist_test",
|
||||
user_input="持久化测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 创建新的管理器实例(模拟重启)
|
||||
new_manager = HistoryManager(self.temp_dir)
|
||||
|
||||
# 应该能读取到之前的记录
|
||||
records = new_manager.get_all()
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0].task_id, "persist_test")
|
||||
|
||||
def test_max_history_size(self):
|
||||
"""测试历史记录数量限制"""
|
||||
# 添加超过限制的记录
|
||||
for i in range(HistoryManager.MAX_HISTORY_SIZE + 10):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 应该只保留最大数量
|
||||
records = self.manager.get_all()
|
||||
self.assertEqual(len(records), HistoryManager.MAX_HISTORY_SIZE)
|
||||
|
||||
|
||||
class TestTaskRecord(unittest.TestCase):
|
||||
"""任务记录数据类测试"""
|
||||
|
||||
def test_create_record(self):
|
||||
"""测试创建记录"""
|
||||
record = TaskRecord(
|
||||
task_id="test",
|
||||
timestamp="2024-01-01 12:00:00",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100,
|
||||
stdout="输出",
|
||||
stderr="",
|
||||
log_path="/path/to/log"
|
||||
)
|
||||
|
||||
self.assertEqual(record.task_id, "test")
|
||||
self.assertTrue(record.success)
|
||||
self.assertEqual(record.duration_ms, 100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
94
tests/test_intent_classifier.py
Normal file
94
tests/test_intent_classifier.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
意图分类器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from intent.labels import CHAT, EXECUTION, VALID_LABELS, EXECUTION_CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
class TestIntentLabels(unittest.TestCase):
|
||||
"""意图标签测试"""
|
||||
|
||||
def test_labels_defined(self):
|
||||
"""测试标签已定义"""
|
||||
self.assertEqual(CHAT, "chat")
|
||||
self.assertEqual(EXECUTION, "execution")
|
||||
|
||||
def test_valid_labels(self):
|
||||
"""测试有效标签集合"""
|
||||
self.assertIn(CHAT, VALID_LABELS)
|
||||
self.assertIn(EXECUTION, VALID_LABELS)
|
||||
self.assertEqual(len(VALID_LABELS), 2)
|
||||
|
||||
def test_confidence_threshold(self):
|
||||
"""测试置信度阈值"""
|
||||
self.assertGreater(EXECUTION_CONFIDENCE_THRESHOLD, 0)
|
||||
self.assertLessEqual(EXECUTION_CONFIDENCE_THRESHOLD, 1)
|
||||
|
||||
|
||||
class TestIntentClassifierParsing(unittest.TestCase):
|
||||
"""意图分类器解析测试(不需要 API)"""
|
||||
|
||||
def setUp(self):
|
||||
from intent.classifier import IntentClassifier
|
||||
self.classifier = IntentClassifier()
|
||||
|
||||
def test_parse_valid_chat_response(self):
|
||||
"""测试解析有效的 chat 响应"""
|
||||
response = '{"label": "chat", "confidence": 0.95, "reason": "这是一个问答"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
self.assertEqual(result.confidence, 0.95)
|
||||
self.assertEqual(result.reason, "这是一个问答")
|
||||
|
||||
def test_parse_valid_execution_response(self):
|
||||
"""测试解析有效的 execution 响应"""
|
||||
response = '{"label": "execution", "confidence": 0.9, "reason": "需要复制文件"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, EXECUTION)
|
||||
self.assertEqual(result.confidence, 0.9)
|
||||
|
||||
def test_parse_low_confidence_execution(self):
|
||||
"""测试低置信度的 execution 降级为 chat"""
|
||||
response = '{"label": "execution", "confidence": 0.5, "reason": "不太确定"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
# 低于阈值应该降级为 chat
|
||||
self.assertEqual(result.label, CHAT)
|
||||
|
||||
def test_parse_invalid_label(self):
|
||||
"""测试无效标签降级为 chat"""
|
||||
response = '{"label": "unknown", "confidence": 0.9, "reason": "测试"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
|
||||
def test_parse_invalid_json(self):
|
||||
"""测试无效 JSON 降级为 chat"""
|
||||
response = 'not a json'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
def test_extract_json_with_prefix(self):
|
||||
"""测试从带前缀的文本中提取 JSON"""
|
||||
text = 'Here is the result: {"label": "chat", "confidence": 0.8, "reason": "test"}'
|
||||
json_str = self.classifier._extract_json(text)
|
||||
self.assertTrue(json_str.startswith('{'))
|
||||
self.assertTrue(json_str.endswith('}'))
|
||||
|
||||
def test_extract_json_with_suffix(self):
|
||||
"""测试从带后缀的文本中提取 JSON"""
|
||||
text = '{"label": "chat", "confidence": 0.8, "reason": "test"} That is my answer.'
|
||||
json_str = self.classifier._extract_json(text)
|
||||
self.assertTrue(json_str.startswith('{'))
|
||||
self.assertTrue(json_str.endswith('}'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
160
tests/test_rule_checker.py
Normal file
160
tests/test_rule_checker.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
安全检查器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from safety.rule_checker import RuleChecker, check_code_safety
|
||||
|
||||
|
||||
class TestRuleChecker(unittest.TestCase):
|
||||
"""规则检查器测试"""
|
||||
|
||||
def setUp(self):
|
||||
self.checker = RuleChecker()
|
||||
|
||||
# ========== 硬性禁止测试 ==========
|
||||
|
||||
def test_block_socket_import(self):
|
||||
"""测试禁止 socket 模块"""
|
||||
code = "import socket\ns = socket.socket()"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('socket' in v for v in result.violations))
|
||||
|
||||
def test_block_subprocess_import(self):
|
||||
"""测试禁止 subprocess 模块"""
|
||||
code = "import subprocess\nsubprocess.run(['ls'])"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('subprocess' in v for v in result.violations))
|
||||
|
||||
def test_block_eval(self):
|
||||
"""测试禁止 eval"""
|
||||
code = "result = eval('1+1')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('eval' in v for v in result.violations))
|
||||
|
||||
def test_block_exec(self):
|
||||
"""测试禁止 exec"""
|
||||
code = "exec('print(1)')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('exec' in v for v in result.violations))
|
||||
|
||||
def test_block_os_system(self):
|
||||
"""测试禁止 os.system"""
|
||||
code = "import os\nos.system('dir')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('os.system' in v for v in result.violations))
|
||||
|
||||
def test_block_os_popen(self):
|
||||
"""测试禁止 os.popen"""
|
||||
code = "import os\nos.popen('dir')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('os.popen' in v for v in result.violations))
|
||||
|
||||
# ========== 警告测试 ==========
|
||||
|
||||
def test_warn_requests_import(self):
|
||||
"""测试 requests 模块产生警告"""
|
||||
code = "import requests\nresponse = requests.get('http://example.com')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('requests' in w for w in result.warnings))
|
||||
|
||||
def test_warn_os_remove(self):
|
||||
"""测试 os.remove 产生警告"""
|
||||
code = "import os\nos.remove('file.txt')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('os.remove' in w for w in result.warnings))
|
||||
|
||||
def test_warn_shutil_rmtree(self):
|
||||
"""测试 shutil.rmtree 产生警告"""
|
||||
code = "import shutil\nshutil.rmtree('folder')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('shutil.rmtree' in w for w in result.warnings))
|
||||
|
||||
# ========== 安全代码测试 ==========
|
||||
|
||||
def test_safe_file_copy(self):
|
||||
"""测试安全的文件复制代码"""
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
def test_safe_image_processing(self):
|
||||
"""测试安全的图片处理代码"""
|
||||
code = """
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for img_path in INPUT_DIR.glob('*.png'):
|
||||
img = Image.open(img_path)
|
||||
img = img.resize((100, 100))
|
||||
img.save(OUTPUT_DIR / img_path.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
def test_safe_excel_processing(self):
|
||||
"""测试安全的 Excel 处理代码"""
|
||||
code = """
|
||||
import openpyxl
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for xlsx_path in INPUT_DIR.glob('*.xlsx'):
|
||||
wb = openpyxl.load_workbook(xlsx_path)
|
||||
ws = wb.active
|
||||
ws['A1'] = 'Modified'
|
||||
wb.save(OUTPUT_DIR / xlsx_path.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
|
||||
class TestCheckCodeSafety(unittest.TestCase):
|
||||
"""便捷函数测试"""
|
||||
|
||||
def test_convenience_function(self):
|
||||
"""测试便捷函数"""
|
||||
result = check_code_safety("print('hello')")
|
||||
self.assertTrue(result.passed)
|
||||
|
||||
def test_convenience_function_block(self):
|
||||
"""测试便捷函数阻止危险代码"""
|
||||
result = check_code_safety("import socket")
|
||||
self.assertFalse(result.passed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
106
ui/chat_view.py
106
ui/chat_view.py
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
聊天视图组件
|
||||
处理普通对话的 UI 展示 - 支持流式消息
|
||||
处理普通对话的 UI 展示 - 支持流式消息和加载动画
|
||||
"""
|
||||
|
||||
import tkinter as tk
|
||||
@@ -8,6 +8,58 @@ from tkinter import scrolledtext
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
class LoadingIndicator:
|
||||
"""加载动画指示器"""
|
||||
|
||||
FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
||||
|
||||
def __init__(self, parent: tk.Widget, text: str = "处理中"):
|
||||
self.parent = parent
|
||||
self.text = text
|
||||
self.frame_index = 0
|
||||
self.running = False
|
||||
self.after_id = None
|
||||
|
||||
# 创建标签
|
||||
self.label = tk.Label(
|
||||
parent,
|
||||
text="",
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
fg='#ffd54f',
|
||||
bg='#1e1e1e'
|
||||
)
|
||||
|
||||
def start(self, text: str = None):
|
||||
"""开始动画"""
|
||||
if text:
|
||||
self.text = text
|
||||
self.running = True
|
||||
self.label.pack(pady=5)
|
||||
self._animate()
|
||||
|
||||
def stop(self):
|
||||
"""停止动画"""
|
||||
self.running = False
|
||||
if self.after_id:
|
||||
self.parent.after_cancel(self.after_id)
|
||||
self.after_id = None
|
||||
self.label.pack_forget()
|
||||
|
||||
def update_text(self, text: str):
|
||||
"""更新提示文字"""
|
||||
self.text = text
|
||||
|
||||
def _animate(self):
|
||||
"""动画帧更新"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
frame = self.FRAMES[self.frame_index]
|
||||
self.label.config(text=f"{frame} {self.text}...")
|
||||
self.frame_index = (self.frame_index + 1) % len(self.FRAMES)
|
||||
self.after_id = self.parent.after(100, self._animate)
|
||||
|
||||
|
||||
class ChatView:
|
||||
"""
|
||||
聊天视图
|
||||
@@ -22,7 +74,8 @@ class ChatView:
|
||||
def __init__(
|
||||
self,
|
||||
parent: tk.Widget,
|
||||
on_send: Callable[[str], None]
|
||||
on_send: Callable[[str], None],
|
||||
on_show_history: Optional[Callable[[], None]] = None
|
||||
):
|
||||
"""
|
||||
初始化聊天视图
|
||||
@@ -30,14 +83,19 @@ class ChatView:
|
||||
Args:
|
||||
parent: 父容器
|
||||
on_send: 发送消息回调函数
|
||||
on_show_history: 显示历史记录回调函数
|
||||
"""
|
||||
self.parent = parent
|
||||
self.on_send = on_send
|
||||
self.on_show_history = on_show_history
|
||||
|
||||
# 流式消息状态
|
||||
self._stream_active = False
|
||||
self._stream_tag = None
|
||||
|
||||
# 加载指示器
|
||||
self.loading: Optional[LoadingIndicator] = None
|
||||
|
||||
self._create_widgets()
|
||||
|
||||
def _create_widgets(self):
|
||||
@@ -46,15 +104,37 @@ class ChatView:
|
||||
self.frame = tk.Frame(self.parent, bg='#1e1e1e')
|
||||
self.frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
|
||||
# 标题栏(包含标题和历史按钮)
|
||||
title_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
||||
title_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
# 标题
|
||||
title_label = tk.Label(
|
||||
self.frame,
|
||||
title_frame,
|
||||
text="LocalAgent - 本地 AI 助手",
|
||||
font=('Microsoft YaHei UI', 16, 'bold'),
|
||||
fg='#61dafb',
|
||||
bg='#1e1e1e'
|
||||
)
|
||||
title_label.pack(pady=(0, 10))
|
||||
title_label.pack(side=tk.LEFT, expand=True)
|
||||
|
||||
# 历史记录按钮
|
||||
if self.on_show_history:
|
||||
self.history_btn = tk.Button(
|
||||
title_frame,
|
||||
text="📜 历史",
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
bg='#424242',
|
||||
fg='#ce93d8',
|
||||
activebackground='#616161',
|
||||
activeforeground='#ce93d8',
|
||||
relief=tk.FLAT,
|
||||
padx=10,
|
||||
pady=3,
|
||||
cursor='hand2',
|
||||
command=self.on_show_history
|
||||
)
|
||||
self.history_btn.pack(side=tk.RIGHT)
|
||||
|
||||
# 消息显示区域
|
||||
self.message_area = scrolledtext.ScrolledText(
|
||||
@@ -118,6 +198,9 @@ class ChatView:
|
||||
"- 输入文件处理需求(如\"复制文件\"、\"整理图片\")将触发执行模式"
|
||||
)
|
||||
self.add_message(welcome_msg, 'system')
|
||||
|
||||
# 创建加载指示器(放在消息区域下方)
|
||||
self.loading = LoadingIndicator(self.frame)
|
||||
|
||||
def _on_enter_pressed(self, event):
|
||||
"""回车键处理"""
|
||||
@@ -214,6 +297,21 @@ class ChatView:
|
||||
self.input_entry.config(state=state)
|
||||
self.send_button.config(state=state)
|
||||
|
||||
def show_loading(self, text: str = "处理中"):
|
||||
"""显示加载动画"""
|
||||
if self.loading:
|
||||
self.loading.start(text)
|
||||
|
||||
def hide_loading(self):
|
||||
"""隐藏加载动画"""
|
||||
if self.loading:
|
||||
self.loading.stop()
|
||||
|
||||
def update_loading_text(self, text: str):
|
||||
"""更新加载提示文字"""
|
||||
if self.loading:
|
||||
self.loading.update_text(text)
|
||||
|
||||
def get_frame(self) -> tk.Frame:
|
||||
"""获取主框架"""
|
||||
return self.frame
|
||||
|
||||
335
ui/history_view.py
Normal file
335
ui/history_view.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
历史记录视图组件
|
||||
显示任务执行历史
|
||||
"""
|
||||
|
||||
import tkinter as tk
|
||||
from tkinter import ttk, messagebox
|
||||
from typing import Callable, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from history.manager import TaskRecord, HistoryManager
|
||||
|
||||
|
||||
class HistoryView:
|
||||
"""
|
||||
历史记录视图
|
||||
|
||||
显示任务执行历史列表,支持查看详情
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: tk.Widget,
|
||||
history_manager: HistoryManager,
|
||||
on_back: Callable[[], None]
|
||||
):
|
||||
self.parent = parent
|
||||
self.history = history_manager
|
||||
self.on_back = on_back
|
||||
|
||||
self._selected_record: Optional[TaskRecord] = None
|
||||
self._create_widgets()
|
||||
|
||||
def _create_widgets(self):
|
||||
"""创建 UI 组件"""
|
||||
self.frame = tk.Frame(self.parent, bg='#1e1e1e')
|
||||
|
||||
# 标题栏
|
||||
title_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
||||
title_frame.pack(fill=tk.X, padx=10, pady=10)
|
||||
|
||||
# 返回按钮
|
||||
back_btn = tk.Button(
|
||||
title_frame,
|
||||
text="← 返回",
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
bg='#424242',
|
||||
fg='white',
|
||||
activebackground='#616161',
|
||||
activeforeground='white',
|
||||
relief=tk.FLAT,
|
||||
padx=10,
|
||||
cursor='hand2',
|
||||
command=self.on_back
|
||||
)
|
||||
back_btn.pack(side=tk.LEFT)
|
||||
|
||||
# 标题
|
||||
title_label = tk.Label(
|
||||
title_frame,
|
||||
text="📜 任务历史记录",
|
||||
font=('Microsoft YaHei UI', 14, 'bold'),
|
||||
fg='#ce93d8',
|
||||
bg='#1e1e1e'
|
||||
)
|
||||
title_label.pack(side=tk.LEFT, padx=20)
|
||||
|
||||
# 统计信息
|
||||
stats = self.history.get_stats()
|
||||
stats_text = f"共 {stats['total']} 条 | 成功 {stats['success']} | 失败 {stats['failed']} | 成功率 {stats['success_rate']:.0%}"
|
||||
stats_label = tk.Label(
|
||||
title_frame,
|
||||
text=stats_text,
|
||||
font=('Microsoft YaHei UI', 9),
|
||||
fg='#888888',
|
||||
bg='#1e1e1e'
|
||||
)
|
||||
stats_label.pack(side=tk.RIGHT)
|
||||
|
||||
# 主内容区域(左右分栏)
|
||||
content_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
||||
content_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=(0, 10))
|
||||
|
||||
# 左侧:历史列表
|
||||
list_frame = tk.LabelFrame(
|
||||
content_frame,
|
||||
text=" 任务列表 ",
|
||||
font=('Microsoft YaHei UI', 10, 'bold'),
|
||||
fg='#4fc3f7',
|
||||
bg='#1e1e1e',
|
||||
relief=tk.GROOVE
|
||||
)
|
||||
list_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 5))
|
||||
|
||||
# 列表框
|
||||
list_container = tk.Frame(list_frame, bg='#2d2d2d')
|
||||
list_container.pack(fill=tk.BOTH, expand=True, padx=3, pady=3)
|
||||
|
||||
# 使用 Treeview 显示列表
|
||||
columns = ('time', 'input', 'status', 'duration')
|
||||
self.tree = ttk.Treeview(list_container, columns=columns, show='headings', height=15)
|
||||
|
||||
# 配置列
|
||||
self.tree.heading('time', text='时间')
|
||||
self.tree.heading('input', text='任务描述')
|
||||
self.tree.heading('status', text='状态')
|
||||
self.tree.heading('duration', text='耗时')
|
||||
|
||||
self.tree.column('time', width=120, minwidth=100)
|
||||
self.tree.column('input', width=250, minwidth=150)
|
||||
self.tree.column('status', width=60, minwidth=50)
|
||||
self.tree.column('duration', width=70, minwidth=50)
|
||||
|
||||
# 滚动条
|
||||
scrollbar = ttk.Scrollbar(list_container, orient=tk.VERTICAL, command=self.tree.yview)
|
||||
self.tree.configure(yscrollcommand=scrollbar.set)
|
||||
|
||||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||
self.tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
# 绑定选择事件
|
||||
self.tree.bind('<<TreeviewSelect>>', self._on_select)
|
||||
|
||||
# 右侧:详情面板
|
||||
detail_frame = tk.LabelFrame(
|
||||
content_frame,
|
||||
text=" 任务详情 ",
|
||||
font=('Microsoft YaHei UI', 10, 'bold'),
|
||||
fg='#81c784',
|
||||
bg='#1e1e1e',
|
||||
relief=tk.GROOVE
|
||||
)
|
||||
detail_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=(5, 0))
|
||||
|
||||
# 详情文本框
|
||||
detail_container = tk.Frame(detail_frame, bg='#2d2d2d')
|
||||
detail_container.pack(fill=tk.BOTH, expand=True, padx=3, pady=3)
|
||||
|
||||
self.detail_text = tk.Text(
|
||||
detail_container,
|
||||
wrap=tk.WORD,
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
bg='#2d2d2d',
|
||||
fg='#d4d4d4',
|
||||
relief=tk.FLAT,
|
||||
padx=10,
|
||||
pady=10,
|
||||
state=tk.DISABLED
|
||||
)
|
||||
|
||||
detail_scrollbar = ttk.Scrollbar(detail_container, orient=tk.VERTICAL, command=self.detail_text.yview)
|
||||
self.detail_text.configure(yscrollcommand=detail_scrollbar.set)
|
||||
|
||||
detail_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||
self.detail_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
# 配置详情文本样式
|
||||
self.detail_text.tag_configure('title', font=('Microsoft YaHei UI', 11, 'bold'), foreground='#ffd54f')
|
||||
self.detail_text.tag_configure('label', font=('Microsoft YaHei UI', 10, 'bold'), foreground='#4fc3f7')
|
||||
self.detail_text.tag_configure('success', foreground='#81c784')
|
||||
self.detail_text.tag_configure('error', foreground='#ef5350')
|
||||
self.detail_text.tag_configure('code', font=('Consolas', 9), foreground='#ce93d8')
|
||||
|
||||
# 底部按钮
|
||||
btn_frame = tk.Frame(self.frame, bg='#1e1e1e')
|
||||
btn_frame.pack(fill=tk.X, padx=10, pady=(0, 10))
|
||||
|
||||
# 打开日志按钮
|
||||
self.open_log_btn = tk.Button(
|
||||
btn_frame,
|
||||
text="📄 打开日志",
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
bg='#424242',
|
||||
fg='white',
|
||||
activebackground='#616161',
|
||||
activeforeground='white',
|
||||
relief=tk.FLAT,
|
||||
padx=15,
|
||||
cursor='hand2',
|
||||
state=tk.DISABLED,
|
||||
command=self._open_log
|
||||
)
|
||||
self.open_log_btn.pack(side=tk.LEFT)
|
||||
|
||||
# 清空历史按钮
|
||||
clear_btn = tk.Button(
|
||||
btn_frame,
|
||||
text="🗑️ 清空历史",
|
||||
font=('Microsoft YaHei UI', 10),
|
||||
bg='#d32f2f',
|
||||
fg='white',
|
||||
activebackground='#f44336',
|
||||
activeforeground='white',
|
||||
relief=tk.FLAT,
|
||||
padx=15,
|
||||
cursor='hand2',
|
||||
command=self._clear_history
|
||||
)
|
||||
clear_btn.pack(side=tk.RIGHT)
|
||||
|
||||
# 加载数据
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
"""加载历史数据到列表"""
|
||||
# 清空现有数据
|
||||
for item in self.tree.get_children():
|
||||
self.tree.delete(item)
|
||||
|
||||
# 加载历史记录
|
||||
records = self.history.get_all()
|
||||
|
||||
for record in records:
|
||||
# 截断过长的输入
|
||||
input_text = record.user_input
|
||||
if len(input_text) > 30:
|
||||
input_text = input_text[:30] + "..."
|
||||
|
||||
status = "✓ 成功" if record.success else "✗ 失败"
|
||||
duration = f"{record.duration_ms}ms"
|
||||
|
||||
# 提取时间(只显示时分秒)
|
||||
time_parts = record.timestamp.split(' ')
|
||||
time_str = time_parts[1] if len(time_parts) > 1 else record.timestamp
|
||||
date_str = time_parts[0] if len(time_parts) > 0 else ""
|
||||
display_time = f"{date_str}\n{time_str}"
|
||||
|
||||
self.tree.insert('', tk.END, iid=record.task_id, values=(
|
||||
record.timestamp,
|
||||
input_text,
|
||||
status,
|
||||
duration
|
||||
))
|
||||
|
||||
# 显示空状态提示
|
||||
if not records:
|
||||
self._show_detail("暂无历史记录\n\n执行任务后,记录将显示在这里。")
|
||||
|
||||
def _on_select(self, event):
|
||||
"""选择记录事件"""
|
||||
selection = self.tree.selection()
|
||||
if not selection:
|
||||
return
|
||||
|
||||
task_id = selection[0]
|
||||
record = self.history.get_by_id(task_id)
|
||||
|
||||
if record:
|
||||
self._selected_record = record
|
||||
self._show_record_detail(record)
|
||||
self.open_log_btn.config(state=tk.NORMAL)
|
||||
|
||||
def _show_record_detail(self, record: TaskRecord):
|
||||
"""显示记录详情"""
|
||||
self.detail_text.config(state=tk.NORMAL)
|
||||
self.detail_text.delete(1.0, tk.END)
|
||||
|
||||
# 标题
|
||||
self.detail_text.insert(tk.END, f"任务 ID: {record.task_id}\n", 'title')
|
||||
self.detail_text.insert(tk.END, f"时间: {record.timestamp}\n\n")
|
||||
|
||||
# 用户输入
|
||||
self.detail_text.insert(tk.END, "用户输入:\n", 'label')
|
||||
self.detail_text.insert(tk.END, f"{record.user_input}\n\n")
|
||||
|
||||
# 执行状态
|
||||
self.detail_text.insert(tk.END, "执行状态: ", 'label')
|
||||
if record.success:
|
||||
self.detail_text.insert(tk.END, "成功 ✓\n", 'success')
|
||||
else:
|
||||
self.detail_text.insert(tk.END, "失败 ✗\n", 'error')
|
||||
|
||||
self.detail_text.insert(tk.END, f"耗时: {record.duration_ms}ms\n\n")
|
||||
|
||||
# 执行计划
|
||||
self.detail_text.insert(tk.END, "执行计划:\n", 'label')
|
||||
plan_preview = record.execution_plan[:500] + "..." if len(record.execution_plan) > 500 else record.execution_plan
|
||||
self.detail_text.insert(tk.END, f"{plan_preview}\n\n")
|
||||
|
||||
# 输出
|
||||
if record.stdout:
|
||||
self.detail_text.insert(tk.END, "输出:\n", 'label')
|
||||
self.detail_text.insert(tk.END, f"{record.stdout}\n\n")
|
||||
|
||||
# 错误
|
||||
if record.stderr:
|
||||
self.detail_text.insert(tk.END, "错误:\n", 'label')
|
||||
self.detail_text.insert(tk.END, f"{record.stderr}\n", 'error')
|
||||
|
||||
self.detail_text.config(state=tk.DISABLED)
|
||||
|
||||
def _show_detail(self, text: str):
|
||||
"""显示详情文本"""
|
||||
self.detail_text.config(state=tk.NORMAL)
|
||||
self.detail_text.delete(1.0, tk.END)
|
||||
self.detail_text.insert(tk.END, text)
|
||||
self.detail_text.config(state=tk.DISABLED)
|
||||
|
||||
def _open_log(self):
|
||||
"""打开日志文件"""
|
||||
if self._selected_record and self._selected_record.log_path:
|
||||
import os
|
||||
log_path = Path(self._selected_record.log_path)
|
||||
if log_path.exists():
|
||||
os.startfile(str(log_path))
|
||||
else:
|
||||
messagebox.showwarning("提示", f"日志文件不存在:\n{log_path}")
|
||||
|
||||
def _clear_history(self):
|
||||
"""清空历史记录"""
|
||||
result = messagebox.askyesno(
|
||||
"确认清空",
|
||||
"确定要清空所有历史记录吗?\n此操作不可恢复。",
|
||||
icon='warning'
|
||||
)
|
||||
|
||||
if result:
|
||||
self.history.clear()
|
||||
self._load_data()
|
||||
self._show_detail("历史记录已清空")
|
||||
self.open_log_btn.config(state=tk.DISABLED)
|
||||
|
||||
def show(self):
|
||||
"""显示视图"""
|
||||
self._load_data() # 刷新数据
|
||||
self.frame.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
def hide(self):
|
||||
"""隐藏视图"""
|
||||
self.frame.pack_forget()
|
||||
|
||||
def get_frame(self) -> tk.Frame:
|
||||
"""获取主框架"""
|
||||
return self.frame
|
||||
|
||||
@@ -243,6 +243,9 @@ class TaskGuideView:
|
||||
# 执行计划区域(Markdown)
|
||||
self._create_plan_section()
|
||||
|
||||
# 代码预览区域(可折叠)
|
||||
self._create_code_section()
|
||||
|
||||
# 风险提示区域
|
||||
self._create_risk_section()
|
||||
|
||||
@@ -306,6 +309,148 @@ class TaskGuideView:
|
||||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||
self.plan_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
def _create_code_section(self):
|
||||
"""创建代码预览区域(可折叠)"""
|
||||
# 折叠状态
|
||||
self._code_expanded = False
|
||||
|
||||
# 外层框架
|
||||
self.code_section = tk.LabelFrame(
|
||||
self.frame,
|
||||
text=" 💻 生成的代码 ",
|
||||
font=('Microsoft YaHei UI', 10, 'bold'),
|
||||
fg='#64b5f6',
|
||||
bg='#1e1e1e',
|
||||
relief=tk.GROOVE
|
||||
)
|
||||
self.code_section.pack(fill=tk.X, padx=10, pady=3)
|
||||
|
||||
# 展开/折叠按钮
|
||||
self.toggle_code_btn = tk.Button(
|
||||
self.code_section,
|
||||
text="▶ 点击展开代码预览",
|
||||
font=('Microsoft YaHei UI', 9),
|
||||
bg='#2d2d2d',
|
||||
fg='#64b5f6',
|
||||
activebackground='#3d3d3d',
|
||||
activeforeground='#64b5f6',
|
||||
relief=tk.FLAT,
|
||||
cursor='hand2',
|
||||
command=self._toggle_code_view
|
||||
)
|
||||
self.toggle_code_btn.pack(fill=tk.X, padx=5, pady=5)
|
||||
|
||||
# 代码显示区域(初始隐藏)
|
||||
self.code_frame = tk.Frame(self.code_section, bg='#1e1e1e')
|
||||
|
||||
# 代码文本框
|
||||
self.code_text = tk.Text(
|
||||
self.code_frame,
|
||||
wrap=tk.NONE,
|
||||
font=('Consolas', 10),
|
||||
bg='#1e1e1e',
|
||||
fg='#d4d4d4',
|
||||
insertbackground='white',
|
||||
relief=tk.FLAT,
|
||||
height=12,
|
||||
padx=8,
|
||||
pady=5
|
||||
)
|
||||
|
||||
# 配置代码高亮标签
|
||||
self.code_text.tag_configure('keyword', foreground='#569cd6')
|
||||
self.code_text.tag_configure('string', foreground='#ce9178')
|
||||
self.code_text.tag_configure('comment', foreground='#6a9955')
|
||||
self.code_text.tag_configure('function', foreground='#dcdcaa')
|
||||
self.code_text.tag_configure('number', foreground='#b5cea8')
|
||||
|
||||
# 滚动条
|
||||
code_scrollbar_y = ttk.Scrollbar(self.code_frame, orient=tk.VERTICAL, command=self.code_text.yview)
|
||||
code_scrollbar_x = ttk.Scrollbar(self.code_frame, orient=tk.HORIZONTAL, command=self.code_text.xview)
|
||||
self.code_text.configure(yscrollcommand=code_scrollbar_y.set, xscrollcommand=code_scrollbar_x.set)
|
||||
|
||||
code_scrollbar_y.pack(side=tk.RIGHT, fill=tk.Y)
|
||||
code_scrollbar_x.pack(side=tk.BOTTOM, fill=tk.X)
|
||||
self.code_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
# 复制按钮
|
||||
self.copy_code_btn = tk.Button(
|
||||
self.code_frame,
|
||||
text="📋 复制代码",
|
||||
font=('Microsoft YaHei UI', 9),
|
||||
bg='#424242',
|
||||
fg='white',
|
||||
activebackground='#616161',
|
||||
activeforeground='white',
|
||||
relief=tk.FLAT,
|
||||
cursor='hand2',
|
||||
command=self._copy_code
|
||||
)
|
||||
|
||||
def _toggle_code_view(self):
|
||||
"""切换代码预览的展开/折叠状态"""
|
||||
self._code_expanded = not self._code_expanded
|
||||
|
||||
if self._code_expanded:
|
||||
self.toggle_code_btn.config(text="▼ 点击折叠代码预览")
|
||||
self.code_frame.pack(fill=tk.BOTH, expand=True, padx=3, pady=(0, 5))
|
||||
self.copy_code_btn.pack(pady=5)
|
||||
else:
|
||||
self.toggle_code_btn.config(text="▶ 点击展开代码预览")
|
||||
self.copy_code_btn.pack_forget()
|
||||
self.code_frame.pack_forget()
|
||||
|
||||
def _copy_code(self):
|
||||
"""复制代码到剪贴板"""
|
||||
code = self.code_text.get(1.0, tk.END).strip()
|
||||
self.frame.clipboard_clear()
|
||||
self.frame.clipboard_append(code)
|
||||
|
||||
# 显示复制成功提示
|
||||
original_text = self.copy_code_btn.cget('text')
|
||||
self.copy_code_btn.config(text="✓ 已复制!")
|
||||
self.frame.after(1500, lambda: self.copy_code_btn.config(text=original_text))
|
||||
|
||||
def _apply_syntax_highlight(self, code: str):
|
||||
"""应用简单的语法高亮"""
|
||||
import re
|
||||
|
||||
# 关键字
|
||||
keywords = r'\b(import|from|def|class|if|else|elif|for|while|try|except|finally|with|as|return|yield|raise|pass|break|continue|and|or|not|in|is|None|True|False|lambda|global|nonlocal)\b'
|
||||
# 字符串
|
||||
strings = r'(\"\"\"[\s\S]*?\"\"\"|\'\'\'[\s\S]*?\'\'\'|\"[^\"]*\"|\'[^\']*\')'
|
||||
# 注释
|
||||
comments = r'(#.*$)'
|
||||
# 函数调用
|
||||
functions = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\('
|
||||
# 数字
|
||||
numbers = r'\b(\d+\.?\d*)\b'
|
||||
|
||||
# 先插入纯文本
|
||||
self.code_text.delete(1.0, tk.END)
|
||||
self.code_text.insert(1.0, code)
|
||||
|
||||
# 应用高亮
|
||||
for match in re.finditer(keywords, code, re.MULTILINE):
|
||||
start = f"1.0+{match.start()}c"
|
||||
end = f"1.0+{match.end()}c"
|
||||
self.code_text.tag_add('keyword', start, end)
|
||||
|
||||
for match in re.finditer(strings, code, re.MULTILINE):
|
||||
start = f"1.0+{match.start()}c"
|
||||
end = f"1.0+{match.end()}c"
|
||||
self.code_text.tag_add('string', start, end)
|
||||
|
||||
for match in re.finditer(comments, code, re.MULTILINE):
|
||||
start = f"1.0+{match.start()}c"
|
||||
end = f"1.0+{match.end()}c"
|
||||
self.code_text.tag_add('comment', start, end)
|
||||
|
||||
for match in re.finditer(numbers, code, re.MULTILINE):
|
||||
start = f"1.0+{match.start(1)}c"
|
||||
end = f"1.0+{match.end(1)}c"
|
||||
self.code_text.tag_add('number', start, end)
|
||||
|
||||
def _create_risk_section(self):
|
||||
"""创建风险提示区域"""
|
||||
section = tk.LabelFrame(
|
||||
@@ -423,6 +568,12 @@ class TaskGuideView:
|
||||
"""设置执行计划(Markdown 格式)"""
|
||||
self.plan_text.set_markdown(plan)
|
||||
|
||||
def set_code(self, code: str):
|
||||
"""设置生成的代码"""
|
||||
self.code_text.config(state=tk.NORMAL)
|
||||
self._apply_syntax_highlight(code)
|
||||
self.code_text.config(state=tk.DISABLED)
|
||||
|
||||
def set_risk_info(self, info: str):
|
||||
"""设置风险提示"""
|
||||
self.risk_label.config(text=info)
|
||||
|
||||
Reference in New Issue
Block a user