- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern) - Filter Chinese n-grams against stopwords in extract() - Update sparse.py content_words extraction to use stopword-filtered query - Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks - Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%) - Phase2 ablation: Gate-only accounts for most of the benefit - Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data; RECENT_WINDOW is the primary token budget control Known honest limitations: - Test set is clean 4-topic synthetic data (no real dirty dialogue) - No strong baselines (BM25 ablation incomplete) - No answer-level evaluation (only retrieval blocks measured) - No parameter sensitivity on noisy real-world data - Zero contamination on 5 queries is not generalizable
160 lines
6.6 KiB
Python
160 lines
6.6 KiB
Python
#!/usr/bin/env python3
|
||
"""Phase 3: Parameter Sensitivity Analysis"""
|
||
import sys, os, json
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||
|
||
from src.gatekeeper import ContextGatekeeper
|
||
|
||
def estimate_tokens(text):
|
||
if not text: return 0
|
||
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
english = len([w for w in text.split() if w.isascii()])
|
||
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
|
||
|
||
def measure_prompt_tokens(selected, query):
|
||
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
|
||
return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query)
|
||
|
||
def evaluate_contamination(selected, target):
|
||
text = " ".join(i['user'] + i['assistant'] for i in selected)
|
||
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||
if t.lower() in text.lower() and t.lower() != target.lower()]
|
||
return len(found) > 0
|
||
|
||
redis_qa = [
|
||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||
]
|
||
asyncio_qa = [
|
||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||
]
|
||
pg_qa = [
|
||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||
]
|
||
git_qa = [
|
||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||
]
|
||
|
||
TEST_SEQ = [
|
||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||
]
|
||
|
||
def build_gate():
|
||
g = ContextGatekeeper(token_budget=4000)
|
||
for i in range(5):
|
||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||
return g
|
||
|
||
def run_with_overlap_threshold(threshold):
|
||
g = build_gate()
|
||
g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold
|
||
g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45
|
||
tokens, cont = [], []
|
||
for _, q, t in TEST_SEQ:
|
||
sel = g.select(q)
|
||
pt = measure_prompt_tokens(sel, q)
|
||
c = evaluate_contamination(sel, t)
|
||
tokens.append(pt)
|
||
cont.append(c)
|
||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||
|
||
def run_with_new_ratio_threshold(threshold):
|
||
g = build_gate()
|
||
g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold
|
||
tokens, cont = [], []
|
||
for _, q, t in TEST_SEQ:
|
||
sel = g.select(q)
|
||
pt = measure_prompt_tokens(sel, q)
|
||
c = evaluate_contamination(sel, t)
|
||
tokens.append(pt)
|
||
cont.append(c)
|
||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||
|
||
def run_with_recent_window(window):
|
||
g = build_gate()
|
||
# Patch retrieve to use custom window
|
||
orig_retrieve = g.retriever.retrieve
|
||
def patched_retrieve(blocks, qa, top_m=20, **kwargs):
|
||
# Use smaller/larger window
|
||
if window < len(blocks):
|
||
blocks = blocks[-window:]
|
||
return orig_retrieve(blocks, qa, top_m, **kwargs)
|
||
g.retriever.retrieve = patched_retrieve
|
||
tokens, cont = [], []
|
||
for _, q, t in TEST_SEQ:
|
||
sel = g.select(q)
|
||
pt = measure_prompt_tokens(sel, q)
|
||
c = evaluate_contamination(sel, t)
|
||
tokens.append(pt)
|
||
cont.append(c)
|
||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||
|
||
def main():
|
||
print("="*70)
|
||
print("Phase 3: Parameter Sensitivity Analysis")
|
||
print("="*70)
|
||
|
||
# Baseline
|
||
g = build_gate()
|
||
base_tokens = []
|
||
base_cont = []
|
||
for _, q, t in TEST_SEQ:
|
||
sel = g.select(q)
|
||
base_tokens.append(measure_prompt_tokens(sel, q))
|
||
base_cont.append(evaluate_contamination(sel, t))
|
||
base_tok_avg = sum(base_tokens)/len(base_tokens)
|
||
base_cont_pct = sum(base_cont)/len(base_cont)*100
|
||
print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%")
|
||
print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15")
|
||
|
||
# OVERLAP threshold sweep
|
||
print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---")
|
||
for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]:
|
||
avg_tok, cont_pct = run_with_overlap_threshold(th)
|
||
flag = " ← default" if th == 0.45 else ""
|
||
print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||
|
||
# NEW_RATIO threshold sweep
|
||
print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---")
|
||
for th in [0.50, 0.60, 0.70, 0.80, 0.90]:
|
||
avg_tok, cont_pct = run_with_new_ratio_threshold(th)
|
||
flag = " ← default" if th == 0.70 else ""
|
||
print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||
|
||
# RECENT_WINDOW sweep
|
||
print("\n--- RECENT_WINDOW sweep ---")
|
||
for win in [5, 10, 15, 20, 30]:
|
||
avg_tok, cont_pct = run_with_recent_window(win)
|
||
flag = " ← default" if win == 15 else ""
|
||
print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||
|
||
out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct}
|
||
out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json')
|
||
with open(out_path, 'w') as f:
|
||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||
print(f"\nSaved to: {out_path}")
|
||
|
||
if __name__ == '__main__':
|
||
main() |