#!/usr/bin/env python3 """Phase 2: Ablation Study - 每个模块的独立贡献""" import sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from src.gatekeeper import ContextGatekeeper def estimate_tokens(text): if not text: return 0 chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english = len([w for w in text.split() if w.isascii()]) return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) def measure_prompt_tokens(selected, query): ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) context_tok = estimate_tokens(ctx) fmt_overhead = int(context_tok * 0.08) return context_tok + fmt_overhead + estimate_tokens(query) def evaluate_contamination(selected, target): text = " ".join(i['user'] + i['assistant'] for i in selected) found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] if t.lower() in text.lower() and t.lower() != target.lower()] return len(found) > 0, found redis_qa = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ] asyncio_qa = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ] pg_qa = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ] git_qa = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ] TEST_SEQ = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ] def build_gate(): g = ContextGatekeeper(token_budget=4000) for i in range(5): g.add_turn(redis_qa[i][0], redis_qa[i][1]) g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) g.add_turn(pg_qa[i][0], pg_qa[i][1]) g.add_turn(git_qa[i][0], git_qa[i][1]) return g def main(): # ---- Ablation 1: 无指代词规则 ---- print("Ablation 1: 无指代词规则") g1 = build_gate() orig_extract = g1.anchor_extractor.extract_with_deictic def no_deictic(text): anchors, _ = orig_extract(text) return anchors, False g1.anchor_extractor.extract_with_deictic = no_deictic # ---- Ablation 2: 无 Exact Match ---- print("Ablation 2: 无 Exact Match") g2 = build_gate() orig_score = g2.retriever._exact_match def no_exact(block, qa): return 0.0 g2.retriever._exact_match = no_exact # ---- Ablation 3: 无 Recency ---- print("Ablation 3: 无 Recency") g3 = build_gate() orig_retrieve = g3.retriever.retrieve def no_recency_retrieve(blocks, qa, top_m=20, **kwargs): kwargs.pop('recency', None) return orig_retrieve(blocks, qa, top_m, **kwargs) # Can't easily remove recency from score, so we just note it # ---- Ablation 4: 无句级裁剪 ---- print("Ablation 4: 无句级裁剪") g4 = build_gate() g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim # ---- Ablation 5: Gate-only (无覆盖优化) ---- print("Ablation 5: Gate-only (无覆盖优化)") g5 = build_gate() orig_select = g5.select def gate_only_select(query): # Only do gate + retrieve, skip selector q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query) switched = g5.topic_gate.is_topic_switch(query, g5._active_topic) idf_cache = g5.anchor_extractor._idf_cache if switched: candidates = g5.blocks[-15:] else: candidates = g5.blocks retrieved = g5.retriever.retrieve( candidates, q_anchors, top_m=20, idf_cache=idf_cache, active_topic_anchors=g5._active_topic[0] if g5._active_topic else None, topic_switched=switched, query_text=query ) # Return all retrieved without coverage optimization result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved] return result # Can't easily override, just note g5.select = gate_only_select variants = { 'Full CGK': (build_gate(), lambda g, q: g.select(q)), '-Deictic': (g1, lambda g, q: g.select(q)), '-Exact Match': (g2, lambda g, q: g.select(q)), '-Trim': (g4, lambda g, q: g.select(q)), } results = {k: [] for k in variants} results['Gate-only'] = [] print("\n" + "="*70) print("Phase 2: Ablation Study") print("="*70) for label, query, target in TEST_SEQ: print(f"\n[{label}] {query[:45]}...") for name, (gate, fn) in variants.items(): sel = fn(gate, query) pt = measure_prompt_tokens(sel, query) cont, _ = evaluate_contamination(sel, target) results[name].append({'pt': pt, 'cont': cont}) print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}") # Gate-only gate5 = build_gate() sel5 = gate5_select(gate5, query) pt5 = measure_prompt_tokens(sel5, query) cont5, _ = evaluate_contamination(sel5, target) results['Gate-only'].append({'pt': pt5, 'cont': cont5}) print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}") # Summary print("\n" + "="*70) print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ))) print("="*70) full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK']) full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100 print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)") for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']: data = results[name] avg_pt = sum(d['pt'] for d in data) / len(data) cont = sum(1 for d in data if d['cont']) / len(data) * 100 diff = avg_pt - full_avg_pt print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)") out = {k: v for k, v in results.items()} out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2, ensure_ascii=False) print(f"\nSaved to: {out_path}") def gate5_select(gate, query): q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query) switched = gate.topic_gate.is_topic_switch(query, gate._active_topic) idf_cache = gate.anchor_extractor._idf_cache if switched: candidates = gate.blocks[-15:] else: candidates = gate.blocks retrieved = gate.retriever.retrieve( candidates, q_anchors, top_m=20, idf_cache=idf_cache, active_topic_anchors=gate._active_topic[0] if gate._active_topic else None, topic_switched=switched, query_text=query ) return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved] if __name__ == '__main__': main()