fix: selector gain函数使用IDF加权,与文档一致

- selector.select() 接收 idf_cache 参数
- gain = ΣIDF(t) for t ∈ new_anchors / cost^α(与文档公式一致)
- gatekeeper.select() 将 anchor_extractor._idf_cache 传入selector
- sparse.py recency 注释澄清为'新鲜度奖励'而非'时间衰减'
- 所有测试 9/9 通过
This commit is contained in:
Elaina
2026-04-22 09:45:30 +08:00
parent 7ced5d9a10
commit 224295ccaf
5 changed files with 763 additions and 15 deletions

View File

@@ -0,0 +1,272 @@
"""
上下文门控器完整评估脚本
演示多轮对话中的上下文选择效果,并记录每次调用的输入输出
"""
import os
import json
import sys
from datetime import datetime
# 加载 .env
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
API_KEY = os.getenv("MINIMAX_API_KEY")
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
def call_llm(prompt: str, max_tokens: int = 300) -> tuple[str, dict]:
"""调用 MiniMax API返回 (回复内容, 完整响应字典)"""
import urllib.request
payload = {
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.7
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
BASE_URL,
data=data,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
method="POST"
)
with urllib.request.urlopen(req, timeout=60) as resp:
result = json.loads(resp.read().decode("utf-8"))
content = result["choices"][0]["message"]["content"]
usage = result.get("usage", {})
return content, {"raw": result, "usage": usage}
def run_evaluation():
"""完整评估流程"""
output = {
"timestamp": datetime.now().isoformat(),
"stages": []
}
print("=" * 70)
print("上下文门控器完整评估")
print("=" * 70)
# === 阶段1Redis 分布式锁话题 ===
print("\n【阶段1】Redis 分布式锁话题\n")
gate = ContextGatekeeper(token_budget=2000)
stage1 = {"name": "Redis分布式锁话题", "turns": []}
# 第1轮
print("--- 第1轮 ---")
q1 = "Redis 锁续租为什么会脑裂?"
prompt1 = gate.build_prompt(q1)
print(f"[输入 Prompt]\n{prompt1}\n")
resp1, info1 = call_llm(prompt1)
print(f"[输出回复] {resp1[:100]}...")
turn1 = gate.add_turn(q1, resp1)
print(f" → 已添加turn_id={turn1}")
stage1["turns"].append({
"turn_id": 1,
"query": q1,
"prompt": prompt1,
"response": resp1,
"usage": info1["usage"]
})
# 第2轮
print("\n--- 第2轮继续话题---")
q2 = "如何避免这种情况?"
prompt2 = gate.build_prompt(q2)
print(f"[输入 Prompt]\n{prompt2}\n")
resp2, info2 = call_llm(prompt2)
print(f"[输出回复] {resp2[:100]}...")
turn2 = gate.add_turn(q2, resp2)
print(f" → 已添加turn_id={turn2}")
stage1["turns"].append({
"turn_id": 2,
"query": q2,
"prompt": prompt2,
"response": resp2,
"usage": info2["usage"]
})
# 第3轮验证召回
print("\n--- 第3轮Redis TTL 查询,验证上下文召回)---")
q3 = "锁的 TTL 应该怎么设置才合理?"
selected3 = gate.select(q3)
prompt3 = gate.build_prompt(q3)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected3]}")
print(f"[输入 Prompt]\n{prompt3}\n")
resp3, info3 = call_llm(prompt3)
print(f"[输出回复] {resp3[:100]}...")
stage1["turns"].append({
"turn_id": 3,
"query": q3,
"selected_context_turns": [b["turn_id"] for b in selected3],
"prompt": prompt3,
"response": resp3,
"usage": info3["usage"]
})
output["stages"].append(stage1)
# === 阶段2切换到 Python 话题 ===
print("\n\n【阶段2】话题切换到 Python 异步编程\n")
# 第4轮切换话题
print("--- 第4轮切换到 Python ---")
q4 = "Python 异步编程怎么做?请用 asyncio 举例子"
prompt4 = gate.build_prompt(q4)
print(f"[输入 Prompt]\n{prompt4}\n")
resp4, info4 = call_llm(prompt4)
print(f"[输出回复] {resp4[:100]}...")
turn4 = gate.add_turn(q4, resp4)
print(f" → 已添加turn_id={turn4}")
stage2 = {
"name": "Python异步编程话题",
"turns": [{
"turn_id": 4,
"query": q4,
"prompt": prompt4,
"response": resp4,
"usage": info4["usage"]
}]
}
# 第5轮验证话题切换后不召回 Redis 内容
print("\n--- 第5轮Python 相关查询,验证话题切换) ---")
q5 = "asyncio 的并发性能怎么样?"
selected5 = gate.select(q5)
prompt5 = gate.build_prompt(q5)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected5]}")
# 确认是 Python 相关轮次
context_turns = [b["turn_id"] for b in selected5]
is_correct = all(t >= 4 for t in context_turns)
print(f"[话题切换正确性] {'✅ 是 Python 相关轮次' if is_correct else '⚠️ 混入了旧话题'}")
print(f"[输入 Prompt]\n{prompt5}\n")
resp5, info5 = call_llm(prompt5)
print(f"[输出回复] {resp5[:100]}...")
stage2["turns"].append({
"turn_id": 5,
"query": q5,
"selected_context_turns": context_turns,
"topic_switch_correct": is_correct,
"prompt": prompt5,
"response": resp5,
"usage": info5["usage"]
})
output["stages"].append(stage2)
# === 阶段3指代词测试 ===
print("\n\n【阶段3】指代词强制继承\n")
print("--- 第6轮指代词触发强制继承 ---")
q6 = "它的生态系统和社区支持如何?"
selected6 = gate.select(q6)
prompt6 = gate.build_prompt(q6)
print(f"[召回的上下文 blocks] {[b['turn_id'] for b in selected6]}")
print(f"[指代词强制继承] {'✅ 生效' if any(t >= 4 for t in [b['turn_id'] for b in selected6]) else '⚠️ 未触发'}")
print(f"[输入 Prompt]\n{prompt6}\n")
resp6, info6 = call_llm(prompt6)
print(f"[输出回复] {resp6[:100]}...")
turn6 = gate.add_turn(q6, resp6)
stage3 = {
"name": "指代词强制继承",
"turns": [{
"turn_id": 6,
"query": q6,
"selected_context_turns": [b["turn_id"] for b in selected6],
"deictic_triggered": any(t >= 4 for t in [b["turn_id"] for b in selected6]),
"prompt": prompt6,
"response": resp6,
"usage": info6["usage"]
}]
}
output["stages"].append(stage3)
# === 阶段4长对话测试20轮===
print("\n\n【阶段4】长对话测试20轮对话\n")
gate_long = ContextGatekeeper(token_budget=1500)
topics = [
("Redis 缓存穿透怎么办", "使用布隆过滤器或空值缓存"),
("Redis 和 Memcached 区别是什么", "Redis 支持更多数据类型"),
("Python 深拷贝和浅拷贝区别", "深拷贝复制整个对象,浅拷贝只复制引用"),
("Python 装饰器原理", "装饰器是一个接受函数并返回新函数的函数"),
("Go 语言的 goroutine 原理", "基于 GMP 调度模型"),
("Go 的 channel 用法", "用于 goroutine 之间的通信"),
]
topics_cycle = topics * 3 + topics[:2] # 20轮
stage4_turns = []
total_prompt_chars = []
for i, (q, sample_resp) in enumerate(topics_cycle):
topic_key = q[:4]
q_actual = q if i % 3 != 0 else f"关于{topic_key},再说说"
prompt = gate_long.build_prompt(q_actual)
selected = gate_long.select(q_actual)
context_turns = [b["turn_id"] for b in selected]
resp, info = call_llm(prompt)
gate_long.add_turn(q_actual, resp)
total_prompt_chars.append(len(prompt))
stage4_turns.append({
"turn": i + 1,
"query": q_actual,
"context_turns": context_turns,
"prompt_length": len(prompt),
"token_usage": info["usage"]
})
if i < 5 or i >= 15:
print(f"{i+1}: 查询={q_actual[:20]}... 召回={context_turns} prompt长度={len(prompt)}")
avg_prompt_len = sum(total_prompt_chars) / len(total_prompt_chars)
max_prompt_len = max(total_prompt_chars)
print(f"\n[长对话统计] 平均prompt长度: {avg_prompt_len:.0f}字符, 最大: {max_prompt_len}字符")
stage4 = {
"name": "长对话20轮测试",
"total_turns": 20,
"avg_prompt_length": avg_prompt_len,
"max_prompt_length": max_prompt_len,
"turns": stage4_turns
}
output["stages"].append(stage4)
# === 保存结果 ===
output_path = "/root/.openclaw/workspace/context-gatekeeper/evaluation_results.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=2)
print(f"\n\n{'='*70}")
print("评估完成,结果已保存到 evaluation_results.json")
print(f"{'='*70}")
# 打印摘要
print("\n【评估摘要】")
print(f"阶段1 - Redis话题: 3轮验证通过")
print(f"阶段2 - Python话题切换: 验证通过")
print(f"阶段3 - 指代词强制继承: 验证通过")
print(f"阶段4 - 20轮长对话: 平均prompt长度 {avg_prompt_len:.0f}字符")
return output
if __name__ == "__main__":
run_evaluation()