Files
context-gatekeeper/experiments/phase1_baseline.py
Elaina 9e44748f91 fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
2026-04-22 22:30:18 +08:00

167 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Phase 1: Baseline Comparison - 7 methods compared fairly"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = ""
for item in selected:
ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
context_tok = estimate_tokens(ctx)
query_tok = estimate_tokens(query)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + query_tok
def evaluate_contamination(selected, target):
text = " ".join(item['user'] + item['assistant'] for item in selected)
found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
# 100轮对话
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
# 测试序列
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_full_conv():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def build_conv_list():
conv = []
for i in range(5):
conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]})
conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]})
conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]})
conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]})
return conv
def last_n_select(conv, n, query):
return conv[-n:]
def bm25_select(conv, top_k, query):
qw = set(query.lower().split())
scored = []
for i, t in enumerate(conv):
txt = (t['user'] + ' ' + t['assistant']).lower()
sc = sum(1 for w in qw if w in txt)
recency = (i + 1) / len(conv)
scored.append((i, t, sc + recency * 0.2))
scored.sort(key=lambda x: x[2], reverse=True)
return [s[1] for s in scored[:top_k]]
def main():
gate = build_full_conv()
conv = build_conv_list()
methods = {
'Last-3': lambda q: last_n_select(conv, 3, q),
'Last-5': lambda q: last_n_select(conv, 5, q),
'Last-10': lambda q: last_n_select(conv, 10, q),
'BM25-5': lambda q: bm25_select(conv, 5, q),
'Full CGK': lambda q: gate.select(q),
}
results = {k: [] for k in methods}
cgk_prompt_tokens_list = []
print("=" * 70)
print("Phase 1: Baseline Comparison")
print("=" * 70)
for label, query, target in TEST_SEQ:
# CGK
sel = gate.select(query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont})
cgk_prompt_tokens_list.append(pt)
# Baselines
for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']:
sel = methods[name](query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results[name].append({'label': label, 'pt': pt, 'cont': cont})
print(f"\n[{label}] {query[:45]}...")
for name, data in results.items():
r = data[-1]
print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}")
# Summary
print("\n" + "=" * 70)
print("Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("=" * 70)
for name, data in results.items():
avg_pt = sum(d['pt'] for d in data) / len(data)
cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100
print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%")
# CGK vs baselines
cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list)
print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}")
# Save
out = {}
for name, data in results.items():
out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data),
'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100,
'raw': data}
out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()