Files
context-gatekeeper/experiments/phase2_complete_ablation.py
Elaina 97e1ddf138 complete: full ablation + Phase4 quality evaluation + honest blog post
Phase2 complete ablation (added missing variants):
- Coverage-only: 20% contamination rate (confirms Gate is critical)
- Gate-only: +5.2 tokens vs Full (coverage optimization marginal on clean data)
- -Recency: 0 effect on clean data
- -IDF: 0 effect on clean data

Phase4 end-to-end quality evaluation:
- CGK vs Last-5 across 5 queries:
  * CGK: 42.2 tok, purity=1.000, anchor_recall=0.638, term_cov=0.380, contamination=0
  * Last-5: 67.6 tok, purity=0.280, anchor_recall=0.066, term_cov=0.080, contamination=5
- All quality metrics CGK >> Last-5 on synthetic clean data

Known honest limitations:
- Still no real dialogue data (synthetic 4-topic only)
- No real LLM calls (quality is rule-estimated)
- Parameter sensitivity only on clean data, not noisy real data
2026-04-22 22:48:25 +08:00

357 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Phase 2 补全: 完整的 Ablation Study包含缺失的 -recency, -coverage, -gate 变体)"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
context_tok = estimate_tokens(ctx)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + estimate_tokens(query)
def evaluate_contamination(selected, target):
text = " ".join(i['user'] + i['assistant'] for i in selected)
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
def evaluate_answer_quality(gate, query, target_topic):
"""
端到端答案质量评估:模拟 LLM 在不同上下文下的回答质量
评估指标:
1. 上下文正确性: 选中的块是否都与目标话题相关
2. 上下文完整性: 选中的块是否覆盖了回答所需的关键信息
3. 回答引用正确性: 如果用这些块让 LLM 回答,答案是否会引用错误话题
由于没有真实 LLM 调用,用规则模拟:
- 相关块比例 = 目标话题块数 / 总块数
- 锚点覆盖率 = query 锚点在选中块中的出现率
"""
sel = gate.select(query)
# 统计
total_blocks = len(sel)
topic_blocks = 0
other_topic_texts = []
for item in sel:
text = item['user'] + ' ' + item['assistant']
found_topics = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target_topic.lower()]
if found_topics:
other_topic_texts.extend(found_topics)
else:
topic_blocks += 1
# query 锚点覆盖率
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
covered_anchors = 0
context_text = ' '.join(i['user'] + i['assistant'] for i in sel)
for a in q_anchors:
if a.lower() in context_text.lower():
covered_anchors += 1
anchor_coverage = covered_anchors / len(q_anchors) if q_anchors else 0
return {
'total_blocks': total_blocks,
'topic_blocks': topic_blocks,
'purity': topic_blocks / total_blocks if total_blocks > 0 else 0,
'other_topics': list(set(other_topic_texts)),
'anchor_coverage': anchor_coverage,
'is_contaminated': len(other_topic_texts) > 0
}
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def run_gate_only(gate, query):
"""Gate-only: 只做话题过滤,不做覆盖优化(直接返回所有召回块)"""
q_anchors, has_deictic = gate.anchor_extractor.extract_with_deictic(query)
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
idf_cache = gate.anchor_extractor._idf_cache
if switched:
candidates = gate.blocks[-15:]
else:
candidates = gate.blocks
retrieved = gate.retriever.retrieve(
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
active_topic_anchors=gate._active_topic[0] if gate._active_topic else None,
topic_switched=switched, query_text=query
)
return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
def run_coverage_only(gate, query):
"""Coverage-only: 不做话题过滤,直接对所有块做覆盖优化"""
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
# Bypass gate: 强制 switched=False全部块作为候选
orig_switched = gate.topic_gate.is_topic_switch
def fake_switch(*args, **kwargs):
return False # 强制不做话题过滤
gate.topic_gate.is_topic_switch = fake_switch
# Bypass content-word filter by setting topic_switched=False in retrieve
orig_retrieve = gate.retriever.retrieve
def no_filter_retrieve(blocks, qa, top_m=20, **kwargs):
# 把所有块都放进来,不做内容词过滤
scored = []
idf_cache = kwargs.get('idf_cache', {})
for block in blocks:
s = gate.retriever.score(block, qa, 0.0, idf_cache)
scored.append((block, s))
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_m]
gate.retriever.retrieve = no_filter_retrieve
sel = gate.select(query)
gate.topic_gate.is_topic_switch = orig_switched
gate.retriever.retrieve = orig_retrieve
return sel
def run_no_recency(gate, query):
"""-Recency: 移除 recency 权重"""
orig_WEIGHT_RECENT = gate.retriever.WEIGHT_RECENT
gate.retriever.WEIGHT_RECENT = 0.0
sel = gate.select(query)
gate.retriever.WEIGHT_RECENT = orig_WEIGHT_RECENT
return sel
def run_no_idf(gate, query):
"""-IDF: 移除 IDF 加权(所有词 IDF=1.0"""
orig_idf_cache = gate.anchor_extractor._idf_cache.copy() if hasattr(gate.anchor_extractor, '_idf_cache') else {}
def fake_idf(anchor):
return 1.0 # 固定 IDF=1.0,取消区分度
orig_idf_fn = gate.anchor_extractor.idf
gate.anchor_extractor.idf = fake_idf
# 清空 idf_cache 让所有词都用默认值
gate.anchor_extractor._idf_cache.clear()
sel = gate.select(query)
gate.anchor_extractor.idf = orig_idf_fn
gate.anchor_extractor._idf_cache = orig_idf_cache
return sel
def main():
results = {
'Full CGK': [],
'Gate-only': [], # 无覆盖优化ChatGPT指出的缺失
'Coverage-only': [], # 无门控过滤ChatGPT指出的缺失
'-Recency': [], # 无近期偏好
'-IDF': [], # 无IDF加权
'-Deictic': [],
'-Exact Match': [],
'-Trim': [],
'Last-5 (baseline)': [],
}
print("="*70)
print("Phase 2 COMPLETE Ablation Study (补全)")
print("="*70)
for label, query, target in TEST_SEQ:
print(f"\n[{label}] {query[:45]}...")
# Full CGK
g = build_gate()
sel = g.select(query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
aq = evaluate_answer_quality(g, query, target)
results['Full CGK'].append({'pt': pt, 'cont': cont, 'aq': aq})
print(f" Full CGK: {pt:5.0f} tok, 污染={cont}, 纯度={aq['purity']:.2f}, 锚点覆盖={aq['anchor_coverage']:.2f}")
# Gate-only
g2 = build_gate()
sel2 = run_gate_only(g2, query)
pt2 = measure_prompt_tokens(sel2, query)
cont2, _ = evaluate_contamination(sel2, target)
results['Gate-only'].append({'pt': pt2, 'cont': cont2})
print(f" Gate-only: {pt2:5.0f} tok, 污染={cont2}")
# Coverage-only
g3 = build_gate()
sel3 = run_coverage_only(g3, query)
pt3 = measure_prompt_tokens(sel3, query)
cont3, _ = evaluate_contamination(sel3, target)
results['Coverage-only'].append({'pt': pt3, 'cont': cont3})
print(f" Coverage-only:{pt3:5.0f} tok, 污染={cont3}")
# -Recency
g4 = build_gate()
sel4 = run_no_recency(g4, query)
pt4 = measure_prompt_tokens(sel4, query)
cont4, _ = evaluate_contamination(sel4, target)
results['-Recency'].append({'pt': pt4, 'cont': cont4})
print(f" -Recency: {pt4:5.0f} tok, 污染={cont4}")
# -IDF
g5 = build_gate()
sel5 = run_no_idf(g5, query)
pt5 = measure_prompt_tokens(sel5, query)
cont5, _ = evaluate_contamination(sel5, target)
results['-IDF'].append({'pt': pt5, 'cont': cont5})
print(f" -IDF: {pt5:5.0f} tok, 污染={cont5}")
# -Deictic
g6 = build_gate()
orig = g6.anchor_extractor.extract_with_deictic
def no_d(text):
a, _ = orig(text)
return a, False
g6.anchor_extractor.extract_with_deictic = no_d
sel6 = g6.select(query)
pt6 = measure_prompt_tokens(sel6, query)
cont6, _ = evaluate_contamination(sel6, target)
results['-Deictic'].append({'pt': pt6, 'cont': cont6})
print(f" -Deictic: {pt6:5.0f} tok, 污染={cont6}")
# -Exact Match
g7 = build_gate()
orig_exact = g7.retriever._exact_match
g7.retriever._exact_match = lambda b, qa: 0.0
sel7 = g7.select(query)
pt7 = measure_prompt_tokens(sel7, query)
cont7, _ = evaluate_contamination(sel7, target)
results['-Exact Match'].append({'pt': pt7, 'cont': cont7})
print(f" -Exact Match: {pt7:5.0f} tok, 污染={cont7}")
# -Trim
g8 = build_gate()
g8._trim_blocks_to_query = lambda blocks, qa: blocks
sel8 = g8.select(query)
pt8 = measure_prompt_tokens(sel8, query)
cont8, _ = evaluate_contamination(sel8, target)
results['-Trim'].append({'pt': pt8, 'cont': cont8})
print(f" -Trim: {pt8:5.0f} tok, 污染={cont8}")
# Last-5 baseline
conv = [{'user': redis_qa[i][0], 'assistant': redis_qa[i][1]} for i in range(5)] + \
[{'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]} for i in range(5)] + \
[{'user': pg_qa[i][0], 'assistant': pg_qa[i][1]} for i in range(5)] + \
[{'user': git_qa[i][0], 'assistant': git_qa[i][1]} for i in range(5)]
sel9 = conv[-5:]
pt9 = measure_prompt_tokens(sel9, query)
cont9, _ = evaluate_contamination(sel9, target)
results['Last-5 (baseline)'].append({'pt': pt9, 'cont': cont9})
print(f" Last-5: {pt9:5.0f} tok, 污染={cont9}")
# Summary
print("\n" + "="*70)
print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("="*70)
full_avg = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK'])
full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100
print(f"\n{'Method':<20} {'Avg Tokens':>12} {'Cont%':>8} {'ΔTokens':>10} {'Notes':<30}")
print("-"*80)
print(f"{'Full CGK':<20} {full_avg:>12.1f} {full_cont:>8.1f} {'':>10} {'baseline':<30}")
for name in ['Gate-only', 'Coverage-only', '-Recency', '-IDF', '-Deictic', '-Exact Match', '-Trim', 'Last-5 (baseline)']:
data = results[name]
avg_tok = sum(d['pt'] for d in data) / len(data)
cont_pct = sum(1 for d in data if d['cont']) / len(data) * 100
diff = avg_tok - full_avg
notes = ""
if name == 'Gate-only':
notes = "← 关键模块"
elif name == 'Coverage-only':
notes = "← 无门控"
elif name == '-Recency':
notes = "← recency权重→0"
elif name == '-IDF':
notes = "← IDF=1.0固定"
print(f"{name:<20} {avg_tok:>12.1f} {cont_pct:>8.1f} {diff:>+10.1f} {notes:<30}")
# Key findings
print("\n" + "="*70)
print("Key Findings")
print("="*70)
# Compare Gate-only vs Coverage-only
go_avg = sum(d['pt'] for d in results['Gate-only']) / len(results['Gate-only'])
co_avg = sum(d['pt'] for d in results['Coverage-only']) / len(results['Coverage-only'])
go_cont = sum(1 for d in results['Gate-only'] if d['cont']) / len(results['Gate-only']) * 100
co_cont = sum(1 for d in results['Coverage-only'] if d['cont']) / len(results['Coverage-only']) * 100
print(f"\n1. Gate-only vs Coverage-only (最关键的两个 ablated variants):")
print(f" Gate-only: {go_avg:.1f} tokens, 污染率 {go_cont:.0f}%")
print(f" Coverage-only: {co_avg:.1f} tokens, 污染率 {co_cont:.0f}%")
print(f" 结论: 门控过滤(Gate)对污染率的影响{'远大于' if co_cont > go_cont else '相当'}覆盖优化(Coverage)")
# -Recency effect
rec_avg = sum(d['pt'] for d in results['-Recency']) / len(results['-Recency'])
rec_cont = sum(1 for d in results['-Recency'] if d['cont']) / len(results['-Recency']) * 100
print(f"\n2. -Recency (移除 recency 权重):")
print(f" -Recency: {rec_avg:.1f} tokens, 污染率 {rec_cont:.0f}% (vs Full {full_cont:.0f}%)")
# -IDF effect
idf_avg = sum(d['pt'] for d in results['-IDF']) / len(results['-IDF'])
idf_cont = sum(1 for d in results['-IDF'] if d['cont']) / len(results['-IDF']) * 100
print(f"\n3. -IDF (固定所有词 IDF=1.0):")
print(f" -IDF: {idf_avg:.1f} tokens, 污染率 {idf_cont:.0f}% (vs Full {full_cont:.0f}%)")
print(f" 结论: IDF 加权对 token 消耗的影响 {'明显' if abs(idf_avg - full_avg) > 5 else '较小'}")
out_path = os.path.join(os.path.dirname(__file__), 'phase2_complete_ablation_results.json')
with open(out_path, 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()