feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
147
tests/test_e2e.py
Normal file
147
tests/test_e2e.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
端到端测试 - 使用 MiniMax API 验证上下文门控器效果
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载 .env
|
||||
load_dotenv()
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
# 获取 API Key
|
||||
API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||
if not API_KEY:
|
||||
print("❌ 未找到 MINIMAX_API_KEY,请检查 .env 文件")
|
||||
sys.exit(1)
|
||||
|
||||
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
|
||||
|
||||
|
||||
def call_minimax(prompt: str) -> str:
|
||||
"""调用 MiniMax API"""
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
payload = {
|
||||
"model": "MiniMax-M2.7",
|
||||
"messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7
|
||||
}
|
||||
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
BASE_URL,
|
||||
data=data,
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
method="POST"
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
result = json.loads(resp.read().decode("utf-8"))
|
||||
return result["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
return f"API 调用失败: {e}"
|
||||
|
||||
|
||||
def test_e2e_conversation():
|
||||
"""端到端对话测试"""
|
||||
print("=" * 60)
|
||||
print("端到端对话测试 - 验证上下文门控器")
|
||||
print("=" * 60)
|
||||
|
||||
gate = ContextGatekeeper(token_budget=2000)
|
||||
|
||||
# === 第一阶段:Redis 分布式锁话题 ===
|
||||
print("\n📌 第1轮:Redis 分布式锁话题")
|
||||
user1 = "Redis 锁续租为什么会脑裂"
|
||||
assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}")
|
||||
gate.add_turn(user1, assistant1)
|
||||
print(f"用户: {user1}")
|
||||
print(f"助手: {assistant1}")
|
||||
|
||||
print("\n📌 第2轮:继续 Redis 话题")
|
||||
user2 = "如何避免脑裂?"
|
||||
assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}")
|
||||
gate.add_turn(user2, assistant2)
|
||||
print(f"用户: {user2}")
|
||||
print(f"助手: {assistant2}")
|
||||
|
||||
# 验证:第3轮问 Redis 相关,应该召回第1轮(可能不召回第2轮,取决于锚点重叠度)
|
||||
print("\n📌 第3轮:问 Redis 相关问题")
|
||||
query3 = "锁的 TTL 怎么设置才合理"
|
||||
selected = gate.select(query3)
|
||||
print(f"用户查询: {query3}")
|
||||
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}")
|
||||
turn_ids = [b['turn_id'] for b in selected]
|
||||
assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容"
|
||||
print("✅ Redis 相关问题召回第1轮内容")
|
||||
|
||||
# === 第二阶段:切换到 Python 话题 ===
|
||||
print("\n" + "=" * 60)
|
||||
print("📌 第4轮:切换到 Python 话题")
|
||||
user4 = "Python 异步编程怎么做?用 asyncio 举例子"
|
||||
assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}")
|
||||
gate.add_turn(user4, assistant4)
|
||||
print(f"用户: {user4}")
|
||||
print(f"助手: {assistant4}")
|
||||
|
||||
# 验证:问 Python 相关,应该召回 Python 内容(3或4)
|
||||
print("\n📌 第5轮:问 Python asyncio 相关")
|
||||
query5 = "asyncio 怎么用?举一个爬虫的例子"
|
||||
selected5 = gate.select(query5)
|
||||
print(f"用户查询: {query5}")
|
||||
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}")
|
||||
turn_ids5 = [b['turn_id'] for b in selected5]
|
||||
# Python 话题应该召回 3 或 4
|
||||
has_python_topic = (3 in turn_ids5 or 4 in turn_ids5)
|
||||
assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}"
|
||||
print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}")
|
||||
|
||||
# === 第三阶段:指代词测试 ===
|
||||
print("\n" + "=" * 60)
|
||||
print("📌 第6轮:指代词测试")
|
||||
user6 = "它的并发性能怎么样"
|
||||
assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}")
|
||||
gate.add_turn(user6, assistant6)
|
||||
print(f"用户: {user6}")
|
||||
print(f"助手: {assistant6}")
|
||||
|
||||
# 验证:有指代词时,应该强制继承最近轮次
|
||||
print("\n📌 第7轮:指代词强制继承验证")
|
||||
query7 = "它和 ThreadPool 比哪个更好"
|
||||
selected7 = gate.select(query7)
|
||||
print(f"用户查询: {query7}")
|
||||
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}")
|
||||
# 应该有指代词强制继承
|
||||
assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block"
|
||||
print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block")
|
||||
|
||||
# === 验证 Token 预算控制 ===
|
||||
print("\n" + "=" * 60)
|
||||
print("📌 Token 预算验证")
|
||||
total_context_tokens = sum(
|
||||
len(b['user']) * 1.5 + len(b['assistant']) * 1.5
|
||||
for b in selected
|
||||
)
|
||||
print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}")
|
||||
assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算"
|
||||
print("✅ Token 预算控制正常")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 所有端到端测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_e2e_conversation()
|
||||
136
tests/test_gatekeeper.py
Normal file
136
tests/test_gatekeeper.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
上下文门控器 - 轻量级选择器测试
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
|
||||
class TestAnchorExtraction:
|
||||
"""锚点提取测试"""
|
||||
|
||||
def test_chinese_ngram(self):
|
||||
gate = ContextGatekeeper()
|
||||
# 触发锚点提取
|
||||
gate.add_turn("Redis 分布式锁如何实现", "可以使用 Redlock 算法...")
|
||||
assert len(gate.blocks) == 1
|
||||
anchors = gate.blocks[0].anchors
|
||||
# 应该有 redis、分布式锁等锚点
|
||||
assert any('redis' in a.lower() for a in anchors)
|
||||
|
||||
def test_empty_turn(self):
|
||||
gate = ContextGatekeeper()
|
||||
gate.add_turn("", "")
|
||||
assert len(gate.blocks) == 1
|
||||
|
||||
|
||||
class TestTopicGate:
|
||||
"""话题门控测试"""
|
||||
|
||||
def test_obvious_switch(self):
|
||||
"""从 Redis 切换到 Python,明显切换"""
|
||||
gate = ContextGatekeeper()
|
||||
gate.add_turn("Redis 锁怎么用", "用 Redlock...")
|
||||
gate.add_turn("Python 怎么写快速排序", "可以用递归...")
|
||||
|
||||
# 第三个问题,明显是 Python 话题
|
||||
blocks = gate.select("Python 列表推导式怎么写")
|
||||
# 应该主要返回第二、三轮的 Python 内容
|
||||
turn_ids = [b['turn_id'] for b in blocks]
|
||||
assert 2 in turn_ids or 3 in turn_ids
|
||||
|
||||
def test_continuation_with_deictic(self):
|
||||
"""指代词触发强制继承"""
|
||||
gate = ContextGatekeeper()
|
||||
gate.add_turn("Redis 锁是什么", "是分布式锁...")
|
||||
gate.add_turn("它有哪些实现方式", "有 Redlock...")
|
||||
|
||||
# 第二轮有"它",应该强制继承第一轮
|
||||
blocks = gate.select("它有哪些实现方式")
|
||||
turn_ids = [b['turn_id'] for b in blocks]
|
||||
# 应该包含第1轮
|
||||
assert 1 in turn_ids
|
||||
|
||||
|
||||
class TestSparseRetrieval:
|
||||
"""稀疏召回测试"""
|
||||
|
||||
def test_exact_match_boost(self):
|
||||
"""exact match 应该提升得分"""
|
||||
gate = ContextGatekeeper()
|
||||
gate.add_turn("使用 Redis v3.0 集群", "Redis 集群配置...")
|
||||
gate.add_turn("Python 装饰器用法", "装饰器是...")
|
||||
|
||||
blocks = gate.select("Redis v3.0 集群如何搭建")
|
||||
turn_ids = [b['turn_id'] for b in blocks]
|
||||
# 应该召回第一轮
|
||||
assert 1 in turn_ids
|
||||
|
||||
|
||||
class TestMinimumCoverage:
|
||||
"""最小覆盖选择测试"""
|
||||
|
||||
def test_token_budget(self):
|
||||
"""token 预算限制"""
|
||||
gate = ContextGatekeeper(token_budget=300) # 调整预算,100字符block约150tokens
|
||||
gate.add_turn("A" * 100, "B" * 100)
|
||||
gate.add_turn("C" * 100, "D" * 100)
|
||||
|
||||
blocks = gate.select("A")
|
||||
# 预算300,约等于2个block的代价,应该只能选1个
|
||||
assert len(blocks) <= 2, f"预算300应该只选1-2个block,实际选了{len(blocks)}"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""端到端集成测试"""
|
||||
|
||||
def test_multi_turn_conversation(self):
|
||||
"""模拟多轮对话"""
|
||||
gate = ContextGatekeeper()
|
||||
|
||||
# 第1轮
|
||||
gate.add_turn("Redis 锁续租为什么会脑裂", "因为锁过期时间设置不合理...")
|
||||
# 第2轮
|
||||
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
|
||||
# 第3轮:切换话题
|
||||
gate.add_turn("Python 异步编程怎么做", "用 asyncio 模块...")
|
||||
|
||||
# 问 Redis 相关问题,验证能召回 Redis 内容
|
||||
redis_blocks = gate.select("Redis 锁的 TTL 怎么设")
|
||||
turn_ids = [b['turn_id'] for b in redis_blocks]
|
||||
assert 1 in turn_ids, f"第1轮 Redis 内容应该被召回,实际: {turn_ids}"
|
||||
print(f" Redis 查询召回: {turn_ids}")
|
||||
|
||||
# 问 Python 相关问题,验证话题切换后召回正确内容
|
||||
py_blocks = gate.select("asyncio 怎么用")
|
||||
turn_ids_py = [b['turn_id'] for b in py_blocks]
|
||||
assert 3 in turn_ids_py, f"第3轮 Python 内容应该被召回,实际: {turn_ids_py}"
|
||||
print(f" Python 查询召回: {turn_ids_py}")
|
||||
|
||||
def test_constraints_preserved(self):
|
||||
"""约束持久化测试"""
|
||||
gate = ContextGatekeeper()
|
||||
gate.set_constraint("language", "中文")
|
||||
gate.set_constraint("style", "简洁")
|
||||
|
||||
constraints = gate.get_constraints()
|
||||
assert constraints["language"] == "中文"
|
||||
assert constraints["style"] == "简洁"
|
||||
|
||||
def test_reset(self):
|
||||
"""重置测试"""
|
||||
gate = ContextGatekeeper()
|
||||
gate.add_turn("test", "test")
|
||||
assert len(gate.blocks) == 1
|
||||
|
||||
gate.reset()
|
||||
assert len(gate.blocks) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user