#!/usr/bin/env python3 """Phase 1: Baseline Comparison - 7 methods compared fairly""" import sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from src.gatekeeper import ContextGatekeeper TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git'] def estimate_tokens(text): if not text: return 0 chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english = len([w for w in text.split() if w.isascii()]) return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) def measure_prompt_tokens(selected, query): ctx = "" for item in selected: ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n" context_tok = estimate_tokens(ctx) query_tok = estimate_tokens(query) fmt_overhead = int(context_tok * 0.08) return context_tok + fmt_overhead + query_tok def evaluate_contamination(selected, target): text = " ".join(item['user'] + item['assistant'] for item in selected) found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()] return len(found) > 0, found # 100轮对话 redis_qa = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ] asyncio_qa = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ] pg_qa = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ] git_qa = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ] # 测试序列 TEST_SEQ = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ] def build_full_conv(): g = ContextGatekeeper(token_budget=4000) for i in range(5): g.add_turn(redis_qa[i][0], redis_qa[i][1]) g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) g.add_turn(pg_qa[i][0], pg_qa[i][1]) g.add_turn(git_qa[i][0], git_qa[i][1]) return g def build_conv_list(): conv = [] for i in range(5): conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]}) conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]}) conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]}) conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]}) return conv def last_n_select(conv, n, query): return conv[-n:] def bm25_select(conv, top_k, query): qw = set(query.lower().split()) scored = [] for i, t in enumerate(conv): txt = (t['user'] + ' ' + t['assistant']).lower() sc = sum(1 for w in qw if w in txt) recency = (i + 1) / len(conv) scored.append((i, t, sc + recency * 0.2)) scored.sort(key=lambda x: x[2], reverse=True) return [s[1] for s in scored[:top_k]] def main(): gate = build_full_conv() conv = build_conv_list() methods = { 'Last-3': lambda q: last_n_select(conv, 3, q), 'Last-5': lambda q: last_n_select(conv, 5, q), 'Last-10': lambda q: last_n_select(conv, 10, q), 'BM25-5': lambda q: bm25_select(conv, 5, q), 'Full CGK': lambda q: gate.select(q), } results = {k: [] for k in methods} cgk_prompt_tokens_list = [] print("=" * 70) print("Phase 1: Baseline Comparison") print("=" * 70) for label, query, target in TEST_SEQ: # CGK sel = gate.select(query) pt = measure_prompt_tokens(sel, query) cont, _ = evaluate_contamination(sel, target) results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont}) cgk_prompt_tokens_list.append(pt) # Baselines for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']: sel = methods[name](query) pt = measure_prompt_tokens(sel, query) cont, _ = evaluate_contamination(sel, target) results[name].append({'label': label, 'pt': pt, 'cont': cont}) print(f"\n[{label}] {query[:45]}...") for name, data in results.items(): r = data[-1] print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}") # Summary print("\n" + "=" * 70) print("Summary (avg over {} queries)".format(len(TEST_SEQ))) print("=" * 70) for name, data in results.items(): avg_pt = sum(d['pt'] for d in data) / len(data) cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100 print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%") # CGK vs baselines cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list) print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}") # Save out = {} for name, data in results.items(): out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data), 'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100, 'raw': data} out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2, ensure_ascii=False) print(f"\nSaved to: {out_path}") if __name__ == '__main__': main()