Files
context-gatekeeper/experiments/phase3_sensitivity.py
Elaina 9e44748f91 fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
2026-04-22 22:30:18 +08:00

160 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Phase 3: Parameter Sensitivity Analysis"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query)
def evaluate_contamination(selected, target):
text = " ".join(i['user'] + i['assistant'] for i in selected)
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def run_with_overlap_threshold(threshold):
g = build_gate()
g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold
g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def run_with_new_ratio_threshold(threshold):
g = build_gate()
g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def run_with_recent_window(window):
g = build_gate()
# Patch retrieve to use custom window
orig_retrieve = g.retriever.retrieve
def patched_retrieve(blocks, qa, top_m=20, **kwargs):
# Use smaller/larger window
if window < len(blocks):
blocks = blocks[-window:]
return orig_retrieve(blocks, qa, top_m, **kwargs)
g.retriever.retrieve = patched_retrieve
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def main():
print("="*70)
print("Phase 3: Parameter Sensitivity Analysis")
print("="*70)
# Baseline
g = build_gate()
base_tokens = []
base_cont = []
for _, q, t in TEST_SEQ:
sel = g.select(q)
base_tokens.append(measure_prompt_tokens(sel, q))
base_cont.append(evaluate_contamination(sel, t))
base_tok_avg = sum(base_tokens)/len(base_tokens)
base_cont_pct = sum(base_cont)/len(base_cont)*100
print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%")
print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15")
# OVERLAP threshold sweep
print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---")
for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]:
avg_tok, cont_pct = run_with_overlap_threshold(th)
flag = " ← default" if th == 0.45 else ""
print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
# NEW_RATIO threshold sweep
print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---")
for th in [0.50, 0.60, 0.70, 0.80, 0.90]:
avg_tok, cont_pct = run_with_new_ratio_threshold(th)
flag = " ← default" if th == 0.70 else ""
print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
# RECENT_WINDOW sweep
print("\n--- RECENT_WINDOW sweep ---")
for win in [5, 10, 15, 20, 30]:
avg_tok, cont_pct = run_with_recent_window(win)
flag = " ← default" if win == 15 else ""
print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct}
out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()