feat: enhance LocalAgent configuration and UI components
- Updated .env.example to provide clearer configuration instructions and API key setup. - Removed debug_env.py as it was no longer needed. - Refactored main.py to streamline application initialization and workspace setup. - Introduced a new HistoryManager for managing task execution history. - Enhanced UI components in chat_view.py and task_guide_view.py to improve user interaction and code preview functionality. - Added loading indicators and improved task history display in the UI. - Implemented unit tests for history management and intent classification.
This commit is contained in:
2
tests/__init__.py
Normal file
2
tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# 测试模块
|
||||
|
||||
235
tests/test_history_manager.py
Normal file
235
tests/test_history_manager.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
历史记录管理器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from history.manager import HistoryManager, TaskRecord
|
||||
|
||||
|
||||
class TestHistoryManager(unittest.TestCase):
|
||||
"""历史记录管理器测试"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建临时目录用于测试"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.manager = HistoryManager(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理临时目录"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_add_record(self):
|
||||
"""测试添加记录"""
|
||||
record = self.manager.add_record(
|
||||
task_id="test_001",
|
||||
user_input="复制文件",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制所有文件",
|
||||
code="shutil.copy(...)",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
self.assertEqual(record.task_id, "test_001")
|
||||
self.assertEqual(record.user_input, "复制文件")
|
||||
self.assertTrue(record.success)
|
||||
|
||||
def test_get_all(self):
|
||||
"""测试获取所有记录"""
|
||||
# 添加多条记录
|
||||
for i in range(3):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
records = self.manager.get_all()
|
||||
self.assertEqual(len(records), 3)
|
||||
|
||||
def test_get_recent(self):
|
||||
"""测试获取最近记录"""
|
||||
# 添加 5 条记录
|
||||
for i in range(5):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 获取最近 3 条
|
||||
recent = self.manager.get_recent(3)
|
||||
self.assertEqual(len(recent), 3)
|
||||
# 最新的在前
|
||||
self.assertEqual(recent[0].task_id, "test_004")
|
||||
|
||||
def test_get_by_id(self):
|
||||
"""测试根据 ID 获取记录"""
|
||||
self.manager.add_record(
|
||||
task_id="unique_id",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
record = self.manager.get_by_id("unique_id")
|
||||
self.assertIsNotNone(record)
|
||||
self.assertEqual(record.task_id, "unique_id")
|
||||
|
||||
# 不存在的 ID
|
||||
not_found = self.manager.get_by_id("not_exist")
|
||||
self.assertIsNone(not_found)
|
||||
|
||||
def test_clear(self):
|
||||
"""测试清空记录"""
|
||||
# 添加记录
|
||||
self.manager.add_record(
|
||||
task_id="test",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
self.assertEqual(len(self.manager.get_all()), 1)
|
||||
|
||||
# 清空
|
||||
self.manager.clear()
|
||||
self.assertEqual(len(self.manager.get_all()), 0)
|
||||
|
||||
def test_get_stats(self):
|
||||
"""测试统计信息"""
|
||||
# 添加成功和失败的记录
|
||||
self.manager.add_record(
|
||||
task_id="success_1",
|
||||
user_input="成功任务",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
self.manager.add_record(
|
||||
task_id="success_2",
|
||||
user_input="成功任务2",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=200
|
||||
)
|
||||
self.manager.add_record(
|
||||
task_id="failed_1",
|
||||
user_input="失败任务",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=False,
|
||||
duration_ms=50
|
||||
)
|
||||
|
||||
stats = self.manager.get_stats()
|
||||
self.assertEqual(stats['total'], 3)
|
||||
self.assertEqual(stats['success'], 2)
|
||||
self.assertEqual(stats['failed'], 1)
|
||||
self.assertAlmostEqual(stats['success_rate'], 2/3)
|
||||
|
||||
def test_persistence(self):
|
||||
"""测试持久化"""
|
||||
# 添加记录
|
||||
self.manager.add_record(
|
||||
task_id="persist_test",
|
||||
user_input="持久化测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 创建新的管理器实例(模拟重启)
|
||||
new_manager = HistoryManager(self.temp_dir)
|
||||
|
||||
# 应该能读取到之前的记录
|
||||
records = new_manager.get_all()
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0].task_id, "persist_test")
|
||||
|
||||
def test_max_history_size(self):
|
||||
"""测试历史记录数量限制"""
|
||||
# 添加超过限制的记录
|
||||
for i in range(HistoryManager.MAX_HISTORY_SIZE + 10):
|
||||
self.manager.add_record(
|
||||
task_id=f"test_{i:03d}",
|
||||
user_input=f"任务 {i}",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 应该只保留最大数量
|
||||
records = self.manager.get_all()
|
||||
self.assertEqual(len(records), HistoryManager.MAX_HISTORY_SIZE)
|
||||
|
||||
|
||||
class TestTaskRecord(unittest.TestCase):
|
||||
"""任务记录数据类测试"""
|
||||
|
||||
def test_create_record(self):
|
||||
"""测试创建记录"""
|
||||
record = TaskRecord(
|
||||
task_id="test",
|
||||
timestamp="2024-01-01 12:00:00",
|
||||
user_input="测试",
|
||||
intent_label="execution",
|
||||
intent_confidence=0.9,
|
||||
execution_plan="计划",
|
||||
code="代码",
|
||||
success=True,
|
||||
duration_ms=100,
|
||||
stdout="输出",
|
||||
stderr="",
|
||||
log_path="/path/to/log"
|
||||
)
|
||||
|
||||
self.assertEqual(record.task_id, "test")
|
||||
self.assertTrue(record.success)
|
||||
self.assertEqual(record.duration_ms, 100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
94
tests/test_intent_classifier.py
Normal file
94
tests/test_intent_classifier.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
意图分类器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from intent.labels import CHAT, EXECUTION, VALID_LABELS, EXECUTION_CONFIDENCE_THRESHOLD
|
||||
|
||||
|
||||
class TestIntentLabels(unittest.TestCase):
|
||||
"""意图标签测试"""
|
||||
|
||||
def test_labels_defined(self):
|
||||
"""测试标签已定义"""
|
||||
self.assertEqual(CHAT, "chat")
|
||||
self.assertEqual(EXECUTION, "execution")
|
||||
|
||||
def test_valid_labels(self):
|
||||
"""测试有效标签集合"""
|
||||
self.assertIn(CHAT, VALID_LABELS)
|
||||
self.assertIn(EXECUTION, VALID_LABELS)
|
||||
self.assertEqual(len(VALID_LABELS), 2)
|
||||
|
||||
def test_confidence_threshold(self):
|
||||
"""测试置信度阈值"""
|
||||
self.assertGreater(EXECUTION_CONFIDENCE_THRESHOLD, 0)
|
||||
self.assertLessEqual(EXECUTION_CONFIDENCE_THRESHOLD, 1)
|
||||
|
||||
|
||||
class TestIntentClassifierParsing(unittest.TestCase):
|
||||
"""意图分类器解析测试(不需要 API)"""
|
||||
|
||||
def setUp(self):
|
||||
from intent.classifier import IntentClassifier
|
||||
self.classifier = IntentClassifier()
|
||||
|
||||
def test_parse_valid_chat_response(self):
|
||||
"""测试解析有效的 chat 响应"""
|
||||
response = '{"label": "chat", "confidence": 0.95, "reason": "这是一个问答"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
self.assertEqual(result.confidence, 0.95)
|
||||
self.assertEqual(result.reason, "这是一个问答")
|
||||
|
||||
def test_parse_valid_execution_response(self):
|
||||
"""测试解析有效的 execution 响应"""
|
||||
response = '{"label": "execution", "confidence": 0.9, "reason": "需要复制文件"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, EXECUTION)
|
||||
self.assertEqual(result.confidence, 0.9)
|
||||
|
||||
def test_parse_low_confidence_execution(self):
|
||||
"""测试低置信度的 execution 降级为 chat"""
|
||||
response = '{"label": "execution", "confidence": 0.5, "reason": "不太确定"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
# 低于阈值应该降级为 chat
|
||||
self.assertEqual(result.label, CHAT)
|
||||
|
||||
def test_parse_invalid_label(self):
|
||||
"""测试无效标签降级为 chat"""
|
||||
response = '{"label": "unknown", "confidence": 0.9, "reason": "测试"}'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
|
||||
def test_parse_invalid_json(self):
|
||||
"""测试无效 JSON 降级为 chat"""
|
||||
response = 'not a json'
|
||||
result = self.classifier._parse_response(response)
|
||||
self.assertEqual(result.label, CHAT)
|
||||
self.assertEqual(result.confidence, 0.0)
|
||||
|
||||
def test_extract_json_with_prefix(self):
|
||||
"""测试从带前缀的文本中提取 JSON"""
|
||||
text = 'Here is the result: {"label": "chat", "confidence": 0.8, "reason": "test"}'
|
||||
json_str = self.classifier._extract_json(text)
|
||||
self.assertTrue(json_str.startswith('{'))
|
||||
self.assertTrue(json_str.endswith('}'))
|
||||
|
||||
def test_extract_json_with_suffix(self):
|
||||
"""测试从带后缀的文本中提取 JSON"""
|
||||
text = '{"label": "chat", "confidence": 0.8, "reason": "test"} That is my answer.'
|
||||
json_str = self.classifier._extract_json(text)
|
||||
self.assertTrue(json_str.startswith('{'))
|
||||
self.assertTrue(json_str.endswith('}'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
160
tests/test_rule_checker.py
Normal file
160
tests/test_rule_checker.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
安全检查器单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from safety.rule_checker import RuleChecker, check_code_safety
|
||||
|
||||
|
||||
class TestRuleChecker(unittest.TestCase):
|
||||
"""规则检查器测试"""
|
||||
|
||||
def setUp(self):
|
||||
self.checker = RuleChecker()
|
||||
|
||||
# ========== 硬性禁止测试 ==========
|
||||
|
||||
def test_block_socket_import(self):
|
||||
"""测试禁止 socket 模块"""
|
||||
code = "import socket\ns = socket.socket()"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('socket' in v for v in result.violations))
|
||||
|
||||
def test_block_subprocess_import(self):
|
||||
"""测试禁止 subprocess 模块"""
|
||||
code = "import subprocess\nsubprocess.run(['ls'])"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('subprocess' in v for v in result.violations))
|
||||
|
||||
def test_block_eval(self):
|
||||
"""测试禁止 eval"""
|
||||
code = "result = eval('1+1')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('eval' in v for v in result.violations))
|
||||
|
||||
def test_block_exec(self):
|
||||
"""测试禁止 exec"""
|
||||
code = "exec('print(1)')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('exec' in v for v in result.violations))
|
||||
|
||||
def test_block_os_system(self):
|
||||
"""测试禁止 os.system"""
|
||||
code = "import os\nos.system('dir')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('os.system' in v for v in result.violations))
|
||||
|
||||
def test_block_os_popen(self):
|
||||
"""测试禁止 os.popen"""
|
||||
code = "import os\nos.popen('dir')"
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed)
|
||||
self.assertTrue(any('os.popen' in v for v in result.violations))
|
||||
|
||||
# ========== 警告测试 ==========
|
||||
|
||||
def test_warn_requests_import(self):
|
||||
"""测试 requests 模块产生警告"""
|
||||
code = "import requests\nresponse = requests.get('http://example.com')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('requests' in w for w in result.warnings))
|
||||
|
||||
def test_warn_os_remove(self):
|
||||
"""测试 os.remove 产生警告"""
|
||||
code = "import os\nos.remove('file.txt')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('os.remove' in w for w in result.warnings))
|
||||
|
||||
def test_warn_shutil_rmtree(self):
|
||||
"""测试 shutil.rmtree 产生警告"""
|
||||
code = "import shutil\nshutil.rmtree('folder')"
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed) # 不应该被阻止
|
||||
self.assertTrue(any('shutil.rmtree' in w for w in result.warnings))
|
||||
|
||||
# ========== 安全代码测试 ==========
|
||||
|
||||
def test_safe_file_copy(self):
|
||||
"""测试安全的文件复制代码"""
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
def test_safe_image_processing(self):
|
||||
"""测试安全的图片处理代码"""
|
||||
code = """
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for img_path in INPUT_DIR.glob('*.png'):
|
||||
img = Image.open(img_path)
|
||||
img = img.resize((100, 100))
|
||||
img.save(OUTPUT_DIR / img_path.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
def test_safe_excel_processing(self):
|
||||
"""测试安全的 Excel 处理代码"""
|
||||
code = """
|
||||
import openpyxl
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for xlsx_path in INPUT_DIR.glob('*.xlsx'):
|
||||
wb = openpyxl.load_workbook(xlsx_path)
|
||||
ws = wb.active
|
||||
ws['A1'] = 'Modified'
|
||||
wb.save(OUTPUT_DIR / xlsx_path.name)
|
||||
"""
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed)
|
||||
self.assertEqual(len(result.violations), 0)
|
||||
|
||||
|
||||
class TestCheckCodeSafety(unittest.TestCase):
|
||||
"""便捷函数测试"""
|
||||
|
||||
def test_convenience_function(self):
|
||||
"""测试便捷函数"""
|
||||
result = check_code_safety("print('hello')")
|
||||
self.assertTrue(result.passed)
|
||||
|
||||
def test_convenience_function_block(self):
|
||||
"""测试便捷函数阻止危险代码"""
|
||||
result = check_code_safety("import socket")
|
||||
self.assertFalse(result.passed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user