Files
context-gatekeeper/experiments/phase4_quality.py
Elaina 97e1ddf138 complete: full ablation + Phase4 quality evaluation + honest blog post
Phase2 complete ablation (added missing variants):
- Coverage-only: 20% contamination rate (confirms Gate is critical)
- Gate-only: +5.2 tokens vs Full (coverage optimization marginal on clean data)
- -Recency: 0 effect on clean data
- -IDF: 0 effect on clean data

Phase4 end-to-end quality evaluation:
- CGK vs Last-5 across 5 queries:
  * CGK: 42.2 tok, purity=1.000, anchor_recall=0.638, term_cov=0.380, contamination=0
  * Last-5: 67.6 tok, purity=0.280, anchor_recall=0.066, term_cov=0.080, contamination=5
- All quality metrics CGK >> Last-5 on synthetic clean data

Known honest limitations:
- Still no real dialogue data (synthetic 4-topic only)
- No real LLM calls (quality is rule-estimated)
- Parameter sensitivity only on clean data, not noisy real data
2026-04-22 22:48:25 +08:00

214 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Phase 4: End-to-End Context Quality Evaluation"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query)
def evaluate_context_quality(gate, query, target_topic, context):
total_blocks = len(context)
topic_blocks = 0
other_topics_found = []
for item in context:
text = item['user'] + ' ' + item['assistant']
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target_topic.lower()]
if found:
other_topics_found.extend(found)
else:
topic_blocks += 1
block_purity = topic_blocks / total_blocks if total_blocks > 0 else 0
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
context_text = ' '.join(item['user'] + ' ' + item['assistant'] for item in context)
covered = sum(1 for a in q_anchors if a.lower() in context_text.lower())
anchor_recall = covered / len(q_anchors) if q_anchors else 0
key_terms = {
'PostgreSQL': ['explain', 'analyze', 'mvcc', 'vacuum', '索引', '执行计划'],
'Git': ['rebase', 'merge', 'reset', 'commit', '分支'],
'Redis': ['redis', '分布式锁', '惰性删除', '定期删除', '过期'],
'asyncio': ['asyncio', 'task', 'cancel', '事件循环', '协程'],
}
relevant_terms = key_terms.get(target_topic, [])
context_lower = context_text.lower()
terms_found = [t for t in relevant_terms if t.lower() in context_lower]
term_coverage = len(terms_found) / len(relevant_terms) if relevant_terms else 0
return {
'block_purity': block_purity,
'anchor_recall': anchor_recall,
'term_coverage': term_coverage,
'other_topics': list(set(other_topics_found)),
'context_tokens': estimate_tokens(context_text),
}
def evaluate_retrieval_quality(gate, query, target_topic):
sel_cgk = gate.select(query)
last5_context = gate.blocks[-5:]
last5_items = [{'user': b.user_text, 'assistant': b.assistant_text} for b in last5_context]
cgk_quality = evaluate_context_quality(gate, query, target_topic, sel_cgk)
last5_quality = evaluate_context_quality(gate, query, target_topic, last5_items)
cgk_prompt_tok = measure_prompt_tokens(sel_cgk, query)
last5_prompt_tok = measure_prompt_tokens(last5_items, query)
return {
'cgk': {**cgk_quality, 'prompt_tokens': cgk_prompt_tok},
'last5': {**last5_quality, 'prompt_tokens': last5_prompt_tok},
}
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def main():
print("="*70)
print("Phase 4: End-to-End Context Quality Evaluation")
print("="*70)
print("评估维度:")
print(" - block_purity: 目标话题块占比1.0=纯目标话题)")
print(" - anchor_recall: query锚点在上下文中的覆盖率")
print(" - term_coverage: 目标话题关键术语在上下文中的覆盖率")
print(" - prompt_tokens: 完整prompt token数含格式化开销")
print()
cgk_tokens_list, last5_tokens_list = [], []
cgk_purities, last5_purities = [], []
cgk_anchors, last5_anchors = [], []
cgk_terms, last5_terms = [], []
cgk_cont, last5_cont = [], []
for label, query, target in TEST_SEQ:
gate = build_gate()
r = evaluate_retrieval_quality(gate, query, target)
cgk = r['cgk']
last5 = r['last5']
cgk_tokens_list.append(cgk['prompt_tokens'])
last5_tokens_list.append(last5['prompt_tokens'])
cgk_purities.append(cgk['block_purity'])
last5_purities.append(last5['block_purity'])
cgk_anchors.append(cgk['anchor_recall'])
last5_anchors.append(last5['anchor_recall'])
cgk_terms.append(cgk['term_coverage'])
last5_terms.append(last5['term_coverage'])
cgk_cont.append(1 if cgk['other_topics'] else 0)
last5_cont.append(1 if last5['other_topics'] else 0)
print(f"\n[{label}] {query}")
print(f" CGK: tok={cgk['prompt_tokens']:.0f}, purity={cgk['block_purity']:.2f}, "
f"anchor={cgk['anchor_recall']:.2f}, term={cgk['term_coverage']:.2f}, "
f"other_topics={cgk['other_topics']}")
print(f" Last-5: tok={last5['prompt_tokens']:.0f}, purity={last5['block_purity']:.2f}, "
f"anchor={last5['anchor_recall']:.2f}, term={last5['term_coverage']:.2f}, "
f"other_topics={last5['other_topics']}")
saving = (1 - cgk['prompt_tokens']/last5['prompt_tokens'])*100
print(f" → CGK节省{saving:.0f}%token, 纯度+{(cgk['block_purity']-last5['block_purity'])*100:.0f}%, "
f"混入话题{len(cgk['other_topics'])}个 vs Last-5 {len(last5['other_topics'])}")
n = len(TEST_SEQ)
print("\n" + "="*70)
print("Summary: CGK vs Last-5 (avg over {} queries)".format(n))
print("="*70)
print(f"\n{'Metric':<25} {'CGK':>12} {'Last-5':>12} {'Winner':<10}")
print("-"*62)
cgk_avg_tok = sum(cgk_tokens_list)/n
last5_avg_tok = sum(last5_tokens_list)/n
print(f"{'Avg prompt tokens':<25} {cgk_avg_tok:>12.1f} {last5_avg_tok:>12.1f} {'CGK' if cgk_avg_tok < last5_avg_tok else 'Last-5':<10}")
cgk_avg_pur = sum(cgk_purities)/n
last5_avg_pur = sum(last5_purities)/n
print(f"{'Avg block purity':<25} {cgk_avg_pur:>12.3f} {last5_avg_pur:>12.3f} {'CGK' if cgk_avg_pur > last5_avg_pur else 'Last-5':<10}")
cgk_avg_anc = sum(cgk_anchors)/n
last5_avg_anc = sum(last5_anchors)/n
print(f"{'Avg anchor recall':<25} {cgk_avg_anc:>12.3f} {last5_avg_anc:>12.3f} {'CGK' if cgk_avg_anc > last5_avg_anc else 'Last-5':<10}")
cgk_avg_term = sum(cgk_terms)/n
last5_avg_term = sum(last5_terms)/n
print(f"{'Avg term coverage':<25} {cgk_avg_term:>12.3f} {last5_avg_term:>12.3f} {'CGK' if cgk_avg_term > last5_avg_term else 'Last-5':<10}")
print(f"{'Contamination episodes':<25} {sum(cgk_cont):>12} {sum(last5_cont):>12} {'CGK' if sum(cgk_cont) < sum(last5_cont) else 'Last-5':<10}")
print("\n" + "="*70)
print("Honest Limitations of This Evaluation")
print("="*70)
print("""
1. 没有真实 LLM 调用:评估的是"上下文块的质量",不是"模型答案的质量"
上下文好 ≠ 答案好,真正的答案质量需要实际调用 LLM。
2. 测试集仍是合成数据:真实对话中用户可能只打"那这个呢""为什么"
短 query 的锚点覆盖率会显著低于本测试中的完整问题。
3. 污染只统计了"块级别"混入:即使上下文纯度 100%LLM 的注意力机制
仍可能跨块建立错误关联,这种"软污染"无法通过块级分析检测。
""")
out_path = os.path.join(os.path.dirname(__file__), 'phase4_quality_results.json')
with open(out_path, 'w') as f:
json.dump({
'cgk_avg_tokens': cgk_avg_tok,
'last5_avg_tokens': last5_avg_tok,
'cgk_avg_purity': cgk_avg_pur,
'last5_avg_purity': last5_avg_pur,
'cgk_avg_anchor_recall': cgk_avg_anc,
'last5_avg_anchor_recall': last5_avg_anc,
'cgk_avg_term_coverage': cgk_avg_term,
'last5_avg_term_coverage': last5_avg_term,
'cgk_contamination_episodes': sum(cgk_cont),
'last5_contamination_episodes': sum(last5_cont),
}, f, indent=2)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()