feat: 上下文门控器初始实现

- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
Elaina
2026-04-22 01:09:35 +08:00
commit 071f9ef418
12 changed files with 1061 additions and 0 deletions

147
tests/test_e2e.py Normal file
View File

@@ -0,0 +1,147 @@
"""
端到端测试 - 使用 MiniMax API 验证上下文门控器效果
"""
import os
import sys
from dotenv import load_dotenv
# 加载 .env
load_dotenv()
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
# 获取 API Key
API_KEY = os.getenv("MINIMAX_API_KEY")
if not API_KEY:
print("❌ 未找到 MINIMAX_API_KEY请检查 .env 文件")
sys.exit(1)
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
def call_minimax(prompt: str) -> str:
"""调用 MiniMax API"""
import urllib.request
import json
payload = {
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度
"max_tokens": 500,
"temperature": 0.7
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
BASE_URL,
data=data,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
method="POST"
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
result = json.loads(resp.read().decode("utf-8"))
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"API 调用失败: {e}"
def test_e2e_conversation():
"""端到端对话测试"""
print("=" * 60)
print("端到端对话测试 - 验证上下文门控器")
print("=" * 60)
gate = ContextGatekeeper(token_budget=2000)
# === 第一阶段Redis 分布式锁话题 ===
print("\n📌 第1轮Redis 分布式锁话题")
user1 = "Redis 锁续租为什么会脑裂"
assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}")
gate.add_turn(user1, assistant1)
print(f"用户: {user1}")
print(f"助手: {assistant1}")
print("\n📌 第2轮继续 Redis 话题")
user2 = "如何避免脑裂?"
assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}")
gate.add_turn(user2, assistant2)
print(f"用户: {user2}")
print(f"助手: {assistant2}")
# 验证第3轮问 Redis 相关应该召回第1轮可能不召回第2轮取决于锚点重叠度
print("\n📌 第3轮问 Redis 相关问题")
query3 = "锁的 TTL 怎么设置才合理"
selected = gate.select(query3)
print(f"用户查询: {query3}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}")
turn_ids = [b['turn_id'] for b in selected]
assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容"
print("✅ Redis 相关问题召回第1轮内容")
# === 第二阶段:切换到 Python 话题 ===
print("\n" + "=" * 60)
print("📌 第4轮切换到 Python 话题")
user4 = "Python 异步编程怎么做?用 asyncio 举例子"
assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}")
gate.add_turn(user4, assistant4)
print(f"用户: {user4}")
print(f"助手: {assistant4}")
# 验证:问 Python 相关,应该召回 Python 内容3或4
print("\n📌 第5轮问 Python asyncio 相关")
query5 = "asyncio 怎么用?举一个爬虫的例子"
selected5 = gate.select(query5)
print(f"用户查询: {query5}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}")
turn_ids5 = [b['turn_id'] for b in selected5]
# Python 话题应该召回 3 或 4
has_python_topic = (3 in turn_ids5 or 4 in turn_ids5)
assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}"
print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}")
# === 第三阶段:指代词测试 ===
print("\n" + "=" * 60)
print("📌 第6轮指代词测试")
user6 = "它的并发性能怎么样"
assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}")
gate.add_turn(user6, assistant6)
print(f"用户: {user6}")
print(f"助手: {assistant6}")
# 验证:有指代词时,应该强制继承最近轮次
print("\n📌 第7轮指代词强制继承验证")
query7 = "它和 ThreadPool 比哪个更好"
selected7 = gate.select(query7)
print(f"用户查询: {query7}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}")
# 应该有指代词强制继承
assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block"
print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block")
# === 验证 Token 预算控制 ===
print("\n" + "=" * 60)
print("📌 Token 预算验证")
total_context_tokens = sum(
len(b['user']) * 1.5 + len(b['assistant']) * 1.5
for b in selected
)
print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}")
assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算"
print("✅ Token 预算控制正常")
print("\n" + "=" * 60)
print("✅ 所有端到端测试通过!")
print("=" * 60)
if __name__ == "__main__":
test_e2e_conversation()