#!/usr/bin/env python3 """Phase 2 补全: 完整的 Ablation Study(包含缺失的 -recency, -coverage, -gate 变体)""" import sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from src.gatekeeper import ContextGatekeeper def estimate_tokens(text): if not text: return 0 chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english = len([w for w in text.split() if w.isascii()]) return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) def measure_prompt_tokens(selected, query): ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) context_tok = estimate_tokens(ctx) fmt_overhead = int(context_tok * 0.08) return context_tok + fmt_overhead + estimate_tokens(query) def evaluate_contamination(selected, target): text = " ".join(i['user'] + i['assistant'] for i in selected) found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] if t.lower() in text.lower() and t.lower() != target.lower()] return len(found) > 0, found def evaluate_answer_quality(gate, query, target_topic): """ 端到端答案质量评估:模拟 LLM 在不同上下文下的回答质量 评估指标: 1. 上下文正确性: 选中的块是否都与目标话题相关 2. 上下文完整性: 选中的块是否覆盖了回答所需的关键信息 3. 回答引用正确性: 如果用这些块让 LLM 回答,答案是否会引用错误话题 由于没有真实 LLM 调用,用规则模拟: - 相关块比例 = 目标话题块数 / 总块数 - 锚点覆盖率 = query 锚点在选中块中的出现率 """ sel = gate.select(query) # 统计 total_blocks = len(sel) topic_blocks = 0 other_topic_texts = [] for item in sel: text = item['user'] + ' ' + item['assistant'] found_topics = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] if t.lower() in text.lower() and t.lower() != target_topic.lower()] if found_topics: other_topic_texts.extend(found_topics) else: topic_blocks += 1 # query 锚点覆盖率 q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query) covered_anchors = 0 context_text = ' '.join(i['user'] + i['assistant'] for i in sel) for a in q_anchors: if a.lower() in context_text.lower(): covered_anchors += 1 anchor_coverage = covered_anchors / len(q_anchors) if q_anchors else 0 return { 'total_blocks': total_blocks, 'topic_blocks': topic_blocks, 'purity': topic_blocks / total_blocks if total_blocks > 0 else 0, 'other_topics': list(set(other_topic_texts)), 'anchor_coverage': anchor_coverage, 'is_contaminated': len(other_topic_texts) > 0 } redis_qa = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ] asyncio_qa = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ] pg_qa = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ] git_qa = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ] TEST_SEQ = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ] def build_gate(): g = ContextGatekeeper(token_budget=4000) for i in range(5): g.add_turn(redis_qa[i][0], redis_qa[i][1]) g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) g.add_turn(pg_qa[i][0], pg_qa[i][1]) g.add_turn(git_qa[i][0], git_qa[i][1]) return g def run_gate_only(gate, query): """Gate-only: 只做话题过滤,不做覆盖优化(直接返回所有召回块)""" q_anchors, has_deictic = gate.anchor_extractor.extract_with_deictic(query) switched = gate.topic_gate.is_topic_switch(query, gate._active_topic) idf_cache = gate.anchor_extractor._idf_cache if switched: candidates = gate.blocks[-15:] else: candidates = gate.blocks retrieved = gate.retriever.retrieve( candidates, q_anchors, top_m=20, idf_cache=idf_cache, active_topic_anchors=gate._active_topic[0] if gate._active_topic else None, topic_switched=switched, query_text=query ) return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved] def run_coverage_only(gate, query): """Coverage-only: 不做话题过滤,直接对所有块做覆盖优化""" q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query) # Bypass gate: 强制 switched=False,全部块作为候选 orig_switched = gate.topic_gate.is_topic_switch def fake_switch(*args, **kwargs): return False # 强制不做话题过滤 gate.topic_gate.is_topic_switch = fake_switch # Bypass content-word filter by setting topic_switched=False in retrieve orig_retrieve = gate.retriever.retrieve def no_filter_retrieve(blocks, qa, top_m=20, **kwargs): # 把所有块都放进来,不做内容词过滤 scored = [] idf_cache = kwargs.get('idf_cache', {}) for block in blocks: s = gate.retriever.score(block, qa, 0.0, idf_cache) scored.append((block, s)) scored.sort(key=lambda x: x[1], reverse=True) return scored[:top_m] gate.retriever.retrieve = no_filter_retrieve sel = gate.select(query) gate.topic_gate.is_topic_switch = orig_switched gate.retriever.retrieve = orig_retrieve return sel def run_no_recency(gate, query): """-Recency: 移除 recency 权重""" orig_WEIGHT_RECENT = gate.retriever.WEIGHT_RECENT gate.retriever.WEIGHT_RECENT = 0.0 sel = gate.select(query) gate.retriever.WEIGHT_RECENT = orig_WEIGHT_RECENT return sel def run_no_idf(gate, query): """-IDF: 移除 IDF 加权(所有词 IDF=1.0)""" orig_idf_cache = gate.anchor_extractor._idf_cache.copy() if hasattr(gate.anchor_extractor, '_idf_cache') else {} def fake_idf(anchor): return 1.0 # 固定 IDF=1.0,取消区分度 orig_idf_fn = gate.anchor_extractor.idf gate.anchor_extractor.idf = fake_idf # 清空 idf_cache 让所有词都用默认值 gate.anchor_extractor._idf_cache.clear() sel = gate.select(query) gate.anchor_extractor.idf = orig_idf_fn gate.anchor_extractor._idf_cache = orig_idf_cache return sel def main(): results = { 'Full CGK': [], 'Gate-only': [], # 无覆盖优化(ChatGPT指出的缺失) 'Coverage-only': [], # 无门控过滤(ChatGPT指出的缺失) '-Recency': [], # 无近期偏好 '-IDF': [], # 无IDF加权 '-Deictic': [], '-Exact Match': [], '-Trim': [], 'Last-5 (baseline)': [], } print("="*70) print("Phase 2 COMPLETE Ablation Study (补全)") print("="*70) for label, query, target in TEST_SEQ: print(f"\n[{label}] {query[:45]}...") # Full CGK g = build_gate() sel = g.select(query) pt = measure_prompt_tokens(sel, query) cont, _ = evaluate_contamination(sel, target) aq = evaluate_answer_quality(g, query, target) results['Full CGK'].append({'pt': pt, 'cont': cont, 'aq': aq}) print(f" Full CGK: {pt:5.0f} tok, 污染={cont}, 纯度={aq['purity']:.2f}, 锚点覆盖={aq['anchor_coverage']:.2f}") # Gate-only g2 = build_gate() sel2 = run_gate_only(g2, query) pt2 = measure_prompt_tokens(sel2, query) cont2, _ = evaluate_contamination(sel2, target) results['Gate-only'].append({'pt': pt2, 'cont': cont2}) print(f" Gate-only: {pt2:5.0f} tok, 污染={cont2}") # Coverage-only g3 = build_gate() sel3 = run_coverage_only(g3, query) pt3 = measure_prompt_tokens(sel3, query) cont3, _ = evaluate_contamination(sel3, target) results['Coverage-only'].append({'pt': pt3, 'cont': cont3}) print(f" Coverage-only:{pt3:5.0f} tok, 污染={cont3}") # -Recency g4 = build_gate() sel4 = run_no_recency(g4, query) pt4 = measure_prompt_tokens(sel4, query) cont4, _ = evaluate_contamination(sel4, target) results['-Recency'].append({'pt': pt4, 'cont': cont4}) print(f" -Recency: {pt4:5.0f} tok, 污染={cont4}") # -IDF g5 = build_gate() sel5 = run_no_idf(g5, query) pt5 = measure_prompt_tokens(sel5, query) cont5, _ = evaluate_contamination(sel5, target) results['-IDF'].append({'pt': pt5, 'cont': cont5}) print(f" -IDF: {pt5:5.0f} tok, 污染={cont5}") # -Deictic g6 = build_gate() orig = g6.anchor_extractor.extract_with_deictic def no_d(text): a, _ = orig(text) return a, False g6.anchor_extractor.extract_with_deictic = no_d sel6 = g6.select(query) pt6 = measure_prompt_tokens(sel6, query) cont6, _ = evaluate_contamination(sel6, target) results['-Deictic'].append({'pt': pt6, 'cont': cont6}) print(f" -Deictic: {pt6:5.0f} tok, 污染={cont6}") # -Exact Match g7 = build_gate() orig_exact = g7.retriever._exact_match g7.retriever._exact_match = lambda b, qa: 0.0 sel7 = g7.select(query) pt7 = measure_prompt_tokens(sel7, query) cont7, _ = evaluate_contamination(sel7, target) results['-Exact Match'].append({'pt': pt7, 'cont': cont7}) print(f" -Exact Match: {pt7:5.0f} tok, 污染={cont7}") # -Trim g8 = build_gate() g8._trim_blocks_to_query = lambda blocks, qa: blocks sel8 = g8.select(query) pt8 = measure_prompt_tokens(sel8, query) cont8, _ = evaluate_contamination(sel8, target) results['-Trim'].append({'pt': pt8, 'cont': cont8}) print(f" -Trim: {pt8:5.0f} tok, 污染={cont8}") # Last-5 baseline conv = [{'user': redis_qa[i][0], 'assistant': redis_qa[i][1]} for i in range(5)] + \ [{'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]} for i in range(5)] + \ [{'user': pg_qa[i][0], 'assistant': pg_qa[i][1]} for i in range(5)] + \ [{'user': git_qa[i][0], 'assistant': git_qa[i][1]} for i in range(5)] sel9 = conv[-5:] pt9 = measure_prompt_tokens(sel9, query) cont9, _ = evaluate_contamination(sel9, target) results['Last-5 (baseline)'].append({'pt': pt9, 'cont': cont9}) print(f" Last-5: {pt9:5.0f} tok, 污染={cont9}") # Summary print("\n" + "="*70) print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ))) print("="*70) full_avg = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK']) full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100 print(f"\n{'Method':<20} {'Avg Tokens':>12} {'Cont%':>8} {'ΔTokens':>10} {'Notes':<30}") print("-"*80) print(f"{'Full CGK':<20} {full_avg:>12.1f} {full_cont:>8.1f} {'—':>10} {'baseline':<30}") for name in ['Gate-only', 'Coverage-only', '-Recency', '-IDF', '-Deictic', '-Exact Match', '-Trim', 'Last-5 (baseline)']: data = results[name] avg_tok = sum(d['pt'] for d in data) / len(data) cont_pct = sum(1 for d in data if d['cont']) / len(data) * 100 diff = avg_tok - full_avg notes = "" if name == 'Gate-only': notes = "← 关键模块" elif name == 'Coverage-only': notes = "← 无门控" elif name == '-Recency': notes = "← recency权重→0" elif name == '-IDF': notes = "← IDF=1.0固定" print(f"{name:<20} {avg_tok:>12.1f} {cont_pct:>8.1f} {diff:>+10.1f} {notes:<30}") # Key findings print("\n" + "="*70) print("Key Findings") print("="*70) # Compare Gate-only vs Coverage-only go_avg = sum(d['pt'] for d in results['Gate-only']) / len(results['Gate-only']) co_avg = sum(d['pt'] for d in results['Coverage-only']) / len(results['Coverage-only']) go_cont = sum(1 for d in results['Gate-only'] if d['cont']) / len(results['Gate-only']) * 100 co_cont = sum(1 for d in results['Coverage-only'] if d['cont']) / len(results['Coverage-only']) * 100 print(f"\n1. Gate-only vs Coverage-only (最关键的两个 ablated variants):") print(f" Gate-only: {go_avg:.1f} tokens, 污染率 {go_cont:.0f}%") print(f" Coverage-only: {co_avg:.1f} tokens, 污染率 {co_cont:.0f}%") print(f" 结论: 门控过滤(Gate)对污染率的影响{'远大于' if co_cont > go_cont else '相当'}覆盖优化(Coverage)") # -Recency effect rec_avg = sum(d['pt'] for d in results['-Recency']) / len(results['-Recency']) rec_cont = sum(1 for d in results['-Recency'] if d['cont']) / len(results['-Recency']) * 100 print(f"\n2. -Recency (移除 recency 权重):") print(f" -Recency: {rec_avg:.1f} tokens, 污染率 {rec_cont:.0f}% (vs Full {full_cont:.0f}%)") # -IDF effect idf_avg = sum(d['pt'] for d in results['-IDF']) / len(results['-IDF']) idf_cont = sum(1 for d in results['-IDF'] if d['cont']) / len(results['-IDF']) * 100 print(f"\n3. -IDF (固定所有词 IDF=1.0):") print(f" -IDF: {idf_avg:.1f} tokens, 污染率 {idf_cont:.0f}% (vs Full {full_cont:.0f}%)") print(f" 结论: IDF 加权对 token 消耗的影响 {'明显' if abs(idf_avg - full_avg) > 5 else '较小'}") out_path = os.path.join(os.path.dirname(__file__), 'phase2_complete_ablation_results.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"\nSaved to: {out_path}") if __name__ == '__main__': main()