feat: refactor API key configuration and enhance application initialization
- Renamed `check_environment` to `check_api_key_configured` for clarity, simplifying the API key validation logic. - Removed the blocking behavior of the API key check during application startup, allowing the app to run while providing a prompt for configuration. - Updated `LocalAgentApp` to accept an `api_configured` parameter, enabling conditional messaging for API key setup. - Enhanced the `SandboxRunner` to support backup management and improved execution result handling with detailed metrics. - Integrated data governance strategies into the `HistoryManager`, ensuring compliance and improved data management. - Added privacy settings and metrics tracking across various components to enhance user experience and application safety.
This commit is contained in:
100
tests/test_config_refresh.py
Normal file
100
tests/test_config_refresh.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
测试配置刷新功能
|
||||
验证配置变更后客户端单例是否正确重置
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
PROJECT_ROOT = Path(__file__).parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from dotenv import load_dotenv, set_key
|
||||
from llm.client import get_client, reset_client, test_connection, LLMClientError
|
||||
|
||||
|
||||
def test_config_refresh():
|
||||
"""测试配置刷新流程"""
|
||||
|
||||
env_path = PROJECT_ROOT / ".env"
|
||||
|
||||
print("=" * 60)
|
||||
print("测试配置刷新功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载初始配置
|
||||
print("\n[步骤 1] 加载初始配置...")
|
||||
load_dotenv(env_path)
|
||||
initial_api_key = os.getenv("LLM_API_KEY", "")
|
||||
print(f"初始 API Key: {initial_api_key[:10]}..." if initial_api_key else "未配置")
|
||||
|
||||
# 2. 获取客户端实例
|
||||
print("\n[步骤 2] 获取客户端实例...")
|
||||
try:
|
||||
client1 = get_client()
|
||||
print(f"✓ 客户端实例创建成功")
|
||||
print(f" API URL: {client1.api_url}")
|
||||
print(f" API Key: {client1.api_key[:10]}..." if client1.api_key else "未配置")
|
||||
except LLMClientError as e:
|
||||
print(f"✗ 客户端创建失败: {e}")
|
||||
return
|
||||
|
||||
# 3. 模拟配置变更(这里只是演示,不实际修改)
|
||||
print("\n[步骤 3] 模拟配置变更...")
|
||||
print(" (实际场景中,用户在设置页修改并保存配置)")
|
||||
|
||||
# 4. 重置客户端单例
|
||||
print("\n[步骤 4] 重置客户端单例...")
|
||||
reset_client()
|
||||
print("✓ 客户端单例已重置")
|
||||
|
||||
# 5. 重新获取客户端实例
|
||||
print("\n[步骤 5] 重新获取客户端实例...")
|
||||
try:
|
||||
client2 = get_client()
|
||||
print(f"✓ 新客户端实例创建成功")
|
||||
print(f" API URL: {client2.api_url}")
|
||||
print(f" API Key: {client2.api_key[:10]}..." if client2.api_key else "未配置")
|
||||
|
||||
# 验证是否是新实例
|
||||
if client1 is client2:
|
||||
print("✗ 警告: 客户端实例未更新(仍是旧实例)")
|
||||
else:
|
||||
print("✓ 确认: 客户端实例已更新(新实例)")
|
||||
except LLMClientError as e:
|
||||
print(f"✗ 新客户端创建失败: {e}")
|
||||
return
|
||||
|
||||
# 6. 测试连接
|
||||
print("\n[步骤 6] 测试 API 连接...")
|
||||
success, message = test_connection(timeout=10)
|
||||
if success:
|
||||
print(f"✓ {message}")
|
||||
else:
|
||||
print(f"✗ {message}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
# 7. 显示度量统计
|
||||
print("\n[度量统计]")
|
||||
try:
|
||||
from llm.config_metrics import get_config_metrics
|
||||
workspace = PROJECT_ROOT / "workspace"
|
||||
metrics = get_config_metrics(workspace)
|
||||
stats = metrics.get_statistics()
|
||||
|
||||
print(f"配置变更总次数: {stats['total_config_changes']}")
|
||||
print(f"首次调用成功率: {stats['first_call_success_rate']:.1%}")
|
||||
print(f"平均重试次数: {stats['avg_retry_count']:.2f}")
|
||||
print(f"连接测试成功率: {stats['connection_test_success_rate']:.1%}")
|
||||
except Exception as e:
|
||||
print(f"无法获取度量统计: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_config_refresh()
|
||||
|
||||
326
tests/test_data_governance.py
Normal file
326
tests/test_data_governance.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
数据治理单元测试
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from history.data_sanitizer import DataSanitizer, SensitiveType
|
||||
from history.data_governance import DataGovernancePolicy, DataLevel
|
||||
from history.manager import HistoryManager
|
||||
|
||||
|
||||
class TestDataSanitizer(unittest.TestCase):
|
||||
"""测试数据脱敏器"""
|
||||
|
||||
def setUp(self):
|
||||
self.sanitizer = DataSanitizer()
|
||||
|
||||
def test_file_path_detection(self):
|
||||
"""测试文件路径检测"""
|
||||
text = "文件保存在 C:\\Users\\test\\document.txt 中"
|
||||
matches = self.sanitizer.find_sensitive_data(text)
|
||||
|
||||
self.assertTrue(any(m.type == SensitiveType.FILE_PATH for m in matches))
|
||||
|
||||
def test_email_detection(self):
|
||||
"""测试邮箱检测"""
|
||||
text = "联系邮箱: test@example.com"
|
||||
matches = self.sanitizer.find_sensitive_data(text)
|
||||
|
||||
self.assertTrue(any(m.type == SensitiveType.EMAIL for m in matches))
|
||||
|
||||
def test_phone_detection(self):
|
||||
"""测试电话号码检测"""
|
||||
text = "手机号: 13812345678"
|
||||
matches = self.sanitizer.find_sensitive_data(text)
|
||||
|
||||
self.assertTrue(any(m.type == SensitiveType.PHONE for m in matches))
|
||||
|
||||
def test_ip_detection(self):
|
||||
"""测试IP地址检测"""
|
||||
text = "服务器地址: 192.168.1.100"
|
||||
matches = self.sanitizer.find_sensitive_data(text)
|
||||
|
||||
self.assertTrue(any(m.type == SensitiveType.IP_ADDRESS for m in matches))
|
||||
|
||||
def test_sanitize_text(self):
|
||||
"""测试文本脱敏"""
|
||||
text = "邮箱 test@example.com 手机 13812345678"
|
||||
sanitized, matches = self.sanitizer.sanitize(text)
|
||||
|
||||
self.assertNotIn("test@example.com", sanitized)
|
||||
self.assertNotIn("13812345678", sanitized)
|
||||
self.assertEqual(len(matches), 2)
|
||||
|
||||
def test_sensitivity_score(self):
|
||||
"""测试敏感度评分"""
|
||||
# 低敏感度
|
||||
low_text = "这是一段普通文本"
|
||||
self.assertLess(self.sanitizer.get_sensitivity_score(low_text), 0.3)
|
||||
|
||||
# 高敏感度(使用更明显的敏感信息)
|
||||
high_text = "密码: password123, API密钥: sk-1234567890abcdefghijklmnopqrstuvwxyz123456789012, 邮箱: admin@company.com, 手机: 13812345678"
|
||||
self.assertGreater(self.sanitizer.get_sensitivity_score(high_text), 0.5)
|
||||
|
||||
|
||||
class TestDataGovernance(unittest.TestCase):
|
||||
"""测试数据治理策略"""
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.policy = DataGovernancePolicy(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_classify_low_sensitivity(self):
|
||||
"""测试低敏感度分类"""
|
||||
record = {
|
||||
'user_input': '计算1+1',
|
||||
'code': 'print(1+1)',
|
||||
'stdout': '2',
|
||||
'stderr': '',
|
||||
'execution_plan': '执行简单计算'
|
||||
}
|
||||
|
||||
classification = self.policy.classify_record(record)
|
||||
self.assertEqual(classification.level, DataLevel.FULL)
|
||||
self.assertLess(classification.sensitivity_score, 0.3)
|
||||
|
||||
def test_classify_high_sensitivity(self):
|
||||
"""测试高敏感度分类"""
|
||||
record = {
|
||||
'user_input': '读取配置文件 /etc/config.json',
|
||||
'code': 'password = "secret123"\napi_key = "sk-1234567890abcdefghijklmnopqrstuvwxyz123456789012"',
|
||||
'stdout': 'API_KEY=sk-1234567890abcdefghijklmnopqrstuvwxyz123456789012\nemail=admin@company.com\nphone=13812345678',
|
||||
'stderr': 'Error at /home/user/secret/config.json',
|
||||
'execution_plan': '读取敏感配置'
|
||||
}
|
||||
|
||||
classification = self.policy.classify_record(record)
|
||||
# 由于敏感信息较多,应该至少是脱敏级别
|
||||
self.assertGreater(classification.sensitivity_score, 0.2)
|
||||
|
||||
def test_apply_policy_minimal(self):
|
||||
"""测试最小化策略应用"""
|
||||
record = {
|
||||
'task_id': 'test-001',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'user_input': 'password=secret123',
|
||||
'code': 'API_KEY="sk-test"',
|
||||
'stdout': 'token: abc123',
|
||||
'stderr': '',
|
||||
'execution_plan': '测试',
|
||||
'intent_label': 'test',
|
||||
'intent_confidence': 0.9,
|
||||
'success': True,
|
||||
'duration_ms': 100,
|
||||
'log_path': '',
|
||||
'task_summary': '测试任务'
|
||||
}
|
||||
|
||||
result = self.policy.apply_policy(record)
|
||||
|
||||
# 应该有治理元数据
|
||||
self.assertIn('_governance', result)
|
||||
self.assertIn('level', result['_governance'])
|
||||
|
||||
def test_expiration_check(self):
|
||||
"""测试过期检查"""
|
||||
# 未过期记录
|
||||
record_valid = {
|
||||
'_governance': {
|
||||
'expires_at': (datetime.now() + timedelta(days=1)).isoformat()
|
||||
}
|
||||
}
|
||||
self.assertFalse(self.policy.check_expiration(record_valid))
|
||||
|
||||
# 已过期记录
|
||||
record_expired = {
|
||||
'_governance': {
|
||||
'expires_at': (datetime.now() - timedelta(days=1)).isoformat()
|
||||
}
|
||||
}
|
||||
self.assertTrue(self.policy.check_expiration(record_expired))
|
||||
|
||||
def test_cleanup_expired(self):
|
||||
"""测试过期清理"""
|
||||
records = [
|
||||
{
|
||||
'task_id': '1',
|
||||
'_governance': {
|
||||
'level': DataLevel.FULL.value,
|
||||
'expires_at': (datetime.now() - timedelta(days=1)).isoformat(),
|
||||
'sensitive_fields': []
|
||||
}
|
||||
},
|
||||
{
|
||||
'task_id': '2',
|
||||
'_governance': {
|
||||
'level': DataLevel.SANITIZED.value,
|
||||
'expires_at': (datetime.now() - timedelta(days=1)).isoformat()
|
||||
}
|
||||
},
|
||||
{
|
||||
'task_id': '3',
|
||||
'_governance': {
|
||||
'level': DataLevel.MINIMAL.value,
|
||||
'expires_at': (datetime.now() - timedelta(days=1)).isoformat()
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
kept, archived, deleted = self.policy.cleanup_expired(records)
|
||||
|
||||
# 完整数据应降级,脱敏数据应归档,最小化数据应删除
|
||||
self.assertGreater(len(kept), 0)
|
||||
self.assertGreater(archived + deleted, 0)
|
||||
|
||||
|
||||
class TestHistoryManager(unittest.TestCase):
|
||||
"""测试历史记录管理器"""
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.manager = HistoryManager(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_add_record_with_governance(self):
|
||||
"""测试添加记录时应用治理策略"""
|
||||
record = self.manager.add_record(
|
||||
task_id='test-001',
|
||||
user_input='测试输入',
|
||||
intent_label='test',
|
||||
intent_confidence=0.9,
|
||||
execution_plan='测试计划',
|
||||
code='print("test")',
|
||||
success=True,
|
||||
duration_ms=100,
|
||||
stdout='test',
|
||||
stderr='',
|
||||
log_path='',
|
||||
task_summary='测试'
|
||||
)
|
||||
|
||||
self.assertIsNotNone(record)
|
||||
self.assertEqual(record.task_id, 'test-001')
|
||||
|
||||
def test_save_and_load_with_governance(self):
|
||||
"""测试保存和加载带治理元数据的记录"""
|
||||
self.manager.add_record(
|
||||
task_id='test-002',
|
||||
user_input='测试',
|
||||
intent_label='test',
|
||||
intent_confidence=0.9,
|
||||
execution_plan='测试',
|
||||
code='test',
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 重新加载
|
||||
new_manager = HistoryManager(self.temp_dir)
|
||||
records = new_manager.get_all()
|
||||
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0].task_id, 'test-002')
|
||||
|
||||
def test_manual_cleanup(self):
|
||||
"""测试手动清理"""
|
||||
# 添加一条过期记录
|
||||
self.manager.add_record(
|
||||
task_id='test-003',
|
||||
user_input='测试',
|
||||
intent_label='test',
|
||||
intent_confidence=0.9,
|
||||
execution_plan='测试',
|
||||
code='test',
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 手动修改过期时间
|
||||
if self.manager._history:
|
||||
record_dict = {
|
||||
'task_id': 'test-004',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'user_input': 'test',
|
||||
'intent_label': 'test',
|
||||
'intent_confidence': 0.9,
|
||||
'execution_plan': 'test',
|
||||
'code': 'test',
|
||||
'success': True,
|
||||
'duration_ms': 100,
|
||||
'stdout': '',
|
||||
'stderr': '',
|
||||
'log_path': '',
|
||||
'task_summary': '',
|
||||
'_governance': {
|
||||
'level': DataLevel.MINIMAL.value,
|
||||
'expires_at': (datetime.now() - timedelta(days=1)).isoformat()
|
||||
},
|
||||
'_sanitization': None
|
||||
}
|
||||
|
||||
from history.manager import TaskRecord
|
||||
self.manager._history.append(TaskRecord(**record_dict))
|
||||
self.manager._save()
|
||||
|
||||
stats = self.manager.manual_cleanup()
|
||||
|
||||
self.assertIn('archived', stats)
|
||||
self.assertIn('deleted', stats)
|
||||
self.assertIn('remaining', stats)
|
||||
|
||||
def test_export_sanitized(self):
|
||||
"""测试导出脱敏数据"""
|
||||
self.manager.add_record(
|
||||
task_id='test-005',
|
||||
user_input='测试邮箱 test@example.com',
|
||||
intent_label='test',
|
||||
intent_confidence=0.9,
|
||||
execution_plan='测试',
|
||||
code='test',
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
export_path = self.temp_dir / "export.json"
|
||||
count = self.manager.export_sanitized(export_path)
|
||||
|
||||
self.assertGreater(count, 0)
|
||||
self.assertTrue(export_path.exists())
|
||||
|
||||
# 验证导出内容
|
||||
with open(export_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self.assertEqual(len(data), count)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""运行所有测试"""
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestDataSanitizer))
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestDataGovernance))
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestHistoryManager))
|
||||
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
success = run_tests()
|
||||
exit(0 if success else 1)
|
||||
|
||||
654
tests/test_e2e_integration.py
Normal file
654
tests/test_e2e_integration.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""
|
||||
端到端集成测试
|
||||
测试关键主流程和安全回归场景
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from history.manager import HistoryManager
|
||||
from safety.rule_checker import RuleChecker
|
||||
from safety.llm_reviewer import LLMReviewer, LLMReviewResult
|
||||
from executor.sandbox_runner import SandboxRunner, ExecutionResult
|
||||
from intent.classifier import IntentClassifier, IntentResult
|
||||
from intent.labels import EXECUTION
|
||||
from llm.config_metrics import ConfigMetricsManager
|
||||
from history.reuse_metrics import ReuseMetrics
|
||||
|
||||
|
||||
class TestCodeReuseSecurityRegression(unittest.TestCase):
|
||||
"""
|
||||
测试场景:复用绕过安全
|
||||
验证历史代码复用时必须重新进行安全检查
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.history = HistoryManager(self.temp_dir)
|
||||
self.rule_checker = RuleChecker()
|
||||
self.reuse_metrics = ReuseMetrics(self.temp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_reuse_must_trigger_security_recheck(self):
|
||||
"""测试:复用代码必须触发安全复检"""
|
||||
# 1. 添加一条历史成功记录(包含潜在危险代码)
|
||||
dangerous_code = """
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
# 危险操作:删除文件
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
os.remove(f)
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="task_001",
|
||||
user_input="删除所有txt文件",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.95,
|
||||
execution_plan="遍历input目录删除txt文件",
|
||||
code=dangerous_code,
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 2. 查找相似任务(模拟复用场景)
|
||||
result = self.history.find_similar_success("删除txt文件", return_details=True)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
similar_record, similarity_score, differences = result
|
||||
|
||||
# 3. 记录复用指标
|
||||
self.reuse_metrics.record_reuse_offered(
|
||||
original_task_id="task_001",
|
||||
similarity_score=similarity_score,
|
||||
differences_count=len(differences),
|
||||
critical_differences=0
|
||||
)
|
||||
|
||||
# 4. 模拟用户接受复用
|
||||
self.reuse_metrics.record_reuse_accepted(
|
||||
original_task_id="task_001",
|
||||
similarity_score=similarity_score,
|
||||
differences_count=len(differences),
|
||||
critical_differences=0
|
||||
)
|
||||
|
||||
# 5. 强制安全复检(关键步骤)
|
||||
recheck_result = self.rule_checker.check(similar_record.code)
|
||||
|
||||
# 6. 验证:必须检测到危险操作
|
||||
self.assertTrue(len(recheck_result.warnings) > 0, "复用代码的安全复检必须检测到警告")
|
||||
self.assertTrue(
|
||||
any('os.remove' in w for w in recheck_result.warnings),
|
||||
"必须检测到 os.remove 警告"
|
||||
)
|
||||
|
||||
def test_reuse_blocked_by_security_check(self):
|
||||
"""测试:复用代码被安全检查拦截"""
|
||||
# 1. 添加包含硬性禁止操作的历史记录
|
||||
blocked_code = """
|
||||
import socket
|
||||
|
||||
# 硬性禁止:网络操作
|
||||
s = socket.socket()
|
||||
s.connect(('example.com', 80))
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="task_002",
|
||||
user_input="连接服务器",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.9,
|
||||
execution_plan="建立socket连接",
|
||||
code=blocked_code,
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 2. 查找并尝试复用
|
||||
result = self.history.find_similar_success("连接到服务器", return_details=True)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
similar_record, _, _ = result
|
||||
|
||||
# 3. 安全复检
|
||||
recheck_result = self.rule_checker.check(similar_record.code)
|
||||
|
||||
# 4. 验证:必须被拦截
|
||||
self.assertFalse(recheck_result.passed, "包含socket的复用代码必须被拦截")
|
||||
self.assertTrue(
|
||||
any('socket' in v for v in recheck_result.violations),
|
||||
"必须检测到socket违规"
|
||||
)
|
||||
|
||||
def test_reuse_metrics_tracking(self):
|
||||
"""测试:复用流程的指标追踪"""
|
||||
# 1. 添加历史记录
|
||||
safe_code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*.png'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="task_003",
|
||||
user_input="复制所有图片",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制png文件",
|
||||
code=safe_code,
|
||||
success=True,
|
||||
duration_ms=150
|
||||
)
|
||||
|
||||
# 2. 模拟完整的复用流程
|
||||
result = self.history.find_similar_success("复制图片文件", return_details=True)
|
||||
similar_record, similarity_score, differences = result
|
||||
|
||||
# 记录复用提供
|
||||
self.reuse_metrics.record_reuse_offered(
|
||||
original_task_id="task_003",
|
||||
similarity_score=similarity_score,
|
||||
differences_count=len(differences),
|
||||
critical_differences=0
|
||||
)
|
||||
|
||||
# 记录复用接受
|
||||
self.reuse_metrics.record_reuse_accepted(
|
||||
original_task_id="task_003",
|
||||
similarity_score=similarity_score,
|
||||
differences_count=len(differences),
|
||||
critical_differences=0
|
||||
)
|
||||
|
||||
# 安全复检通过
|
||||
recheck_result = self.rule_checker.check(similar_record.code)
|
||||
self.assertTrue(recheck_result.passed)
|
||||
|
||||
# 记录执行结果
|
||||
self.reuse_metrics.record_reuse_execution(
|
||||
original_task_id="task_003",
|
||||
new_task_id="task_004",
|
||||
success=True
|
||||
)
|
||||
|
||||
# 3. 验证指标
|
||||
stats = self.reuse_metrics.get_stats()
|
||||
self.assertEqual(stats['total_offered'], 1)
|
||||
self.assertEqual(stats['total_accepted'], 1)
|
||||
self.assertEqual(stats['total_executed'], 1)
|
||||
self.assertEqual(stats['success_count'], 1)
|
||||
self.assertAlmostEqual(stats['acceptance_rate'], 1.0)
|
||||
|
||||
|
||||
class TestConfigHotReloadRegression(unittest.TestCase):
|
||||
"""
|
||||
测试场景:设置热更新
|
||||
验证配置变更后首次调用的正确性
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.config_metrics = ConfigMetricsManager(self.temp_dir / "config_metrics.json")
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_config_change_triggers_first_call_tracking(self):
|
||||
"""测试:配置变更触发首次调用追踪"""
|
||||
# 1. 记录配置变更
|
||||
self.config_metrics.mark_config_changed(connection_test_success=True)
|
||||
|
||||
# 2. 验证首次调用标志
|
||||
self.assertTrue(
|
||||
self.config_metrics._config_changed,
|
||||
"配置变更后应标记为首次调用"
|
||||
)
|
||||
|
||||
# 3. 模拟首次调用成功
|
||||
self.config_metrics.record_first_call(success=True)
|
||||
|
||||
# 4. 验证标志已清除
|
||||
self.assertTrue(
|
||||
self.config_metrics._first_call_recorded,
|
||||
"首次调用后应记录标志"
|
||||
)
|
||||
|
||||
def test_config_change_first_call_failure(self):
|
||||
"""测试:配置变更后首次调用失败"""
|
||||
# 1. 记录配置变更
|
||||
self.config_metrics.mark_config_changed(connection_test_success=True)
|
||||
|
||||
# 2. 模拟首次调用失败
|
||||
self.config_metrics.record_first_call(
|
||||
success=False,
|
||||
error_message="Invalid API Key"
|
||||
)
|
||||
|
||||
# 3. 验证记录
|
||||
self.assertTrue(self.config_metrics._first_call_recorded)
|
||||
self.assertEqual(self.config_metrics._retry_count, 0)
|
||||
|
||||
@patch('llm.client.get_client')
|
||||
def test_intent_classification_after_config_change(self, mock_get_client):
|
||||
"""测试:配置变更后的意图分类调用"""
|
||||
# 1. Mock LLM 客户端
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.return_value = '{"label": "execution", "confidence": 0.95, "reason": "需要执行文件操作"}'
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# 2. 记录配置变更
|
||||
self.config_metrics.mark_config_changed(connection_test_success=True)
|
||||
|
||||
# 3. 执行意图分类(首次调用)
|
||||
from intent.classifier import classify_intent
|
||||
|
||||
try:
|
||||
result = classify_intent("复制所有文件")
|
||||
|
||||
# 4. 记录成功
|
||||
self.config_metrics.record_first_call(success=True)
|
||||
|
||||
# 5. 验证结果
|
||||
self.assertEqual(result.label, EXECUTION)
|
||||
self.assertGreater(result.confidence, 0.9)
|
||||
|
||||
except Exception as e:
|
||||
# 记录失败
|
||||
self.config_metrics.record_first_call(success=False, error_message=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class TestExecutionResultThreeStateRegression(unittest.TestCase):
|
||||
"""
|
||||
测试场景:执行链三态结果
|
||||
验证 success/partial/failed 状态的正确流转
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.workspace = self.temp_dir / "workspace"
|
||||
self.workspace.mkdir()
|
||||
(self.workspace / "input").mkdir()
|
||||
(self.workspace / "output").mkdir()
|
||||
(self.workspace / "codes").mkdir()
|
||||
(self.workspace / "logs").mkdir()
|
||||
|
||||
self.runner = SandboxRunner(str(self.workspace))
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_execution_result_all_success(self):
|
||||
"""测试:全部成功状态"""
|
||||
# 创建测试输入文件
|
||||
input_dir = self.workspace / "input"
|
||||
(input_dir / "test1.txt").write_text("content1")
|
||||
(input_dir / "test2.txt").write_text("content2")
|
||||
|
||||
# 执行代码:复制所有文件
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
total_count = 0
|
||||
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
total_count += 1
|
||||
try:
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
success_count += 1
|
||||
print(f"成功: {f.name}")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"失败: {f.name} - {e}")
|
||||
|
||||
print(f"\\n总计: {total_count}, 成功: {success_count}, 失败: {failed_count}")
|
||||
"""
|
||||
|
||||
result = self.runner.execute(code, user_input="复制所有txt文件")
|
||||
|
||||
# 验证:全部成功
|
||||
self.assertEqual(result.status, 'success')
|
||||
self.assertEqual(result.total_count, 2)
|
||||
self.assertEqual(result.success_count, 2)
|
||||
self.assertEqual(result.failed_count, 0)
|
||||
self.assertAlmostEqual(result.success_rate, 1.0)
|
||||
self.assertTrue(result.success)
|
||||
|
||||
def test_execution_result_partial_success(self):
|
||||
"""测试:部分成功状态"""
|
||||
# 创建测试输入文件(一个正常,一个只读)
|
||||
input_dir = self.workspace / "input"
|
||||
normal_file = input_dir / "normal.txt"
|
||||
readonly_file = input_dir / "readonly.txt"
|
||||
|
||||
normal_file.write_text("normal content")
|
||||
readonly_file.write_text("readonly content")
|
||||
|
||||
# 设置只读(模拟失败场景)
|
||||
if os.name == 'nt': # Windows
|
||||
os.chmod(readonly_file, 0o444)
|
||||
|
||||
# 执行代码:尝试复制所有文件
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
total_count = 0
|
||||
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
total_count += 1
|
||||
try:
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
success_count += 1
|
||||
print(f"成功: {f.name}")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"失败: {f.name} - {e}")
|
||||
|
||||
print(f"\\n总计: {total_count}, 成功: {success_count}, 失败: {failed_count}")
|
||||
"""
|
||||
|
||||
result = self.runner.execute(code, user_input="复制所有txt文件")
|
||||
|
||||
# 验证:部分成功(至少有一个成功)
|
||||
self.assertEqual(result.total_count, 2)
|
||||
self.assertGreater(result.success_count, 0)
|
||||
self.assertGreater(result.failed_count, 0)
|
||||
|
||||
# 根据实际情况判断状态
|
||||
if result.success_count > 0 and result.failed_count > 0:
|
||||
self.assertEqual(result.status, 'partial')
|
||||
self.assertFalse(result.success) # partial 不算完全成功
|
||||
|
||||
# 恢复权限
|
||||
if os.name == 'nt':
|
||||
os.chmod(readonly_file, 0o666)
|
||||
|
||||
def test_execution_result_all_failed(self):
|
||||
"""测试:全部失败状态"""
|
||||
# 不创建输入文件,导致无文件可处理
|
||||
|
||||
# 执行代码:尝试处理不存在的文件
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
total_count = 0
|
||||
|
||||
files = list(INPUT_DIR.glob('*.txt'))
|
||||
if not files:
|
||||
print("错误: 没有找到任何txt文件")
|
||||
total_count = 1
|
||||
failed_count = 1
|
||||
else:
|
||||
for f in files:
|
||||
total_count += 1
|
||||
try:
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
success_count += 1
|
||||
print(f"成功: {f.name}")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"失败: {f.name} - {e}")
|
||||
|
||||
print(f"\\n总计: {total_count}, 成功: {success_count}, 失败: {failed_count}")
|
||||
"""
|
||||
|
||||
result = self.runner.execute(code, user_input="复制所有txt文件")
|
||||
|
||||
# 验证:全部失败
|
||||
self.assertEqual(result.status, 'failed')
|
||||
self.assertEqual(result.success_count, 0)
|
||||
self.assertFalse(result.success)
|
||||
|
||||
def test_execution_result_status_display(self):
|
||||
"""测试:状态显示文本"""
|
||||
# 测试各种状态的显示文本
|
||||
|
||||
# 成功状态
|
||||
success_result = ExecutionResult(
|
||||
task_id="test_001",
|
||||
success=True,
|
||||
stdout="output",
|
||||
stderr="",
|
||||
duration_ms=100,
|
||||
log_path="/path/to/log",
|
||||
status='success',
|
||||
total_count=5,
|
||||
success_count=5,
|
||||
failed_count=0
|
||||
)
|
||||
self.assertIn("✅", success_result.get_status_display())
|
||||
self.assertIn("全部成功", success_result.get_status_display())
|
||||
|
||||
# 部分成功状态
|
||||
partial_result = ExecutionResult(
|
||||
task_id="test_002",
|
||||
success=False,
|
||||
stdout="output",
|
||||
stderr="",
|
||||
duration_ms=100,
|
||||
log_path="/path/to/log",
|
||||
status='partial',
|
||||
total_count=5,
|
||||
success_count=3,
|
||||
failed_count=2
|
||||
)
|
||||
self.assertIn("⚠️", partial_result.get_status_display())
|
||||
self.assertIn("部分成功", partial_result.get_status_display())
|
||||
|
||||
# 失败状态
|
||||
failed_result = ExecutionResult(
|
||||
task_id="test_003",
|
||||
success=False,
|
||||
stdout="",
|
||||
stderr="error",
|
||||
duration_ms=100,
|
||||
log_path="/path/to/log",
|
||||
status='failed',
|
||||
total_count=5,
|
||||
success_count=0,
|
||||
failed_count=5
|
||||
)
|
||||
self.assertIn("❌", failed_result.get_status_display())
|
||||
self.assertIn("执行失败", failed_result.get_status_display())
|
||||
|
||||
|
||||
class TestEndToEndWorkflow(unittest.TestCase):
|
||||
"""
|
||||
端到端工作流测试
|
||||
模拟完整的用户任务执行流程
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.workspace = self.temp_dir / "workspace"
|
||||
self.workspace.mkdir()
|
||||
(self.workspace / "input").mkdir()
|
||||
(self.workspace / "output").mkdir()
|
||||
(self.workspace / "codes").mkdir()
|
||||
(self.workspace / "logs").mkdir()
|
||||
|
||||
self.history = HistoryManager(self.workspace)
|
||||
self.runner = SandboxRunner(str(self.workspace))
|
||||
self.rule_checker = RuleChecker()
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch('llm.client.get_client')
|
||||
def test_complete_execution_workflow(self, mock_get_client):
|
||||
"""测试:完整的执行工作流"""
|
||||
# 1. Mock LLM 响应
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.return_value = '{"label": "execution", "confidence": 0.95, "reason": "需要复制文件"}'
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# 2. 意图分类
|
||||
from intent.classifier import classify_intent
|
||||
intent_result = classify_intent("复制所有图片到输出目录")
|
||||
self.assertEqual(intent_result.label, EXECUTION)
|
||||
|
||||
# 3. 生成代码(模拟)
|
||||
code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
success_count = 0
|
||||
total_count = 0
|
||||
|
||||
for f in INPUT_DIR.glob('*.png'):
|
||||
total_count += 1
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
success_count += 1
|
||||
print(f"已复制: {f.name}")
|
||||
|
||||
print(f"\\n总计: {total_count}, 成功: {success_count}")
|
||||
"""
|
||||
|
||||
# 4. 安全检查
|
||||
safety_result = self.rule_checker.check(code)
|
||||
self.assertTrue(safety_result.passed, "安全代码应该通过检查")
|
||||
|
||||
# 5. 准备输入文件
|
||||
input_dir = self.workspace / "input"
|
||||
(input_dir / "image1.png").write_bytes(b"fake png data 1")
|
||||
(input_dir / "image2.png").write_bytes(b"fake png data 2")
|
||||
|
||||
# 6. 执行代码
|
||||
exec_result = self.runner.execute(code, user_input="复制所有图片到输出目录")
|
||||
|
||||
# 7. 验证执行结果
|
||||
self.assertTrue(exec_result.success)
|
||||
self.assertEqual(exec_result.status, 'success')
|
||||
self.assertEqual(exec_result.total_count, 2)
|
||||
self.assertEqual(exec_result.success_count, 2)
|
||||
|
||||
# 8. 保存历史记录
|
||||
self.history.add_record(
|
||||
task_id=exec_result.task_id,
|
||||
user_input="复制所有图片到输出目录",
|
||||
intent_label=intent_result.label,
|
||||
intent_confidence=intent_result.confidence,
|
||||
execution_plan="复制png文件",
|
||||
code=code,
|
||||
success=exec_result.success,
|
||||
duration_ms=exec_result.duration_ms,
|
||||
stdout=exec_result.stdout,
|
||||
stderr=exec_result.stderr,
|
||||
log_path=exec_result.log_path,
|
||||
task_summary="复制图片"
|
||||
)
|
||||
|
||||
# 9. 验证历史记录
|
||||
records = self.history.get_all()
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertTrue(records[0].success)
|
||||
|
||||
def test_workflow_with_security_block(self):
|
||||
"""测试:安全检查拦截的工作流"""
|
||||
# 1. 生成危险代码
|
||||
dangerous_code = """
|
||||
import subprocess
|
||||
|
||||
# 危险操作:执行系统命令
|
||||
subprocess.run(['dir'], shell=True)
|
||||
"""
|
||||
|
||||
# 2. 安全检查
|
||||
safety_result = self.rule_checker.check(dangerous_code)
|
||||
|
||||
# 3. 验证:必须被拦截
|
||||
self.assertFalse(safety_result.passed)
|
||||
self.assertTrue(any('subprocess' in v for v in safety_result.violations))
|
||||
|
||||
# 4. 不应该执行代码
|
||||
# (在实际应用中,安全检查失败后会直接返回,不会执行)
|
||||
|
||||
|
||||
class TestSecurityMetricsTracking(unittest.TestCase):
|
||||
"""
|
||||
安全指标追踪测试
|
||||
验证安全相关的度量指标
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_security_metrics_reuse_tracking(self):
|
||||
"""测试:复用安全指标追踪"""
|
||||
from safety.security_metrics import SecurityMetrics
|
||||
|
||||
metrics = SecurityMetrics(workspace_path=self.temp_dir)
|
||||
|
||||
# 1. 记录复用复检
|
||||
metrics.add_reuse_recheck()
|
||||
metrics.add_reuse_recheck()
|
||||
|
||||
# 2. 记录复用拦截
|
||||
metrics.add_reuse_block()
|
||||
|
||||
# 3. 验证统计
|
||||
stats = metrics.get_stats()
|
||||
self.assertEqual(stats['reuse_recheck_count'], 2)
|
||||
self.assertEqual(stats['reuse_block_count'], 1)
|
||||
self.assertAlmostEqual(stats['reuse_block_rate'], 0.5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行测试并生成详细报告
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
204
tests/test_retry_fix.py
Normal file
204
tests/test_retry_fix.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
测试重试策略修复
|
||||
验证网络异常能够被正确识别并重试
|
||||
"""
|
||||
|
||||
import sys
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
# 设置标准输出为 UTF-8
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
|
||||
# 添加项目根目录到路径
|
||||
PROJECT_ROOT = Path(__file__).parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from llm.client import LLMClient, LLMClientError
|
||||
import requests
|
||||
|
||||
|
||||
def test_exception_classification():
|
||||
"""测试异常分类"""
|
||||
print("=" * 60)
|
||||
print("测试 1: 异常分类")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试网络异常
|
||||
network_error = LLMClientError(
|
||||
"网络连接失败",
|
||||
error_type=LLMClientError.TYPE_NETWORK,
|
||||
original_exception=requests.exceptions.ConnectionError()
|
||||
)
|
||||
print(f"✓ 网络错误类型: {network_error.error_type}")
|
||||
assert network_error.error_type == LLMClientError.TYPE_NETWORK
|
||||
|
||||
# 测试服务器异常
|
||||
server_error = LLMClientError(
|
||||
"服务器错误 500",
|
||||
error_type=LLMClientError.TYPE_SERVER
|
||||
)
|
||||
print(f"✓ 服务器错误类型: {server_error.error_type}")
|
||||
assert server_error.error_type == LLMClientError.TYPE_SERVER
|
||||
|
||||
# 测试客户端异常
|
||||
client_error = LLMClientError(
|
||||
"请求参数错误 400",
|
||||
error_type=LLMClientError.TYPE_CLIENT
|
||||
)
|
||||
print(f"✓ 客户端错误类型: {client_error.error_type}")
|
||||
assert client_error.error_type == LLMClientError.TYPE_CLIENT
|
||||
|
||||
print("\n✅ 异常分类测试通过\n")
|
||||
|
||||
|
||||
def test_should_retry_logic():
|
||||
"""测试重试判断逻辑"""
|
||||
print("=" * 60)
|
||||
print("测试 2: 重试判断逻辑")
|
||||
print("=" * 60)
|
||||
|
||||
client = LLMClient(max_retries=3)
|
||||
|
||||
# 测试网络错误应该重试
|
||||
network_error = LLMClientError(
|
||||
"网络连接失败",
|
||||
error_type=LLMClientError.TYPE_NETWORK,
|
||||
original_exception=requests.exceptions.ConnectionError()
|
||||
)
|
||||
should_retry = client._should_retry(network_error)
|
||||
print(f"✓ 网络错误应该重试: {should_retry}")
|
||||
assert should_retry == True, "网络错误应该重试"
|
||||
|
||||
# 测试超时错误应该重试
|
||||
timeout_error = LLMClientError(
|
||||
"请求超时",
|
||||
error_type=LLMClientError.TYPE_NETWORK,
|
||||
original_exception=requests.exceptions.Timeout()
|
||||
)
|
||||
should_retry = client._should_retry(timeout_error)
|
||||
print(f"✓ 超时错误应该重试: {should_retry}")
|
||||
assert should_retry == True, "超时错误应该重试"
|
||||
|
||||
# 测试服务器错误应该重试
|
||||
server_error = LLMClientError(
|
||||
"服务器错误 500",
|
||||
error_type=LLMClientError.TYPE_SERVER
|
||||
)
|
||||
should_retry = client._should_retry(server_error)
|
||||
print(f"✓ 服务器错误应该重试: {should_retry}")
|
||||
assert should_retry == True, "服务器错误应该重试"
|
||||
|
||||
# 测试客户端错误不应该重试
|
||||
client_error = LLMClientError(
|
||||
"请求参数错误 400",
|
||||
error_type=LLMClientError.TYPE_CLIENT
|
||||
)
|
||||
should_retry = client._should_retry(client_error)
|
||||
print(f"✓ 客户端错误不应该重试: {should_retry}")
|
||||
assert should_retry == False, "客户端错误不应该重试"
|
||||
|
||||
# 测试解析错误不应该重试
|
||||
parse_error = LLMClientError(
|
||||
"解析响应失败",
|
||||
error_type=LLMClientError.TYPE_PARSE
|
||||
)
|
||||
should_retry = client._should_retry(parse_error)
|
||||
print(f"✓ 解析错误不应该重试: {should_retry}")
|
||||
assert should_retry == False, "解析错误不应该重试"
|
||||
|
||||
# 测试配置错误不应该重试
|
||||
config_error = LLMClientError(
|
||||
"未配置 API Key",
|
||||
error_type=LLMClientError.TYPE_CONFIG
|
||||
)
|
||||
should_retry = client._should_retry(config_error)
|
||||
print(f"✓ 配置错误不应该重试: {should_retry}")
|
||||
assert should_retry == False, "配置错误不应该重试"
|
||||
|
||||
# 测试原始异常检查
|
||||
error_with_original = LLMClientError(
|
||||
"网络请求异常",
|
||||
error_type=LLMClientError.TYPE_NETWORK,
|
||||
original_exception=requests.exceptions.ConnectionError("Connection refused")
|
||||
)
|
||||
should_retry = client._should_retry(error_with_original)
|
||||
print(f"✓ 带原始异常的网络错误应该重试: {should_retry}")
|
||||
assert should_retry == True, "带原始异常的网络错误应该重试"
|
||||
|
||||
print("\n✅ 重试判断逻辑测试通过\n")
|
||||
|
||||
|
||||
def test_error_type_preservation():
|
||||
"""测试错误类型保留"""
|
||||
print("=" * 60)
|
||||
print("测试 3: 错误类型保留")
|
||||
print("=" * 60)
|
||||
|
||||
# 模拟不同状态码的错误
|
||||
test_cases = [
|
||||
(500, LLMClientError.TYPE_SERVER, "服务器错误"),
|
||||
(502, LLMClientError.TYPE_SERVER, "网关错误"),
|
||||
(503, LLMClientError.TYPE_SERVER, "服务不可用"),
|
||||
(504, LLMClientError.TYPE_SERVER, "网关超时"),
|
||||
(429, LLMClientError.TYPE_SERVER, "限流错误"),
|
||||
(400, LLMClientError.TYPE_CLIENT, "请求错误"),
|
||||
(401, LLMClientError.TYPE_CLIENT, "未授权"),
|
||||
(403, LLMClientError.TYPE_CLIENT, "禁止访问"),
|
||||
(404, LLMClientError.TYPE_CLIENT, "未找到"),
|
||||
]
|
||||
|
||||
for status_code, expected_type, description in test_cases:
|
||||
if status_code >= 500:
|
||||
error_type = LLMClientError.TYPE_SERVER
|
||||
elif status_code == 429:
|
||||
error_type = LLMClientError.TYPE_SERVER
|
||||
else:
|
||||
error_type = LLMClientError.TYPE_CLIENT
|
||||
|
||||
print(f"✓ 状态码 {status_code} ({description}): {error_type}")
|
||||
assert error_type == expected_type, f"状态码 {status_code} 的错误类型不正确"
|
||||
|
||||
print("\n✅ 错误类型保留测试通过\n")
|
||||
|
||||
|
||||
def main():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "=" * 60)
|
||||
print("重试策略修复验证测试")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
test_exception_classification()
|
||||
test_should_retry_logic()
|
||||
test_error_type_preservation()
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ 所有测试通过!")
|
||||
print("=" * 60)
|
||||
print("\n修复总结:")
|
||||
print("1. ✅ 为 LLMClientError 添加了错误类型分类")
|
||||
print("2. ✅ 保留了原始异常信息")
|
||||
print("3. ✅ 统一了 _should_retry 判断逻辑")
|
||||
print("4. ✅ 网络异常(超时、连接失败)现在可以正确重试")
|
||||
print("5. ✅ 服务器错误(5xx)和限流(429)可以重试")
|
||||
print("6. ✅ 客户端错误(4xx)、解析错误、配置错误不会重试")
|
||||
print("7. ✅ 增强了重试度量指标记录")
|
||||
print("\n预期效果:")
|
||||
print("- 弱网环境下的稳定性显著提升")
|
||||
print("- 意图识别、生成计划、代码生成的成功率提高")
|
||||
print("- 网络抖动时自动重试并恢复")
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
336
tests/test_runner.py
Normal file
336
tests/test_runner.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
测试运行器
|
||||
提供统一的测试执行和报告生成
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
class TestMetricsCollector(unittest.TestResult):
|
||||
"""
|
||||
测试指标收集器
|
||||
收集测试执行的详细指标
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.test_metrics = []
|
||||
self.start_time = None
|
||||
self.current_test_start = None
|
||||
|
||||
def startTest(self, test):
|
||||
super().startTest(test)
|
||||
self.current_test_start = datetime.now()
|
||||
|
||||
def stopTest(self, test):
|
||||
super().stopTest(test)
|
||||
duration = (datetime.now() - self.current_test_start).total_seconds()
|
||||
|
||||
# 确定测试状态
|
||||
status = 'passed'
|
||||
error_msg = None
|
||||
|
||||
if test in [t[0] for t in self.failures]:
|
||||
status = 'failed'
|
||||
error_msg = [e[1] for e in self.failures if e[0] == test][0]
|
||||
elif test in [t[0] for t in self.errors]:
|
||||
status = 'error'
|
||||
error_msg = [e[1] for e in self.errors if e[0] == test][0]
|
||||
elif test in self.skipped:
|
||||
status = 'skipped'
|
||||
|
||||
# 记录指标
|
||||
self.test_metrics.append({
|
||||
'test_name': str(test),
|
||||
'test_class': test.__class__.__name__,
|
||||
'test_method': test._testMethodName,
|
||||
'status': status,
|
||||
'duration_seconds': duration,
|
||||
'error_message': error_msg
|
||||
})
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""获取测试摘要"""
|
||||
total = self.testsRun
|
||||
passed = len([m for m in self.test_metrics if m['status'] == 'passed'])
|
||||
failed = len(self.failures)
|
||||
errors = len(self.errors)
|
||||
skipped = len(self.skipped)
|
||||
|
||||
total_duration = sum(m['duration_seconds'] for m in self.test_metrics)
|
||||
|
||||
return {
|
||||
'total_tests': total,
|
||||
'passed': passed,
|
||||
'failed': failed,
|
||||
'errors': errors,
|
||||
'skipped': skipped,
|
||||
'success_rate': passed / total if total > 0 else 0,
|
||||
'total_duration_seconds': total_duration,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def run_test_suite(test_modules: List[str], output_dir: Path = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行测试套件并生成报告
|
||||
|
||||
Args:
|
||||
test_modules: 测试模块名称列表
|
||||
output_dir: 报告输出目录
|
||||
|
||||
Returns:
|
||||
测试结果摘要
|
||||
"""
|
||||
# 创建测试套件
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
for module_name in test_modules:
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[''])
|
||||
suite.addTests(loader.loadTestsFromModule(module))
|
||||
except ImportError as e:
|
||||
print(f"警告: 无法加载测试模块 {module_name}: {e}")
|
||||
|
||||
# 运行测试
|
||||
print(f"\n{'='*70}")
|
||||
print(f"开始运行测试套件 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"{'='*70}\n")
|
||||
|
||||
result = TestMetricsCollector()
|
||||
suite.run(result)
|
||||
|
||||
# 生成摘要
|
||||
summary = result.get_summary()
|
||||
|
||||
# 打印结果
|
||||
print(f"\n{'='*70}")
|
||||
print("测试执行摘要")
|
||||
print(f"{'='*70}")
|
||||
print(f"总测试数: {summary['total_tests']}")
|
||||
print(f"通过: {summary['passed']} ✅")
|
||||
print(f"失败: {summary['failed']} ❌")
|
||||
print(f"错误: {summary['errors']} ⚠️")
|
||||
print(f"跳过: {summary['skipped']} ⏭️")
|
||||
print(f"成功率: {summary['success_rate']:.1%}")
|
||||
print(f"总耗时: {summary['total_duration_seconds']:.2f}秒")
|
||||
print(f"{'='*70}\n")
|
||||
|
||||
# 显示失败的测试
|
||||
if result.failures:
|
||||
print("失败的测试:")
|
||||
for test, traceback in result.failures:
|
||||
print(f" ❌ {test}")
|
||||
print(f" {traceback.split(chr(10))[0]}")
|
||||
|
||||
# 显示错误的测试
|
||||
if result.errors:
|
||||
print("\n错误的测试:")
|
||||
for test, traceback in result.errors:
|
||||
print(f" ⚠️ {test}")
|
||||
print(f" {traceback.split(chr(10))[0]}")
|
||||
|
||||
# 保存详细报告
|
||||
if output_dir:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# JSON报告
|
||||
report_data = {
|
||||
'summary': summary,
|
||||
'test_details': result.test_metrics,
|
||||
'failures': [
|
||||
{
|
||||
'test': str(test),
|
||||
'traceback': traceback
|
||||
}
|
||||
for test, traceback in result.failures
|
||||
],
|
||||
'errors': [
|
||||
{
|
||||
'test': str(test),
|
||||
'traceback': traceback
|
||||
}
|
||||
for test, traceback in result.errors
|
||||
]
|
||||
}
|
||||
|
||||
report_file = output_dir / f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
with open(report_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n详细报告已保存到: {report_file}")
|
||||
|
||||
# Markdown报告
|
||||
md_report = generate_markdown_report(summary, result)
|
||||
md_file = output_dir / f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
|
||||
with open(md_file, 'w', encoding='utf-8') as f:
|
||||
f.write(md_report)
|
||||
|
||||
print(f"Markdown报告已保存到: {md_file}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def generate_markdown_report(summary: Dict[str, Any], result: TestMetricsCollector) -> str:
|
||||
"""生成Markdown格式的测试报告"""
|
||||
md = f"""# 测试执行报告
|
||||
|
||||
**生成时间**: {summary['timestamp']}
|
||||
|
||||
## 执行摘要
|
||||
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| 总测试数 | {summary['total_tests']} |
|
||||
| 通过 | {summary['passed']} ✅ |
|
||||
| 失败 | {summary['failed']} ❌ |
|
||||
| 错误 | {summary['errors']} ⚠️ |
|
||||
| 跳过 | {summary['skipped']} ⏭️ |
|
||||
| 成功率 | {summary['success_rate']:.1%} |
|
||||
| 总耗时 | {summary['total_duration_seconds']:.2f}秒 |
|
||||
|
||||
## 测试覆盖矩阵
|
||||
|
||||
### 关键路径覆盖
|
||||
|
||||
"""
|
||||
|
||||
# 按测试类分组
|
||||
test_by_class = {}
|
||||
for metric in result.test_metrics:
|
||||
class_name = metric['test_class']
|
||||
if class_name not in test_by_class:
|
||||
test_by_class[class_name] = []
|
||||
test_by_class[class_name].append(metric)
|
||||
|
||||
for class_name, tests in test_by_class.items():
|
||||
passed = len([t for t in tests if t['status'] == 'passed'])
|
||||
total = len(tests)
|
||||
md += f"\n#### {class_name}\n\n"
|
||||
md += f"- 覆盖率: {passed}/{total} ({passed/total:.1%})\n"
|
||||
md += f"- 测试用例:\n"
|
||||
|
||||
for test in tests:
|
||||
status_icon = {
|
||||
'passed': '✅',
|
||||
'failed': '❌',
|
||||
'error': '⚠️',
|
||||
'skipped': '⏭️'
|
||||
}.get(test['status'], '❓')
|
||||
|
||||
md += f" - {status_icon} `{test['test_method']}` ({test['duration_seconds']:.3f}s)\n"
|
||||
|
||||
# 失败详情
|
||||
if result.failures or result.errors:
|
||||
md += "\n## 失败详情\n\n"
|
||||
|
||||
if result.failures:
|
||||
md += "### 失败的测试\n\n"
|
||||
for test, traceback in result.failures:
|
||||
md += f"#### {test}\n\n"
|
||||
md += "```\n"
|
||||
md += traceback
|
||||
md += "\n```\n\n"
|
||||
|
||||
if result.errors:
|
||||
md += "### 错误的测试\n\n"
|
||||
for test, traceback in result.errors:
|
||||
md += f"#### {test}\n\n"
|
||||
md += "```\n"
|
||||
md += traceback
|
||||
md += "\n```\n\n"
|
||||
|
||||
# 建议
|
||||
md += "\n## 改进建议\n\n"
|
||||
|
||||
if summary['success_rate'] < 1.0:
|
||||
md += "- ⚠️ 存在失败的测试,需要修复\n"
|
||||
|
||||
if summary['success_rate'] >= 0.95:
|
||||
md += "- ✅ 测试覆盖率良好\n"
|
||||
elif summary['success_rate'] >= 0.8:
|
||||
md += "- ⚠️ 建议提高测试覆盖率\n"
|
||||
else:
|
||||
md += "- ❌ 测试覆盖率较低,需要补充测试用例\n"
|
||||
|
||||
return md
|
||||
|
||||
|
||||
def run_critical_path_tests():
|
||||
"""运行关键路径测试"""
|
||||
test_modules = [
|
||||
'test_e2e_integration',
|
||||
'test_security_regression',
|
||||
]
|
||||
|
||||
workspace_path = Path(__file__).parent.parent / "workspace"
|
||||
output_dir = workspace_path / "test_reports"
|
||||
|
||||
summary = run_test_suite(test_modules, output_dir)
|
||||
|
||||
# 返回退出码
|
||||
return 0 if summary['failed'] == 0 and summary['errors'] == 0 else 1
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
test_modules = [
|
||||
'test_intent_classifier',
|
||||
'test_rule_checker',
|
||||
'test_history_manager',
|
||||
'test_task_features',
|
||||
'test_data_governance',
|
||||
'test_config_refresh',
|
||||
'test_retry_fix',
|
||||
'test_e2e_integration',
|
||||
'test_security_regression',
|
||||
]
|
||||
|
||||
workspace_path = Path(__file__).parent.parent / "workspace"
|
||||
output_dir = workspace_path / "test_reports"
|
||||
|
||||
summary = run_test_suite(test_modules, output_dir)
|
||||
|
||||
# 返回退出码
|
||||
return 0 if summary['failed'] == 0 and summary['errors'] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='LocalAgent 测试运行器')
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
choices=['all', 'critical', 'unit'],
|
||||
default='critical',
|
||||
help='测试模式: all(全部), critical(关键路径), unit(单元测试)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == 'all':
|
||||
exit_code = run_all_tests()
|
||||
elif args.mode == 'critical':
|
||||
exit_code = run_critical_path_tests()
|
||||
else: # unit
|
||||
test_modules = [
|
||||
'test_intent_classifier',
|
||||
'test_rule_checker',
|
||||
'test_history_manager',
|
||||
]
|
||||
workspace_path = Path(__file__).parent.parent / "workspace"
|
||||
output_dir = workspace_path / "test_reports"
|
||||
summary = run_test_suite(test_modules, output_dir)
|
||||
exit_code = 0 if summary['failed'] == 0 and summary['errors'] == 0 else 1
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
570
tests/test_security_regression.py
Normal file
570
tests/test_security_regression.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""
|
||||
安全回归测试矩阵
|
||||
专注于安全相关的回归场景
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from safety.rule_checker import RuleChecker, RuleCheckResult
|
||||
from safety.llm_reviewer import LLMReviewer, LLMReviewResult
|
||||
from history.manager import HistoryManager
|
||||
from intent.labels import EXECUTION
|
||||
|
||||
|
||||
class TestSecurityRegressionMatrix(unittest.TestCase):
|
||||
"""
|
||||
安全回归测试矩阵
|
||||
覆盖所有已知的安全风险场景
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.checker = RuleChecker()
|
||||
|
||||
# ========== 硬性禁止回归测试 ==========
|
||||
|
||||
def test_regression_network_operations(self):
|
||||
"""回归测试:网络操作必须被拦截"""
|
||||
test_cases = [
|
||||
("import socket\ns = socket.socket()", "socket模块"),
|
||||
("import requests\nrequests.get('http://example.com')", "requests模块"),
|
||||
("import urllib\nurllib.request.urlopen('http://example.com')", "urllib模块"),
|
||||
("import http.client\nconn = http.client.HTTPConnection('example.com')", "http.client模块"),
|
||||
]
|
||||
|
||||
for code, description in test_cases:
|
||||
with self.subTest(description=description):
|
||||
result = self.checker.check(code)
|
||||
# requests 是警告,其他是硬性拦截
|
||||
if 'requests' in code:
|
||||
self.assertTrue(result.passed, f"{description}应该通过但产生警告")
|
||||
self.assertTrue(len(result.warnings) > 0, f"{description}应该产生警告")
|
||||
else:
|
||||
self.assertFalse(result.passed, f"{description}必须被拦截")
|
||||
|
||||
def test_regression_command_execution(self):
|
||||
"""回归测试:命令执行必须被拦截"""
|
||||
test_cases = [
|
||||
("import subprocess\nsubprocess.run(['ls'])", "subprocess.run"),
|
||||
("import subprocess\nsubprocess.Popen(['dir'])", "subprocess.Popen"),
|
||||
("import subprocess\nsubprocess.call(['echo', 'test'])", "subprocess.call"),
|
||||
("import os\nos.system('dir')", "os.system"),
|
||||
("import os\nos.popen('ls')", "os.popen"),
|
||||
("eval('1+1')", "eval函数"),
|
||||
("exec('print(1)')", "exec函数"),
|
||||
("__import__('os').system('ls')", "__import__动态导入"),
|
||||
]
|
||||
|
||||
for code, description in test_cases:
|
||||
with self.subTest(description=description):
|
||||
result = self.checker.check(code)
|
||||
self.assertFalse(result.passed, f"{description}必须被拦截")
|
||||
self.assertTrue(len(result.violations) > 0, f"{description}必须产生违规记录")
|
||||
|
||||
def test_regression_file_system_warnings(self):
|
||||
"""回归测试:危险文件操作产生警告"""
|
||||
test_cases = [
|
||||
("import os\nos.remove('file.txt')", "os.remove"),
|
||||
("import os\nos.unlink('file.txt')", "os.unlink"),
|
||||
("import shutil\nshutil.rmtree('folder')", "shutil.rmtree"),
|
||||
("from pathlib import Path\nPath('file.txt').unlink()", "Path.unlink"),
|
||||
]
|
||||
|
||||
for code, description in test_cases:
|
||||
with self.subTest(description=description):
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed, f"{description}应该通过检查")
|
||||
self.assertTrue(len(result.warnings) > 0, f"{description}应该产生警告")
|
||||
|
||||
def test_regression_safe_operations(self):
|
||||
"""回归测试:安全操作不应被误拦截"""
|
||||
safe_codes = [
|
||||
# 文件复制
|
||||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
""",
|
||||
# 图片处理
|
||||
"""
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
for img_path in INPUT_DIR.glob('*.png'):
|
||||
img = Image.open(img_path)
|
||||
img = img.resize((100, 100))
|
||||
img.save(OUTPUT_DIR / img_path.name)
|
||||
""",
|
||||
# Excel处理
|
||||
"""
|
||||
import openpyxl
|
||||
from pathlib import Path
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
for xlsx_path in INPUT_DIR.glob('*.xlsx'):
|
||||
wb = openpyxl.load_workbook(xlsx_path)
|
||||
ws = wb.active
|
||||
ws['A1'] = 'Modified'
|
||||
wb.save(OUTPUT_DIR / xlsx_path.name)
|
||||
""",
|
||||
# JSON处理
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
for json_path in INPUT_DIR.glob('*.json'):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
data['processed'] = True
|
||||
with open(OUTPUT_DIR / json_path.name, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
""",
|
||||
]
|
||||
|
||||
for i, code in enumerate(safe_codes):
|
||||
with self.subTest(case=f"安全代码{i+1}"):
|
||||
result = self.checker.check(code)
|
||||
self.assertTrue(result.passed, f"安全代码{i+1}不应被拦截")
|
||||
self.assertEqual(len(result.violations), 0, f"安全代码{i+1}不应有违规")
|
||||
|
||||
|
||||
class TestLLMReviewerRegression(unittest.TestCase):
|
||||
"""
|
||||
LLM审查器回归测试
|
||||
验证软规则审查的稳定性
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.reviewer = LLMReviewer()
|
||||
|
||||
def test_llm_review_response_parsing(self):
|
||||
"""测试:LLM响应解析的鲁棒性"""
|
||||
test_cases = [
|
||||
# 标准JSON格式
|
||||
('{"pass": true, "reason": "代码安全"}', True),
|
||||
('{"pass": false, "reason": "存在风险"}', False),
|
||||
|
||||
# 带代码块的JSON
|
||||
('```json\n{"pass": true, "reason": "安全"}\n```', True),
|
||||
('```\n{"pass": false, "reason": "危险"}\n```', False),
|
||||
|
||||
# 带前缀文本
|
||||
('分析结果如下:{"pass": true, "reason": "通过"}', True),
|
||||
|
||||
# 字符串形式的布尔值
|
||||
('{"pass": "true", "reason": "安全"}', True),
|
||||
('{"pass": "false", "reason": "危险"}', False),
|
||||
|
||||
# 无效JSON(应该保守判定为不通过)
|
||||
('这不是JSON', False),
|
||||
('{"incomplete": true', False),
|
||||
]
|
||||
|
||||
for response, expected_pass in test_cases:
|
||||
with self.subTest(response=response[:30]):
|
||||
result = self.reviewer._parse_response(response)
|
||||
self.assertEqual(result.passed, expected_pass,
|
||||
f"响应 '{response[:30]}...' 解析错误")
|
||||
|
||||
@patch('llm.client.get_client')
|
||||
def test_llm_review_failure_handling(self, mock_get_client):
|
||||
"""测试:LLM调用失败时的降级处理"""
|
||||
# Mock LLM客户端抛出异常
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.side_effect = Exception("API调用失败")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# 执行审查
|
||||
result = self.reviewer.review(
|
||||
user_input="测试任务",
|
||||
execution_plan="测试计划",
|
||||
code="print('test')",
|
||||
warnings=[]
|
||||
)
|
||||
|
||||
# 验证:失败时应保守判定为不通过
|
||||
self.assertFalse(result.passed, "LLM调用失败时应拒绝执行")
|
||||
self.assertIn("失败", result.reason, "应包含失败原因")
|
||||
|
||||
@patch('llm.client.get_client')
|
||||
def test_llm_review_with_warnings(self, mock_get_client):
|
||||
"""测试:带警告的LLM审查"""
|
||||
# Mock LLM客户端
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.return_value = '{"pass": true, "reason": "警告已审查,风险可控"}'
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# 执行审查(带警告)
|
||||
warnings = ["使用了 os.remove", "使用了 requests"]
|
||||
result = self.reviewer.review(
|
||||
user_input="删除文件并上传",
|
||||
execution_plan="删除本地文件后上传到服务器",
|
||||
code="import os\nimport requests\nos.remove('file.txt')\nrequests.post('http://api.example.com')",
|
||||
warnings=warnings
|
||||
)
|
||||
|
||||
# 验证:调用参数应包含警告信息
|
||||
call_args = mock_client.chat.call_args
|
||||
messages = call_args[1]['messages']
|
||||
user_message = messages[1]['content']
|
||||
|
||||
self.assertIn("静态检查警告", user_message, "应传递警告信息给LLM")
|
||||
self.assertIn("os.remove", user_message, "应包含具体警告内容")
|
||||
|
||||
|
||||
class TestHistoryReuseSecurityRegression(unittest.TestCase):
|
||||
"""
|
||||
历史复用安全回归测试
|
||||
确保复用流程不会绕过安全检查
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
self.history = HistoryManager(self.temp_dir)
|
||||
self.checker = RuleChecker()
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_reuse_security_bypass_prevention(self):
|
||||
"""测试:防止通过复用绕过安全检查"""
|
||||
# 场景:历史记录中存在一个"曾经通过"但现在应该被拦截的代码
|
||||
|
||||
# 1. 添加历史记录(模拟旧版本允许的代码)
|
||||
old_dangerous_code = """
|
||||
import socket
|
||||
|
||||
# 旧版本可能允许的网络操作
|
||||
s = socket.socket()
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="old_task_001",
|
||||
user_input="建立网络连接",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.9,
|
||||
execution_plan="创建socket连接",
|
||||
code=old_dangerous_code,
|
||||
success=True, # 历史上标记为成功
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 2. 尝试复用
|
||||
result = self.history.find_similar_success("创建网络连接", return_details=True)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
similar_record, _, _ = result
|
||||
|
||||
# 3. 强制安全复检(关键步骤)
|
||||
recheck_result = self.checker.check(similar_record.code)
|
||||
|
||||
# 4. 验证:必须被当前规则拦截
|
||||
self.assertFalse(recheck_result.passed,
|
||||
"历史代码复用时必须被当前安全规则拦截")
|
||||
self.assertTrue(any('socket' in v for v in recheck_result.violations),
|
||||
"必须检测到socket违规")
|
||||
|
||||
def test_reuse_with_modified_dangerous_code(self):
|
||||
"""测试:复用后修改为危险代码的检测"""
|
||||
# 1. 添加安全的历史记录
|
||||
safe_code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="safe_task_001",
|
||||
user_input="复制文件",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制txt文件",
|
||||
code=safe_code,
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 2. 模拟用户修改代码(添加危险操作)
|
||||
modified_dangerous_code = safe_code + """
|
||||
# 用户添加的危险操作
|
||||
import subprocess
|
||||
subprocess.run(['dir'], shell=True)
|
||||
"""
|
||||
|
||||
# 3. 安全检查修改后的代码
|
||||
check_result = self.checker.check(modified_dangerous_code)
|
||||
|
||||
# 4. 验证:必须检测到新增的危险操作
|
||||
self.assertFalse(check_result.passed, "修改后的危险代码必须被拦截")
|
||||
self.assertTrue(any('subprocess' in v for v in check_result.violations))
|
||||
|
||||
def test_reuse_multiple_security_layers(self):
|
||||
"""测试:复用时的多层安全检查"""
|
||||
# 1. 添加包含警告操作的历史记录
|
||||
warning_code = """
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
# 先删除旧文件
|
||||
for f in OUTPUT_DIR.glob('*.txt'):
|
||||
os.remove(f)
|
||||
|
||||
# 再复制新文件
|
||||
for f in INPUT_DIR.glob('*.txt'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
self.history.add_record(
|
||||
task_id="warning_task_001",
|
||||
user_input="清空并复制文件",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.9,
|
||||
execution_plan="删除旧文件并复制新文件",
|
||||
code=warning_code,
|
||||
success=True,
|
||||
duration_ms=150
|
||||
)
|
||||
|
||||
# 2. 复用并进行安全检查
|
||||
result = self.history.find_similar_success("清空目录并复制", return_details=True)
|
||||
similar_record, _, _ = result
|
||||
|
||||
# 3. 第一层:硬规则检查
|
||||
rule_result = self.checker.check(similar_record.code)
|
||||
self.assertTrue(rule_result.passed, "应该通过硬规则检查")
|
||||
self.assertTrue(len(rule_result.warnings) > 0, "应该产生警告")
|
||||
|
||||
# 4. 第二层:LLM审查(Mock)
|
||||
with patch('llm.client.get_client') as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.return_value = '{"pass": true, "reason": "删除操作在workspace内,风险可控"}'
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
reviewer = LLMReviewer()
|
||||
llm_result = reviewer.review(
|
||||
user_input=similar_record.user_input,
|
||||
execution_plan=similar_record.execution_plan,
|
||||
code=similar_record.code,
|
||||
warnings=rule_result.warnings
|
||||
)
|
||||
|
||||
# 验证:LLM收到了警告信息
|
||||
call_args = mock_client.chat.call_args
|
||||
messages = call_args[1]['messages']
|
||||
user_message = messages[1]['content']
|
||||
self.assertIn("静态检查警告", user_message)
|
||||
|
||||
|
||||
class TestSecurityMetricsRegression(unittest.TestCase):
|
||||
"""
|
||||
安全指标回归测试
|
||||
确保安全相关的度量指标正确记录
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""创建测试环境"""
|
||||
self.temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
def tearDown(self):
|
||||
"""清理测试环境"""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_security_metrics_persistence(self):
|
||||
"""测试:安全指标的持久化"""
|
||||
from safety.security_metrics import SecurityMetrics
|
||||
|
||||
# 1. 创建指标实例并记录数据
|
||||
metrics1 = SecurityMetrics(self.temp_dir)
|
||||
metrics1.add_reuse_recheck()
|
||||
metrics1.add_reuse_recheck()
|
||||
metrics1.add_reuse_block()
|
||||
|
||||
# 2. 创建新实例(模拟重启)
|
||||
metrics2 = SecurityMetrics(self.temp_dir)
|
||||
|
||||
# 3. 验证:数据应该被持久化
|
||||
stats = metrics2.get_stats()
|
||||
self.assertEqual(stats['reuse_recheck_count'], 2)
|
||||
self.assertEqual(stats['reuse_block_count'], 1)
|
||||
|
||||
def test_security_metrics_accuracy(self):
|
||||
"""测试:安全指标计算的准确性"""
|
||||
from safety.security_metrics import SecurityMetrics
|
||||
|
||||
metrics = SecurityMetrics(self.temp_dir)
|
||||
|
||||
# 记录10次复检,3次拦截
|
||||
for _ in range(10):
|
||||
metrics.add_reuse_recheck()
|
||||
|
||||
for _ in range(3):
|
||||
metrics.add_reuse_block()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
# 验证计数
|
||||
self.assertEqual(stats['reuse_recheck_count'], 10)
|
||||
self.assertEqual(stats['reuse_block_count'], 3)
|
||||
|
||||
# 验证拦截率
|
||||
expected_rate = 3 / 10
|
||||
self.assertAlmostEqual(stats['reuse_block_rate'], expected_rate, places=2)
|
||||
|
||||
|
||||
class TestCriticalPathCoverage(unittest.TestCase):
|
||||
"""
|
||||
关键路径覆盖测试
|
||||
确保所有关键安全路径都被测试覆盖
|
||||
"""
|
||||
|
||||
def test_critical_path_new_code_generation(self):
|
||||
"""关键路径:新代码生成 -> 安全检查 -> 执行"""
|
||||
checker = RuleChecker()
|
||||
|
||||
# 1. 生成新代码(模拟)
|
||||
new_code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*.png'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
# 2. 硬规则检查
|
||||
rule_result = checker.check(new_code)
|
||||
self.assertTrue(rule_result.passed)
|
||||
|
||||
# 3. LLM审查(Mock)
|
||||
with patch('llm.client.get_client') as mock_get_client:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.return_value = '{"pass": true, "reason": "代码安全"}'
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
reviewer = LLMReviewer()
|
||||
llm_result = reviewer.review(
|
||||
user_input="复制图片",
|
||||
execution_plan="复制png文件",
|
||||
code=new_code,
|
||||
warnings=rule_result.warnings
|
||||
)
|
||||
|
||||
self.assertTrue(llm_result.passed)
|
||||
|
||||
def test_critical_path_code_reuse(self):
|
||||
"""关键路径:代码复用 -> 安全复检 -> 执行"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
history = HistoryManager(temp_dir)
|
||||
checker = RuleChecker()
|
||||
|
||||
# 1. 添加历史记录
|
||||
reuse_code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
for f in INPUT_DIR.glob('*.jpg'):
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
history.add_record(
|
||||
task_id="reuse_001",
|
||||
user_input="复制jpg图片",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制jpg文件",
|
||||
code=reuse_code,
|
||||
success=True,
|
||||
duration_ms=100
|
||||
)
|
||||
|
||||
# 2. 查找相似任务
|
||||
result = history.find_similar_success("复制jpeg图片", return_details=True)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
similar_record, _, _ = result
|
||||
|
||||
# 3. 安全复检(关键步骤)
|
||||
recheck_result = checker.check(similar_record.code)
|
||||
self.assertTrue(recheck_result.passed, "复用代码必须通过安全复检")
|
||||
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def test_critical_path_code_fix_retry(self):
|
||||
"""关键路径:失败重试 -> 代码修复 -> 安全检查 -> 执行"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
history = HistoryManager(temp_dir)
|
||||
checker = RuleChecker()
|
||||
|
||||
# 1. 添加失败的历史记录
|
||||
failed_code = """
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path('workspace/input')
|
||||
OUTPUT_DIR = Path('workspace/output')
|
||||
|
||||
# 错误:路径拼写错误
|
||||
for f in INPUT_DIR.glob('*.pngg'): # 注意:pngg是错误的
|
||||
shutil.copy(f, OUTPUT_DIR / f.name)
|
||||
"""
|
||||
|
||||
history.add_record(
|
||||
task_id="failed_001",
|
||||
user_input="复制png图片",
|
||||
intent_label=EXECUTION,
|
||||
intent_confidence=0.95,
|
||||
execution_plan="复制png文件",
|
||||
code=failed_code,
|
||||
success=False,
|
||||
duration_ms=50,
|
||||
stderr="没有找到文件"
|
||||
)
|
||||
|
||||
# 2. 修复代码(模拟AI修复)
|
||||
fixed_code = failed_code.replace('*.pngg', '*.png')
|
||||
|
||||
# 3. 安全检查修复后的代码
|
||||
check_result = checker.check(fixed_code)
|
||||
self.assertTrue(check_result.passed, "修复后的代码必须通过安全检查")
|
||||
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行测试并生成详细报告
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
142
tests/test_task_features.py
Normal file
142
tests/test_task_features.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
任务特征提取与匹配的测试用例
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from history.task_features import TaskFeatureExtractor, TaskMatcher
|
||||
|
||||
|
||||
def test_feature_extraction():
|
||||
"""测试特征提取"""
|
||||
print("=" * 60)
|
||||
print("测试 1: 特征提取")
|
||||
print("=" * 60)
|
||||
|
||||
extractor = TaskFeatureExtractor()
|
||||
|
||||
# 测试用例 1
|
||||
input1 = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
features1 = extractor.extract(input1)
|
||||
|
||||
print(f"\n输入: {input1}")
|
||||
print(f"文件格式: {features1.file_formats}")
|
||||
print(f"目录路径: {features1.directory_paths}")
|
||||
print(f"命名规则: {features1.naming_patterns}")
|
||||
print(f"操作类型: {features1.operations}")
|
||||
print(f"数量信息: {features1.quantities}")
|
||||
|
||||
# 测试用例 2
|
||||
input2 = "批量转换 C:/documents 下的 100 个 .docx 文件为 .pdf"
|
||||
features2 = extractor.extract(input2)
|
||||
|
||||
print(f"\n输入: {input2}")
|
||||
print(f"文件格式: {features2.file_formats}")
|
||||
print(f"目录路径: {features2.directory_paths}")
|
||||
print(f"命名规则: {features2.naming_patterns}")
|
||||
print(f"操作类型: {features2.operations}")
|
||||
print(f"数量信息: {features2.quantities}")
|
||||
|
||||
|
||||
def test_similarity_matching():
|
||||
"""测试相似度匹配"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 2: 相似度匹配")
|
||||
print("=" * 60)
|
||||
|
||||
matcher = TaskMatcher()
|
||||
|
||||
# 测试场景 1: 高度相似(仅目录不同)
|
||||
print("\n场景 1: 高度相似任务(仅目录不同)")
|
||||
current1 = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
history1 = "将 C:/images 目录下的所有 .jpg 图片按日期重命名"
|
||||
|
||||
score1, diffs1 = matcher.calculate_similarity(current1, history1)
|
||||
print(f"当前任务: {current1}")
|
||||
print(f"历史任务: {history1}")
|
||||
print(f"相似度: {score1:.2%}")
|
||||
print(f"差异数量: {len(diffs1)}")
|
||||
for diff in diffs1:
|
||||
print(f" - {diff.category} [{diff.importance}]: 当前={diff.current_value}, 历史={diff.history_value}")
|
||||
|
||||
# 测试场景 2: 中等相似(格式和操作不同)
|
||||
print("\n场景 2: 中等相似任务(格式和操作不同)")
|
||||
current2 = "将 D:/photos 目录下的所有 .jpg 图片转换为 .png"
|
||||
history2 = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
|
||||
score2, diffs2 = matcher.calculate_similarity(current2, history2)
|
||||
print(f"当前任务: {current2}")
|
||||
print(f"历史任务: {history2}")
|
||||
print(f"相似度: {score2:.2%}")
|
||||
print(f"差异数量: {len(diffs2)}")
|
||||
for diff in diffs2:
|
||||
print(f" - {diff.category} [{diff.importance}]: 当前={diff.current_value}, 历史={diff.history_value}")
|
||||
|
||||
# 测试场景 3: 低相似度(完全不同的任务)
|
||||
print("\n场景 3: 低相似度任务(完全不同)")
|
||||
current3 = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
history3 = "统计 C:/documents 下所有 .txt 文件的行数"
|
||||
|
||||
score3, diffs3 = matcher.calculate_similarity(current3, history3)
|
||||
print(f"当前任务: {current3}")
|
||||
print(f"历史任务: {history3}")
|
||||
print(f"相似度: {score3:.2%}")
|
||||
print(f"差异数量: {len(diffs3)}")
|
||||
for diff in diffs3:
|
||||
print(f" - {diff.category} [{diff.importance}]: 当前={diff.current_value}, 历史={diff.history_value}")
|
||||
|
||||
# 测试场景 4: 关键参数差异(数量不同)
|
||||
print("\n场景 4: 关键参数差异(数量不同)")
|
||||
current4 = "批量转换 100 个 .docx 文件为 .pdf"
|
||||
history4 = "批量转换所有 .docx 文件为 .pdf"
|
||||
|
||||
score4, diffs4 = matcher.calculate_similarity(current4, history4)
|
||||
print(f"当前任务: {current4}")
|
||||
print(f"历史任务: {history4}")
|
||||
print(f"相似度: {score4:.2%}")
|
||||
print(f"差异数量: {len(diffs4)}")
|
||||
for diff in diffs4:
|
||||
print(f" - {diff.category} [{diff.importance}]: 当前={diff.current_value}, 历史={diff.history_value}")
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
"""测试边界情况"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 3: 边界情况")
|
||||
print("=" * 60)
|
||||
|
||||
matcher = TaskMatcher()
|
||||
|
||||
# 空输入
|
||||
print("\n边界 1: 空输入")
|
||||
score, diffs = matcher.calculate_similarity("", "")
|
||||
print(f"相似度: {score:.2%}, 差异数: {len(diffs)}")
|
||||
|
||||
# 完全相同
|
||||
print("\n边界 2: 完全相同")
|
||||
same_input = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
score, diffs = matcher.calculate_similarity(same_input, same_input)
|
||||
print(f"相似度: {score:.2%}, 差异数: {len(diffs)}")
|
||||
|
||||
# 仅标点不同
|
||||
print("\n边界 3: 仅标点不同")
|
||||
input_a = "将D:/photos目录下的所有.jpg图片按日期重命名"
|
||||
input_b = "将 D:/photos 目录下的所有 .jpg 图片按日期重命名"
|
||||
score, diffs = matcher.calculate_similarity(input_a, input_b)
|
||||
print(f"相似度: {score:.2%}, 差异数: {len(diffs)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_feature_extraction()
|
||||
test_similarity_matching()
|
||||
test_edge_cases()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
191
tests/verify_tests.py
Normal file
191
tests/verify_tests.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
快速验证脚本
|
||||
验证新增测试的基本功能
|
||||
"""
|
||||
|
||||
import sys
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
# 设置标准输出编码为UTF-8
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
def test_imports():
|
||||
"""测试所有测试模块是否可以正常导入"""
|
||||
print("=" * 70)
|
||||
print("测试模块导入验证")
|
||||
print("=" * 70)
|
||||
|
||||
modules = [
|
||||
'tests.test_e2e_integration',
|
||||
'tests.test_security_regression',
|
||||
'tests.test_runner',
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
failed_modules = []
|
||||
|
||||
for module_name in modules:
|
||||
try:
|
||||
__import__(module_name)
|
||||
print(f"✅ {module_name} - 导入成功")
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {module_name} - 导入失败: {e}")
|
||||
failed_modules.append((module_name, str(e)))
|
||||
|
||||
print(f"\n导入结果: {success_count}/{len(modules)} 成功")
|
||||
|
||||
if failed_modules:
|
||||
print("\n失败详情:")
|
||||
for module, error in failed_modules:
|
||||
print(f" - {module}: {error}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_test_classes():
|
||||
"""测试关键测试类是否存在"""
|
||||
print("\n" + "=" * 70)
|
||||
print("测试类验证")
|
||||
print("=" * 70)
|
||||
|
||||
test_classes = [
|
||||
('tests.test_e2e_integration', 'TestCodeReuseSecurityRegression'),
|
||||
('tests.test_e2e_integration', 'TestConfigHotReloadRegression'),
|
||||
('tests.test_e2e_integration', 'TestExecutionResultThreeStateRegression'),
|
||||
('tests.test_security_regression', 'TestSecurityRegressionMatrix'),
|
||||
('tests.test_security_regression', 'TestLLMReviewerRegression'),
|
||||
('tests.test_security_regression', 'TestCriticalPathCoverage'),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
for module_name, class_name in test_classes:
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[class_name])
|
||||
test_class = getattr(module, class_name)
|
||||
print(f"✅ {module_name}.{class_name} - 存在")
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {module_name}.{class_name} - 不存在: {e}")
|
||||
|
||||
print(f"\n验证结果: {success_count}/{len(test_classes)} 成功")
|
||||
|
||||
return success_count == len(test_classes)
|
||||
|
||||
|
||||
def test_runner_functionality():
|
||||
"""测试测试运行器的基本功能"""
|
||||
print("\n" + "=" * 70)
|
||||
print("测试运行器功能验证")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
from tests.test_runner import TestMetricsCollector
|
||||
|
||||
# 创建指标收集器
|
||||
collector = TestMetricsCollector()
|
||||
print("✅ TestMetricsCollector 创建成功")
|
||||
|
||||
# 测试摘要生成
|
||||
summary = collector.get_summary()
|
||||
print("✅ 摘要生成功能正常")
|
||||
|
||||
# 验证摘要字段
|
||||
required_fields = ['total_tests', 'passed', 'failed', 'errors', 'skipped', 'success_rate']
|
||||
for field in required_fields:
|
||||
if field in summary:
|
||||
print(f" ✅ 摘要包含字段: {field}")
|
||||
else:
|
||||
print(f" ❌ 摘要缺少字段: {field}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试运行器验证失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def count_test_methods():
|
||||
"""统计测试方法数量"""
|
||||
print("\n" + "=" * 70)
|
||||
print("测试方法统计")
|
||||
print("=" * 70)
|
||||
|
||||
import unittest
|
||||
|
||||
modules = [
|
||||
'tests.test_e2e_integration',
|
||||
'tests.test_security_regression',
|
||||
]
|
||||
|
||||
total_tests = 0
|
||||
|
||||
for module_name in modules:
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[''])
|
||||
loader = unittest.TestLoader()
|
||||
suite = loader.loadTestsFromModule(module)
|
||||
count = suite.countTestCases()
|
||||
print(f"📊 {module_name}: {count} 个测试方法")
|
||||
total_tests += count
|
||||
except Exception as e:
|
||||
print(f"❌ {module_name}: 统计失败 - {e}")
|
||||
|
||||
print(f"\n总计: {total_tests} 个测试方法")
|
||||
return total_tests
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 70)
|
||||
print("LocalAgent 测试验证工具")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
results = []
|
||||
|
||||
# 1. 测试导入
|
||||
results.append(("模块导入", test_imports()))
|
||||
|
||||
# 2. 测试类验证
|
||||
results.append(("测试类验证", test_test_classes()))
|
||||
|
||||
# 3. 测试运行器功能
|
||||
results.append(("测试运行器", test_runner_functionality()))
|
||||
|
||||
# 4. 统计测试方法
|
||||
test_count = count_test_methods()
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 70)
|
||||
print("验证总结")
|
||||
print("=" * 70)
|
||||
|
||||
for name, result in results:
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f"{name}: {status}")
|
||||
|
||||
all_passed = all(result for _, result in results)
|
||||
|
||||
if all_passed:
|
||||
print(f"\n🎉 所有验证通过!共 {test_count} 个测试方法可用。")
|
||||
print("\n下一步:")
|
||||
print(" 1. 运行关键路径测试: python tests/test_runner.py --mode critical")
|
||||
print(" 2. 运行所有测试: python tests/test_runner.py --mode all")
|
||||
print(" 3. 使用批处理脚本: run_tests.bat")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ 部分验证失败,请检查错误信息。")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
|
||||
Reference in New Issue
Block a user