Files
context-gatekeeper/experiments/phase2_ablation.py
Elaina 9e44748f91 fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
2026-04-22 22:30:18 +08:00

195 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Phase 2: Ablation Study - 每个模块的独立贡献"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
context_tok = estimate_tokens(ctx)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + estimate_tokens(query)
def evaluate_contamination(selected, target):
text = " ".join(i['user'] + i['assistant'] for i in selected)
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def main():
# ---- Ablation 1: 无指代词规则 ----
print("Ablation 1: 无指代词规则")
g1 = build_gate()
orig_extract = g1.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False
g1.anchor_extractor.extract_with_deictic = no_deictic
# ---- Ablation 2: 无 Exact Match ----
print("Ablation 2: 无 Exact Match")
g2 = build_gate()
orig_score = g2.retriever._exact_match
def no_exact(block, qa):
return 0.0
g2.retriever._exact_match = no_exact
# ---- Ablation 3: 无 Recency ----
print("Ablation 3: 无 Recency")
g3 = build_gate()
orig_retrieve = g3.retriever.retrieve
def no_recency_retrieve(blocks, qa, top_m=20, **kwargs):
kwargs.pop('recency', None)
return orig_retrieve(blocks, qa, top_m, **kwargs)
# Can't easily remove recency from score, so we just note it
# ---- Ablation 4: 无句级裁剪 ----
print("Ablation 4: 无句级裁剪")
g4 = build_gate()
g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim
# ---- Ablation 5: Gate-only (无覆盖优化) ----
print("Ablation 5: Gate-only (无覆盖优化)")
g5 = build_gate()
orig_select = g5.select
def gate_only_select(query):
# Only do gate + retrieve, skip selector
q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query)
switched = g5.topic_gate.is_topic_switch(query, g5._active_topic)
idf_cache = g5.anchor_extractor._idf_cache
if switched:
candidates = g5.blocks[-15:]
else:
candidates = g5.blocks
retrieved = g5.retriever.retrieve(
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
active_topic_anchors=g5._active_topic[0] if g5._active_topic else None,
topic_switched=switched, query_text=query
)
# Return all retrieved without coverage optimization
result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
return result
# Can't easily override, just note
g5.select = gate_only_select
variants = {
'Full CGK': (build_gate(), lambda g, q: g.select(q)),
'-Deictic': (g1, lambda g, q: g.select(q)),
'-Exact Match': (g2, lambda g, q: g.select(q)),
'-Trim': (g4, lambda g, q: g.select(q)),
}
results = {k: [] for k in variants}
results['Gate-only'] = []
print("\n" + "="*70)
print("Phase 2: Ablation Study")
print("="*70)
for label, query, target in TEST_SEQ:
print(f"\n[{label}] {query[:45]}...")
for name, (gate, fn) in variants.items():
sel = fn(gate, query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results[name].append({'pt': pt, 'cont': cont})
print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}")
# Gate-only
gate5 = build_gate()
sel5 = gate5_select(gate5, query)
pt5 = measure_prompt_tokens(sel5, query)
cont5, _ = evaluate_contamination(sel5, target)
results['Gate-only'].append({'pt': pt5, 'cont': cont5})
print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}")
# Summary
print("\n" + "="*70)
print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("="*70)
full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK'])
full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100
print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)")
for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']:
data = results[name]
avg_pt = sum(d['pt'] for d in data) / len(data)
cont = sum(1 for d in data if d['cont']) / len(data) * 100
diff = avg_pt - full_avg_pt
print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)")
out = {k: v for k, v in results.items()}
out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
def gate5_select(gate, query):
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
idf_cache = gate.anchor_extractor._idf_cache
if switched:
candidates = gate.blocks[-15:]
else:
candidates = gate.blocks
retrieved = gate.retriever.retrieve(
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
active_topic_anchors=gate._active_topic[0] if gate._active_topic else None,
topic_switched=switched, query_text=query
)
return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
if __name__ == '__main__':
main()