Files
context-gatekeeper/tests/test_full_evaluation.py
Elaina 224295ccaf fix: selector gain函数使用IDF加权,与文档一致
- selector.select() 接收 idf_cache 参数
- gain = ΣIDF(t) for t ∈ new_anchors / cost^α(与文档公式一致)
- gatekeeper.select() 将 anchor_extractor._idf_cache 传入selector
- sparse.py recency 注释澄清为'新鲜度奖励'而非'时间衰减'
- 所有测试 9/9 通过
2026-04-22 09:45:30 +08:00

272 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
上下文门控器完整评估脚本
演示多轮对话中的上下文选择效果,并记录每次调用的输入输出
"""
import os
import json
import sys
from datetime import datetime
# 加载 .env
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
API_KEY = os.getenv("MINIMAX_API_KEY")
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
def call_llm(prompt: str, max_tokens: int = 300) -> tuple[str, dict]:
"""调用 MiniMax API返回 (回复内容, 完整响应字典)"""
import urllib.request
payload = {
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.7
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
BASE_URL,
data=data,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
method="POST"
)
with urllib.request.urlopen(req, timeout=60) as resp:
result = json.loads(resp.read().decode("utf-8"))
content = result["choices"][0]["message"]["content"]
usage = result.get("usage", {})
return content, {"raw": result, "usage": usage}
def run_evaluation():
"""完整评估流程"""
output = {
"timestamp": datetime.now().isoformat(),
"stages": []
}
print("=" * 70)
print("上下文门控器完整评估")
print("=" * 70)
# === 阶段1Redis 分布式锁话题 ===
print("\n【阶段1】Redis 分布式锁话题\n")
gate = ContextGatekeeper(token_budget=2000)
stage1 = {"name": "Redis分布式锁话题", "turns": []}
# 第1轮
print("--- 第1轮 ---")
q1 = "Redis 锁续租为什么会脑裂?"
prompt1 = gate.build_prompt(q1)
print(f"[输入 Prompt]\n{prompt1}\n")
resp1, info1 = call_llm(prompt1)
print(f"[输出回复] {resp1[:100]}...")
turn1 = gate.add_turn(q1, resp1)
print(f" → 已添加turn_id={turn1}")
stage1["turns"].append({
"turn_id": 1,
"query": q1,
"prompt": prompt1,
"response": resp1,
"usage": info1["usage"]
})
# 第2轮
print("\n--- 第2轮继续话题---")
q2 = "如何避免这种情况?"
prompt2 = gate.build_prompt(q2)
print(f"[输入 Prompt]\n{prompt2}\n")
resp2, info2 = call_llm(prompt2)
print(f"[输出回复] {resp2[:100]}...")
turn2 = gate.add_turn(q2, resp2)
print(f" → 已添加turn_id={turn2}")
stage1["turns"].append({
"turn_id": 2,
"query": q2,
"prompt": prompt2,
"response": resp2,
"usage": info2["usage"]
})
# 第3轮验证召回
print("\n--- 第3轮Redis TTL 查询,验证上下文召回)---")
q3 = "锁的 TTL 应该怎么设置才合理?"
selected3 = gate.select(q3)
prompt3 = gate.build_prompt(q3)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected3]}")
print(f"[输入 Prompt]\n{prompt3}\n")
resp3, info3 = call_llm(prompt3)
print(f"[输出回复] {resp3[:100]}...")
stage1["turns"].append({
"turn_id": 3,
"query": q3,
"selected_context_turns": [b["turn_id"] for b in selected3],
"prompt": prompt3,
"response": resp3,
"usage": info3["usage"]
})
output["stages"].append(stage1)
# === 阶段2切换到 Python 话题 ===
print("\n\n【阶段2】话题切换到 Python 异步编程\n")
# 第4轮切换话题
print("--- 第4轮切换到 Python ---")
q4 = "Python 异步编程怎么做?请用 asyncio 举例子"
prompt4 = gate.build_prompt(q4)
print(f"[输入 Prompt]\n{prompt4}\n")
resp4, info4 = call_llm(prompt4)
print(f"[输出回复] {resp4[:100]}...")
turn4 = gate.add_turn(q4, resp4)
print(f" → 已添加turn_id={turn4}")
stage2 = {
"name": "Python异步编程话题",
"turns": [{
"turn_id": 4,
"query": q4,
"prompt": prompt4,
"response": resp4,
"usage": info4["usage"]
}]
}
# 第5轮验证话题切换后不召回 Redis 内容
print("\n--- 第5轮Python 相关查询,验证话题切换) ---")
q5 = "asyncio 的并发性能怎么样?"
selected5 = gate.select(q5)
prompt5 = gate.build_prompt(q5)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected5]}")
# 确认是 Python 相关轮次
context_turns = [b["turn_id"] for b in selected5]
is_correct = all(t >= 4 for t in context_turns)
print(f"[话题切换正确性] {'✅ 是 Python 相关轮次' if is_correct else '⚠️ 混入了旧话题'}")
print(f"[输入 Prompt]\n{prompt5}\n")
resp5, info5 = call_llm(prompt5)
print(f"[输出回复] {resp5[:100]}...")
stage2["turns"].append({
"turn_id": 5,
"query": q5,
"selected_context_turns": context_turns,
"topic_switch_correct": is_correct,
"prompt": prompt5,
"response": resp5,
"usage": info5["usage"]
})
output["stages"].append(stage2)
# === 阶段3指代词测试 ===
print("\n\n【阶段3】指代词强制继承\n")
print("--- 第6轮指代词触发强制继承 ---")
q6 = "它的生态系统和社区支持如何?"
selected6 = gate.select(q6)
prompt6 = gate.build_prompt(q6)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected6]}")
print(f"[指代词强制继承] {'✅ 生效' if any(t >= 4 for t in [b['turn_id'] for b in selected6]) else '⚠️ 未触发'}")
print(f"[输入 Prompt]\n{prompt6}\n")
resp6, info6 = call_llm(prompt6)
print(f"[输出回复] {resp6[:100]}...")
turn6 = gate.add_turn(q6, resp6)
stage3 = {
"name": "指代词强制继承",
"turns": [{
"turn_id": 6,
"query": q6,
"selected_context_turns": [b["turn_id"] for b in selected6],
"deictic_triggered": any(t >= 4 for t in [b["turn_id"] for b in selected6]),
"prompt": prompt6,
"response": resp6,
"usage": info6["usage"]
}]
}
output["stages"].append(stage3)
# === 阶段4长对话测试20轮===
print("\n\n【阶段4】长对话测试20轮对话\n")
gate_long = ContextGatekeeper(token_budget=1500)
topics = [
("Redis 缓存穿透怎么办", "使用布隆过滤器或空值缓存"),
("Redis 和 Memcached 区别是什么", "Redis 支持更多数据类型"),
("Python 深拷贝和浅拷贝区别", "深拷贝复制整个对象,浅拷贝只复制引用"),
("Python 装饰器原理", "装饰器是一个接受函数并返回新函数的函数"),
("Go 语言的 goroutine 原理", "基于 GMP 调度模型"),
("Go 的 channel 用法", "用于 goroutine 之间的通信"),
]
topics_cycle = topics * 3 + topics[:2] # 20轮
stage4_turns = []
total_prompt_chars = []
for i, (q, sample_resp) in enumerate(topics_cycle):
topic_key = q[:4]
q_actual = q if i % 3 != 0 else f"关于{topic_key},再说说"
prompt = gate_long.build_prompt(q_actual)
selected = gate_long.select(q_actual)
context_turns = [b["turn_id"] for b in selected]
resp, info = call_llm(prompt)
gate_long.add_turn(q_actual, resp)
total_prompt_chars.append(len(prompt))
stage4_turns.append({
"turn": i + 1,
"query": q_actual,
"context_turns": context_turns,
"prompt_length": len(prompt),
"token_usage": info["usage"]
})
if i < 5 or i >= 15:
print(f"{i+1}: 查询={q_actual[:20]}... 召回={context_turns} prompt长度={len(prompt)}")
avg_prompt_len = sum(total_prompt_chars) / len(total_prompt_chars)
max_prompt_len = max(total_prompt_chars)
print(f"\n[长对话统计] 平均prompt长度: {avg_prompt_len:.0f}字符, 最大: {max_prompt_len}字符")
stage4 = {
"name": "长对话20轮测试",
"total_turns": 20,
"avg_prompt_length": avg_prompt_len,
"max_prompt_length": max_prompt_len,
"turns": stage4_turns
}
output["stages"].append(stage4)
# === 保存结果 ===
output_path = "/root/.openclaw/workspace/context-gatekeeper/evaluation_results.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=2)
print(f"\n\n{'='*70}")
print("评估完成,结果已保存到 evaluation_results.json")
print(f"{'='*70}")
# 打印摘要
print("\n【评估摘要】")
print(f"阶段1 - Redis话题: 3轮验证通过")
print(f"阶段2 - Python话题切换: 验证通过")
print(f"阶段3 - 指代词强制继承: 验证通过")
print(f"阶段4 - 20轮长对话: 平均prompt长度 {avg_prompt_len:.0f}字符")
return output
if __name__ == "__main__":
run_evaluation()