""" 安全检查器单元测试 """ import unittest import sys from pathlib import Path # 添加项目根目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) from safety.rule_checker import RuleChecker, check_code_safety class TestRuleChecker(unittest.TestCase): """规则检查器测试""" def setUp(self): self.checker = RuleChecker() # ========== 硬性禁止测试 ========== def test_block_socket_import(self): """测试禁止 socket 模块""" code = "import socket\ns = socket.socket()" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('socket' in v for v in result.violations)) def test_block_subprocess_import(self): """测试禁止 subprocess 模块""" code = "import subprocess\nsubprocess.run(['ls'])" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('subprocess' in v for v in result.violations)) def test_block_eval(self): """测试禁止 eval""" code = "result = eval('1+1')" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('eval' in v for v in result.violations)) def test_block_exec(self): """测试禁止 exec""" code = "exec('print(1)')" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('exec' in v for v in result.violations)) def test_block_os_system(self): """测试禁止 os.system""" code = "import os\nos.system('dir')" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('os.system' in v for v in result.violations)) def test_block_os_popen(self): """测试禁止 os.popen""" code = "import os\nos.popen('dir')" result = self.checker.check(code) self.assertFalse(result.passed) self.assertTrue(any('os.popen' in v for v in result.violations)) # ========== 警告测试 ========== def test_warn_requests_import(self): """测试 requests 模块产生警告""" code = "import requests\nresponse = requests.get('http://example.com')" result = self.checker.check(code) self.assertTrue(result.passed) # 不应该被阻止 self.assertTrue(any('requests' in w for w in result.warnings)) def test_warn_os_remove(self): """测试 os.remove 产生警告""" code = "import os\nos.remove('file.txt')" result = self.checker.check(code) self.assertTrue(result.passed) # 不应该被阻止 self.assertTrue(any('os.remove' in w for w in result.warnings)) def test_warn_shutil_rmtree(self): """测试 shutil.rmtree 产生警告""" code = "import shutil\nshutil.rmtree('folder')" result = self.checker.check(code) self.assertTrue(result.passed) # 不应该被阻止 self.assertTrue(any('shutil.rmtree' in w for w in result.warnings)) # ========== 安全代码测试 ========== def test_safe_file_copy(self): """测试安全的文件复制代码""" code = """ import shutil from pathlib import Path INPUT_DIR = Path('workspace/input') OUTPUT_DIR = Path('workspace/output') for f in INPUT_DIR.glob('*'): shutil.copy(f, OUTPUT_DIR / f.name) """ result = self.checker.check(code) self.assertTrue(result.passed) self.assertEqual(len(result.violations), 0) def test_safe_image_processing(self): """测试安全的图片处理代码""" code = """ from PIL import Image from pathlib import Path INPUT_DIR = Path('workspace/input') OUTPUT_DIR = Path('workspace/output') for img_path in INPUT_DIR.glob('*.png'): img = Image.open(img_path) img = img.resize((100, 100)) img.save(OUTPUT_DIR / img_path.name) """ result = self.checker.check(code) self.assertTrue(result.passed) self.assertEqual(len(result.violations), 0) def test_safe_excel_processing(self): """测试安全的 Excel 处理代码""" code = """ import openpyxl from pathlib import Path INPUT_DIR = Path('workspace/input') OUTPUT_DIR = Path('workspace/output') for xlsx_path in INPUT_DIR.glob('*.xlsx'): wb = openpyxl.load_workbook(xlsx_path) ws = wb.active ws['A1'] = 'Modified' wb.save(OUTPUT_DIR / xlsx_path.name) """ result = self.checker.check(code) self.assertTrue(result.passed) self.assertEqual(len(result.violations), 0) class TestCheckCodeSafety(unittest.TestCase): """便捷函数测试""" def test_convenience_function(self): """测试便捷函数""" result = check_code_safety("print('hello')") self.assertTrue(result.passed) def test_convenience_function_block(self): """测试便捷函数阻止危险代码""" result = check_code_safety("import socket") self.assertFalse(result.passed) if __name__ == '__main__': unittest.main()