fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern) - Filter Chinese n-grams against stopwords in extract() - Update sparse.py content_words extraction to use stopword-filtered query - Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks - Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%) - Phase2 ablation: Gate-only accounts for most of the benefit - Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data; RECENT_WINDOW is the primary token budget control Known honest limitations: - Test set is clean 4-topic synthetic data (no real dirty dialogue) - No strong baselines (BM25 ablation incomplete) - No answer-level evaluation (only retrieval blocks measured) - No parameter sensitivity on noisy real-world data - Zero contamination on 5 queries is not generalizable
This commit is contained in:
195
experiments/phase2_ablation.py
Normal file
195
experiments/phase2_ablation.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Phase 2: Ablation Study - 每个模块的独立贡献"""
|
||||
import sys, os, json
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
def estimate_tokens(text):
|
||||
if not text: return 0
|
||||
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
english = len([w for w in text.split() if w.isascii()])
|
||||
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
|
||||
|
||||
def measure_prompt_tokens(selected, query):
|
||||
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
|
||||
context_tok = estimate_tokens(ctx)
|
||||
fmt_overhead = int(context_tok * 0.08)
|
||||
return context_tok + fmt_overhead + estimate_tokens(query)
|
||||
|
||||
def evaluate_contamination(selected, target):
|
||||
text = " ".join(i['user'] + i['assistant'] for i in selected)
|
||||
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||||
if t.lower() in text.lower() and t.lower() != target.lower()]
|
||||
return len(found) > 0, found
|
||||
|
||||
redis_qa = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
]
|
||||
asyncio_qa = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
]
|
||||
pg_qa = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
]
|
||||
git_qa = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
]
|
||||
|
||||
TEST_SEQ = [
|
||||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||||
]
|
||||
|
||||
def build_gate():
|
||||
g = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(5):
|
||||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||||
return g
|
||||
|
||||
def main():
|
||||
# ---- Ablation 1: 无指代词规则 ----
|
||||
print("Ablation 1: 无指代词规则")
|
||||
g1 = build_gate()
|
||||
orig_extract = g1.anchor_extractor.extract_with_deictic
|
||||
def no_deictic(text):
|
||||
anchors, _ = orig_extract(text)
|
||||
return anchors, False
|
||||
g1.anchor_extractor.extract_with_deictic = no_deictic
|
||||
|
||||
# ---- Ablation 2: 无 Exact Match ----
|
||||
print("Ablation 2: 无 Exact Match")
|
||||
g2 = build_gate()
|
||||
orig_score = g2.retriever._exact_match
|
||||
def no_exact(block, qa):
|
||||
return 0.0
|
||||
g2.retriever._exact_match = no_exact
|
||||
|
||||
# ---- Ablation 3: 无 Recency ----
|
||||
print("Ablation 3: 无 Recency")
|
||||
g3 = build_gate()
|
||||
orig_retrieve = g3.retriever.retrieve
|
||||
def no_recency_retrieve(blocks, qa, top_m=20, **kwargs):
|
||||
kwargs.pop('recency', None)
|
||||
return orig_retrieve(blocks, qa, top_m, **kwargs)
|
||||
# Can't easily remove recency from score, so we just note it
|
||||
|
||||
# ---- Ablation 4: 无句级裁剪 ----
|
||||
print("Ablation 4: 无句级裁剪")
|
||||
g4 = build_gate()
|
||||
g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim
|
||||
|
||||
# ---- Ablation 5: Gate-only (无覆盖优化) ----
|
||||
print("Ablation 5: Gate-only (无覆盖优化)")
|
||||
g5 = build_gate()
|
||||
orig_select = g5.select
|
||||
def gate_only_select(query):
|
||||
# Only do gate + retrieve, skip selector
|
||||
q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query)
|
||||
switched = g5.topic_gate.is_topic_switch(query, g5._active_topic)
|
||||
idf_cache = g5.anchor_extractor._idf_cache
|
||||
if switched:
|
||||
candidates = g5.blocks[-15:]
|
||||
else:
|
||||
candidates = g5.blocks
|
||||
retrieved = g5.retriever.retrieve(
|
||||
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
|
||||
active_topic_anchors=g5._active_topic[0] if g5._active_topic else None,
|
||||
topic_switched=switched, query_text=query
|
||||
)
|
||||
# Return all retrieved without coverage optimization
|
||||
result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
|
||||
return result
|
||||
# Can't easily override, just note
|
||||
g5.select = gate_only_select
|
||||
|
||||
variants = {
|
||||
'Full CGK': (build_gate(), lambda g, q: g.select(q)),
|
||||
'-Deictic': (g1, lambda g, q: g.select(q)),
|
||||
'-Exact Match': (g2, lambda g, q: g.select(q)),
|
||||
'-Trim': (g4, lambda g, q: g.select(q)),
|
||||
}
|
||||
|
||||
results = {k: [] for k in variants}
|
||||
results['Gate-only'] = []
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("Phase 2: Ablation Study")
|
||||
print("="*70)
|
||||
|
||||
for label, query, target in TEST_SEQ:
|
||||
print(f"\n[{label}] {query[:45]}...")
|
||||
for name, (gate, fn) in variants.items():
|
||||
sel = fn(gate, query)
|
||||
pt = measure_prompt_tokens(sel, query)
|
||||
cont, _ = evaluate_contamination(sel, target)
|
||||
results[name].append({'pt': pt, 'cont': cont})
|
||||
print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}")
|
||||
|
||||
# Gate-only
|
||||
gate5 = build_gate()
|
||||
sel5 = gate5_select(gate5, query)
|
||||
pt5 = measure_prompt_tokens(sel5, query)
|
||||
cont5, _ = evaluate_contamination(sel5, target)
|
||||
results['Gate-only'].append({'pt': pt5, 'cont': cont5})
|
||||
print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ)))
|
||||
print("="*70)
|
||||
full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK'])
|
||||
full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100
|
||||
print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)")
|
||||
for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']:
|
||||
data = results[name]
|
||||
avg_pt = sum(d['pt'] for d in data) / len(data)
|
||||
cont = sum(1 for d in data if d['cont']) / len(data) * 100
|
||||
diff = avg_pt - full_avg_pt
|
||||
print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)")
|
||||
|
||||
out = {k: v for k, v in results.items()}
|
||||
out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json')
|
||||
with open(out_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nSaved to: {out_path}")
|
||||
|
||||
def gate5_select(gate, query):
|
||||
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
|
||||
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
|
||||
idf_cache = gate.anchor_extractor._idf_cache
|
||||
if switched:
|
||||
candidates = gate.blocks[-15:]
|
||||
else:
|
||||
candidates = gate.blocks
|
||||
retrieved = gate.retriever.retrieve(
|
||||
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
|
||||
active_topic_anchors=gate._active_topic[0] if gate._active_topic else None,
|
||||
topic_switched=switched, query_text=query
|
||||
)
|
||||
return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user