Files
context-gatekeeper/tests/test_gatekeeper.py
Elaina 071f9ef418 feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
2026-04-22 01:09:35 +08:00

136 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
上下文门控器 - 轻量级选择器测试
"""
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
class TestAnchorExtraction:
"""锚点提取测试"""
def test_chinese_ngram(self):
gate = ContextGatekeeper()
# 触发锚点提取
gate.add_turn("Redis 分布式锁如何实现", "可以使用 Redlock 算法...")
assert len(gate.blocks) == 1
anchors = gate.blocks[0].anchors
# 应该有 redis、分布式锁等锚点
assert any('redis' in a.lower() for a in anchors)
def test_empty_turn(self):
gate = ContextGatekeeper()
gate.add_turn("", "")
assert len(gate.blocks) == 1
class TestTopicGate:
"""话题门控测试"""
def test_obvious_switch(self):
"""从 Redis 切换到 Python明显切换"""
gate = ContextGatekeeper()
gate.add_turn("Redis 锁怎么用", "用 Redlock...")
gate.add_turn("Python 怎么写快速排序", "可以用递归...")
# 第三个问题,明显是 Python 话题
blocks = gate.select("Python 列表推导式怎么写")
# 应该主要返回第二、三轮的 Python 内容
turn_ids = [b['turn_id'] for b in blocks]
assert 2 in turn_ids or 3 in turn_ids
def test_continuation_with_deictic(self):
"""指代词触发强制继承"""
gate = ContextGatekeeper()
gate.add_turn("Redis 锁是什么", "是分布式锁...")
gate.add_turn("它有哪些实现方式", "有 Redlock...")
# 第二轮有"它",应该强制继承第一轮
blocks = gate.select("它有哪些实现方式")
turn_ids = [b['turn_id'] for b in blocks]
# 应该包含第1轮
assert 1 in turn_ids
class TestSparseRetrieval:
"""稀疏召回测试"""
def test_exact_match_boost(self):
"""exact match 应该提升得分"""
gate = ContextGatekeeper()
gate.add_turn("使用 Redis v3.0 集群", "Redis 集群配置...")
gate.add_turn("Python 装饰器用法", "装饰器是...")
blocks = gate.select("Redis v3.0 集群如何搭建")
turn_ids = [b['turn_id'] for b in blocks]
# 应该召回第一轮
assert 1 in turn_ids
class TestMinimumCoverage:
"""最小覆盖选择测试"""
def test_token_budget(self):
"""token 预算限制"""
gate = ContextGatekeeper(token_budget=300) # 调整预算100字符block约150tokens
gate.add_turn("A" * 100, "B" * 100)
gate.add_turn("C" * 100, "D" * 100)
blocks = gate.select("A")
# 预算300约等于2个block的代价应该只能选1个
assert len(blocks) <= 2, f"预算300应该只选1-2个block实际选了{len(blocks)}"
class TestIntegration:
"""端到端集成测试"""
def test_multi_turn_conversation(self):
"""模拟多轮对话"""
gate = ContextGatekeeper()
# 第1轮
gate.add_turn("Redis 锁续租为什么会脑裂", "因为锁过期时间设置不合理...")
# 第2轮
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
# 第3轮切换话题
gate.add_turn("Python 异步编程怎么做", "用 asyncio 模块...")
# 问 Redis 相关问题,验证能召回 Redis 内容
redis_blocks = gate.select("Redis 锁的 TTL 怎么设")
turn_ids = [b['turn_id'] for b in redis_blocks]
assert 1 in turn_ids, f"第1轮 Redis 内容应该被召回,实际: {turn_ids}"
print(f" Redis 查询召回: {turn_ids}")
# 问 Python 相关问题,验证话题切换后召回正确内容
py_blocks = gate.select("asyncio 怎么用")
turn_ids_py = [b['turn_id'] for b in py_blocks]
assert 3 in turn_ids_py, f"第3轮 Python 内容应该被召回,实际: {turn_ids_py}"
print(f" Python 查询召回: {turn_ids_py}")
def test_constraints_preserved(self):
"""约束持久化测试"""
gate = ContextGatekeeper()
gate.set_constraint("language", "中文")
gate.set_constraint("style", "简洁")
constraints = gate.get_constraints()
assert constraints["language"] == "中文"
assert constraints["style"] == "简洁"
def test_reset(self):
"""重置测试"""
gate = ContextGatekeeper()
gate.add_turn("test", "test")
assert len(gate.blocks) == 1
gate.reset()
assert len(gate.blocks) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])