fix: anchor stopwords - remove generic question patterns causing cross-topic contamination

- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
This commit is contained in:
Elaina
2026-04-22 22:30:18 +08:00
parent 2064eb7bdf
commit 9e44748f91
10 changed files with 1461 additions and 12 deletions

View File

@@ -0,0 +1,167 @@
#!/usr/bin/env python3
"""Phase 1: Baseline Comparison - 7 methods compared fairly"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = ""
for item in selected:
ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
context_tok = estimate_tokens(ctx)
query_tok = estimate_tokens(query)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + query_tok
def evaluate_contamination(selected, target):
text = " ".join(item['user'] + item['assistant'] for item in selected)
found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
# 100轮对话
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
# 测试序列
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_full_conv():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def build_conv_list():
conv = []
for i in range(5):
conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]})
conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]})
conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]})
conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]})
return conv
def last_n_select(conv, n, query):
return conv[-n:]
def bm25_select(conv, top_k, query):
qw = set(query.lower().split())
scored = []
for i, t in enumerate(conv):
txt = (t['user'] + ' ' + t['assistant']).lower()
sc = sum(1 for w in qw if w in txt)
recency = (i + 1) / len(conv)
scored.append((i, t, sc + recency * 0.2))
scored.sort(key=lambda x: x[2], reverse=True)
return [s[1] for s in scored[:top_k]]
def main():
gate = build_full_conv()
conv = build_conv_list()
methods = {
'Last-3': lambda q: last_n_select(conv, 3, q),
'Last-5': lambda q: last_n_select(conv, 5, q),
'Last-10': lambda q: last_n_select(conv, 10, q),
'BM25-5': lambda q: bm25_select(conv, 5, q),
'Full CGK': lambda q: gate.select(q),
}
results = {k: [] for k in methods}
cgk_prompt_tokens_list = []
print("=" * 70)
print("Phase 1: Baseline Comparison")
print("=" * 70)
for label, query, target in TEST_SEQ:
# CGK
sel = gate.select(query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont})
cgk_prompt_tokens_list.append(pt)
# Baselines
for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']:
sel = methods[name](query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results[name].append({'label': label, 'pt': pt, 'cont': cont})
print(f"\n[{label}] {query[:45]}...")
for name, data in results.items():
r = data[-1]
print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}")
# Summary
print("\n" + "=" * 70)
print("Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("=" * 70)
for name, data in results.items():
avg_pt = sum(d['pt'] for d in data) / len(data)
cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100
print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%")
# CGK vs baselines
cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list)
print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}")
# Save
out = {}
for name, data in results.items():
out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data),
'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100,
'raw': data}
out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()