Files
context-gatekeeper/tests/test_e2e.py
Elaina 071f9ef418 feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
2026-04-22 01:09:35 +08:00

147 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
端到端测试 - 使用 MiniMax API 验证上下文门控器效果
"""
import os
import sys
from dotenv import load_dotenv
# 加载 .env
load_dotenv()
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
# 获取 API Key
API_KEY = os.getenv("MINIMAX_API_KEY")
if not API_KEY:
print("❌ 未找到 MINIMAX_API_KEY请检查 .env 文件")
sys.exit(1)
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
def call_minimax(prompt: str) -> str:
"""调用 MiniMax API"""
import urllib.request
import json
payload = {
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度
"max_tokens": 500,
"temperature": 0.7
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
BASE_URL,
data=data,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
method="POST"
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
result = json.loads(resp.read().decode("utf-8"))
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"API 调用失败: {e}"
def test_e2e_conversation():
"""端到端对话测试"""
print("=" * 60)
print("端到端对话测试 - 验证上下文门控器")
print("=" * 60)
gate = ContextGatekeeper(token_budget=2000)
# === 第一阶段Redis 分布式锁话题 ===
print("\n📌 第1轮Redis 分布式锁话题")
user1 = "Redis 锁续租为什么会脑裂"
assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}")
gate.add_turn(user1, assistant1)
print(f"用户: {user1}")
print(f"助手: {assistant1}")
print("\n📌 第2轮继续 Redis 话题")
user2 = "如何避免脑裂?"
assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}")
gate.add_turn(user2, assistant2)
print(f"用户: {user2}")
print(f"助手: {assistant2}")
# 验证第3轮问 Redis 相关应该召回第1轮可能不召回第2轮取决于锚点重叠度
print("\n📌 第3轮问 Redis 相关问题")
query3 = "锁的 TTL 怎么设置才合理"
selected = gate.select(query3)
print(f"用户查询: {query3}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}")
turn_ids = [b['turn_id'] for b in selected]
assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容"
print("✅ Redis 相关问题召回第1轮内容")
# === 第二阶段:切换到 Python 话题 ===
print("\n" + "=" * 60)
print("📌 第4轮切换到 Python 话题")
user4 = "Python 异步编程怎么做?用 asyncio 举例子"
assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}")
gate.add_turn(user4, assistant4)
print(f"用户: {user4}")
print(f"助手: {assistant4}")
# 验证:问 Python 相关,应该召回 Python 内容3或4
print("\n📌 第5轮问 Python asyncio 相关")
query5 = "asyncio 怎么用?举一个爬虫的例子"
selected5 = gate.select(query5)
print(f"用户查询: {query5}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}")
turn_ids5 = [b['turn_id'] for b in selected5]
# Python 话题应该召回 3 或 4
has_python_topic = (3 in turn_ids5 or 4 in turn_ids5)
assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}"
print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}")
# === 第三阶段:指代词测试 ===
print("\n" + "=" * 60)
print("📌 第6轮指代词测试")
user6 = "它的并发性能怎么样"
assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}")
gate.add_turn(user6, assistant6)
print(f"用户: {user6}")
print(f"助手: {assistant6}")
# 验证:有指代词时,应该强制继承最近轮次
print("\n📌 第7轮指代词强制继承验证")
query7 = "它和 ThreadPool 比哪个更好"
selected7 = gate.select(query7)
print(f"用户查询: {query7}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}")
# 应该有指代词强制继承
assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block"
print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block")
# === 验证 Token 预算控制 ===
print("\n" + "=" * 60)
print("📌 Token 预算验证")
total_context_tokens = sum(
len(b['user']) * 1.5 + len(b['assistant']) * 1.5
for b in selected
)
print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}")
assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算"
print("✅ Token 预算控制正常")
print("\n" + "=" * 60)
print("✅ 所有端到端测试通过!")
print("=" * 60)
if __name__ == "__main__":
test_e2e_conversation()