#!/usr/bin/env python3 """Phase 3: Parameter Sensitivity Analysis""" import sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from src.gatekeeper import ContextGatekeeper def estimate_tokens(text): if not text: return 0 chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english = len([w for w in text.split() if w.isascii()]) return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) def measure_prompt_tokens(selected, query): ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query) def evaluate_contamination(selected, target): text = " ".join(i['user'] + i['assistant'] for i in selected) found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] if t.lower() in text.lower() and t.lower() != target.lower()] return len(found) > 0 redis_qa = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ] asyncio_qa = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ] pg_qa = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ] git_qa = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ] TEST_SEQ = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ] def build_gate(): g = ContextGatekeeper(token_budget=4000) for i in range(5): g.add_turn(redis_qa[i][0], redis_qa[i][1]) g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) g.add_turn(pg_qa[i][0], pg_qa[i][1]) g.add_turn(git_qa[i][0], git_qa[i][1]) return g def run_with_overlap_threshold(threshold): g = build_gate() g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45 tokens, cont = [], [] for _, q, t in TEST_SEQ: sel = g.select(q) pt = measure_prompt_tokens(sel, q) c = evaluate_contamination(sel, t) tokens.append(pt) cont.append(c) return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 def run_with_new_ratio_threshold(threshold): g = build_gate() g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold tokens, cont = [], [] for _, q, t in TEST_SEQ: sel = g.select(q) pt = measure_prompt_tokens(sel, q) c = evaluate_contamination(sel, t) tokens.append(pt) cont.append(c) return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 def run_with_recent_window(window): g = build_gate() # Patch retrieve to use custom window orig_retrieve = g.retriever.retrieve def patched_retrieve(blocks, qa, top_m=20, **kwargs): # Use smaller/larger window if window < len(blocks): blocks = blocks[-window:] return orig_retrieve(blocks, qa, top_m, **kwargs) g.retriever.retrieve = patched_retrieve tokens, cont = [], [] for _, q, t in TEST_SEQ: sel = g.select(q) pt = measure_prompt_tokens(sel, q) c = evaluate_contamination(sel, t) tokens.append(pt) cont.append(c) return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 def main(): print("="*70) print("Phase 3: Parameter Sensitivity Analysis") print("="*70) # Baseline g = build_gate() base_tokens = [] base_cont = [] for _, q, t in TEST_SEQ: sel = g.select(q) base_tokens.append(measure_prompt_tokens(sel, q)) base_cont.append(evaluate_contamination(sel, t)) base_tok_avg = sum(base_tokens)/len(base_tokens) base_cont_pct = sum(base_cont)/len(base_cont)*100 print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%") print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15") # OVERLAP threshold sweep print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---") for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]: avg_tok, cont_pct = run_with_overlap_threshold(th) flag = " ← default" if th == 0.45 else "" print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") # NEW_RATIO threshold sweep print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---") for th in [0.50, 0.60, 0.70, 0.80, 0.90]: avg_tok, cont_pct = run_with_new_ratio_threshold(th) flag = " ← default" if th == 0.70 else "" print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") # RECENT_WINDOW sweep print("\n--- RECENT_WINDOW sweep ---") for win in [5, 10, 15, 20, 30]: avg_tok, cont_pct = run_with_recent_window(win) flag = " ← default" if win == 15 else "" print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct} out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2, ensure_ascii=False) print(f"\nSaved to: {out_path}") if __name__ == '__main__': main()