diff --git a/experiments/diagnose_contamination.py b/experiments/diagnose_contamination.py new file mode 100644 index 0000000..f168ccc --- /dev/null +++ b/experiments/diagnose_contamination.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +"""诊断:为什么 Full CGK 有 20% 污染率""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.gatekeeper import ContextGatekeeper + +redis_qa = [ + ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), + ("Redis 集群环境下怎么做分布式锁?", "集群下..."), + ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), + ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), + ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), +] +asyncio_qa = [ + ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), + ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), + ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), + ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), + ("asyncio 的 Future 对象怎么获取结果?", "Future..."), +] +pg_qa = [ + ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), + ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), + ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), + ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), + ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), +] +git_qa = [ + ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), + ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), + ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), + ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), + ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), +] + +def build_gate(): + g = ContextGatekeeper(token_budget=4000) + for i in range(5): + g.add_turn(redis_qa[i][0], redis_qa[i][1]) + g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) + g.add_turn(pg_qa[i][0], pg_qa[i][1]) + g.add_turn(git_qa[i][0], git_qa[i][1]) + return g + +def diagnose(query, target_topic): + gate = build_gate() + + print(f"\n{'='*60}") + print(f"Query: {query}") + print(f"Target: {target_topic}") + print(f"="*60) + + # 提取 query 锚点 + q_anchors, has_deictic = gate.anchor_extractor.extract_with_deictic(query) + print(f"Query anchors: {q_anchors}") + print(f"Has deictic: {has_deictic}") + + # 话题切换检测 + switched = gate.topic_gate.is_topic_switch(query, gate._active_topic) + print(f"Topic switched: {switched}") + + # 召回的块 + sel = gate.select(query) + print(f"Selected blocks: {len(sel)}") + + for item in sel: + content = item['user'] + item['assistant'] + found_topics = [] + for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']: + if t.lower() in content.lower(): + found_topics.append(t) + print(f" turn {item['turn_id']}: {found_topics} -> {content[:60]}") + + # 检查污染 + all_text = ' '.join(item['user'] + item['assistant'] for item in sel) + other = [t for t in ['Redis','asyncio','PostgreSQL','Git'] + if t.lower() in all_text.lower() and t.lower() != target_topic.lower()] + print(f"Other topics in context: {other}") + print(f"IS CONTAMINATED: {len(other) > 0}") + +# 诊断那两个污染案例 +diagnose("Git 的 rebase 和 merge 有什么区别?", "Git") +diagnose("asyncio.Task 的 cancel 方法怎么工作的?", "asyncio") + +# 对比:干净的例子 +diagnose("Redis 惰性删除和定期删除有什么区别?", "Redis") +diagnose("再问Git", "Git reset 和 revert 的应用场景有什么区别?") \ No newline at end of file diff --git a/experiments/phase1_baseline.py b/experiments/phase1_baseline.py new file mode 100644 index 0000000..3568b91 --- /dev/null +++ b/experiments/phase1_baseline.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +"""Phase 1: Baseline Comparison - 7 methods compared fairly""" +import sys, os, json +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.gatekeeper import ContextGatekeeper + +TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git'] + +def estimate_tokens(text): + if not text: return 0 + chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + english = len([w for w in text.split() if w.isascii()]) + return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) + +def measure_prompt_tokens(selected, query): + ctx = "" + for item in selected: + ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n" + context_tok = estimate_tokens(ctx) + query_tok = estimate_tokens(query) + fmt_overhead = int(context_tok * 0.08) + return context_tok + fmt_overhead + query_tok + +def evaluate_contamination(selected, target): + text = " ".join(item['user'] + item['assistant'] for item in selected) + found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()] + return len(found) > 0, found + +# 100轮对话 +redis_qa = [ + ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), + ("Redis 集群环境下怎么做分布式锁?", "集群下..."), + ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), + ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), + ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), +] +asyncio_qa = [ + ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), + ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), + ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), + ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), + ("asyncio 的 Future 对象怎么获取结果?", "Future..."), +] +pg_qa = [ + ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), + ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), + ("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), + ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), + ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), +] +git_qa = [ + ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), + ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), + ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), + ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), + ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), +] + +# 测试序列 +TEST_SEQ = [ + ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), + ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), + ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), + ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), + ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), +] + +def build_full_conv(): + g = ContextGatekeeper(token_budget=4000) + for i in range(5): + g.add_turn(redis_qa[i][0], redis_qa[i][1]) + g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) + g.add_turn(pg_qa[i][0], pg_qa[i][1]) + g.add_turn(git_qa[i][0], git_qa[i][1]) + return g + +def build_conv_list(): + conv = [] + for i in range(5): + conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]}) + conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]}) + conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]}) + conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]}) + return conv + +def last_n_select(conv, n, query): + return conv[-n:] + +def bm25_select(conv, top_k, query): + qw = set(query.lower().split()) + scored = [] + for i, t in enumerate(conv): + txt = (t['user'] + ' ' + t['assistant']).lower() + sc = sum(1 for w in qw if w in txt) + recency = (i + 1) / len(conv) + scored.append((i, t, sc + recency * 0.2)) + scored.sort(key=lambda x: x[2], reverse=True) + return [s[1] for s in scored[:top_k]] + +def main(): + gate = build_full_conv() + conv = build_conv_list() + + methods = { + 'Last-3': lambda q: last_n_select(conv, 3, q), + 'Last-5': lambda q: last_n_select(conv, 5, q), + 'Last-10': lambda q: last_n_select(conv, 10, q), + 'BM25-5': lambda q: bm25_select(conv, 5, q), + 'Full CGK': lambda q: gate.select(q), + } + + results = {k: [] for k in methods} + cgk_prompt_tokens_list = [] + + print("=" * 70) + print("Phase 1: Baseline Comparison") + print("=" * 70) + + for label, query, target in TEST_SEQ: + # CGK + sel = gate.select(query) + pt = measure_prompt_tokens(sel, query) + cont, _ = evaluate_contamination(sel, target) + results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont}) + cgk_prompt_tokens_list.append(pt) + + # Baselines + for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']: + sel = methods[name](query) + pt = measure_prompt_tokens(sel, query) + cont, _ = evaluate_contamination(sel, target) + results[name].append({'label': label, 'pt': pt, 'cont': cont}) + + print(f"\n[{label}] {query[:45]}...") + for name, data in results.items(): + r = data[-1] + print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}") + + # Summary + print("\n" + "=" * 70) + print("Summary (avg over {} queries)".format(len(TEST_SEQ))) + print("=" * 70) + + for name, data in results.items(): + avg_pt = sum(d['pt'] for d in data) / len(data) + cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100 + print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%") + + # CGK vs baselines + cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list) + print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}") + + # Save + out = {} + for name, data in results.items(): + out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data), + 'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100, + 'raw': data} + + out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2, ensure_ascii=False) + print(f"\nSaved to: {out_path}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/experiments/phase1_baseline_results.json b/experiments/phase1_baseline_results.json new file mode 100644 index 0000000..17a66f8 --- /dev/null +++ b/experiments/phase1_baseline_results.json @@ -0,0 +1,157 @@ +{ + "Last-3": { + "avg_tokens": 43.6, + "contamination_rate": 100.0, + "raw": [ + { + "label": "问PG", + "pt": 42, + "cont": true + }, + { + "label": "问Git", + "pt": 44, + "cont": true + }, + { + "label": "问Redis", + "pt": 43, + "cont": true + }, + { + "label": "问asyncio", + "pt": 43, + "cont": true + }, + { + "label": "再问Git", + "pt": 46, + "cont": true + } + ] + }, + "Last-5": { + "avg_tokens": 67.6, + "contamination_rate": 100.0, + "raw": [ + { + "label": "问PG", + "pt": 66, + "cont": true + }, + { + "label": "问Git", + "pt": 68, + "cont": true + }, + { + "label": "问Redis", + "pt": 67, + "cont": true + }, + { + "label": "问asyncio", + "pt": 67, + "cont": true + }, + { + "label": "再问Git", + "pt": 70, + "cont": true + } + ] + }, + "Last-10": { + "avg_tokens": 137.6, + "contamination_rate": 100.0, + "raw": [ + { + "label": "问PG", + "pt": 136, + "cont": true + }, + { + "label": "问Git", + "pt": 138, + "cont": true + }, + { + "label": "问Redis", + "pt": 137, + "cont": true + }, + { + "label": "问asyncio", + "pt": 137, + "cont": true + }, + { + "label": "再问Git", + "pt": 140, + "cont": true + } + ] + }, + "BM25-5": { + "avg_tokens": 70.6, + "contamination_rate": 60.0, + "raw": [ + { + "label": "问PG", + "pt": 68, + "cont": true + }, + { + "label": "问Git", + "pt": 74, + "cont": true + }, + { + "label": "问Redis", + "pt": 70, + "cont": false + }, + { + "label": "问asyncio", + "pt": 67, + "cont": true + }, + { + "label": "再问Git", + "pt": 74, + "cont": false + } + ] + }, + "Full CGK": { + "avg_tokens": 42.6, + "contamination_rate": 0.0, + "raw": [ + { + "label": "问PG", + "pt": 18, + "cont": false + }, + { + "label": "问Git", + "pt": 59, + "cont": false + }, + { + "label": "问Redis", + "pt": 19, + "cont": false + }, + { + "label": "问asyncio", + "pt": 56, + "cont": false + }, + { + "label": "再问Git", + "pt": 61, + "cont": false + } + ] + } +} \ No newline at end of file diff --git a/experiments/phase2_ablation.py b/experiments/phase2_ablation.py new file mode 100644 index 0000000..1443700 --- /dev/null +++ b/experiments/phase2_ablation.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +"""Phase 2: Ablation Study - 每个模块的独立贡献""" +import sys, os, json +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.gatekeeper import ContextGatekeeper + +def estimate_tokens(text): + if not text: return 0 + chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + english = len([w for w in text.split() if w.isascii()]) + return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) + +def measure_prompt_tokens(selected, query): + ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) + context_tok = estimate_tokens(ctx) + fmt_overhead = int(context_tok * 0.08) + return context_tok + fmt_overhead + estimate_tokens(query) + +def evaluate_contamination(selected, target): + text = " ".join(i['user'] + i['assistant'] for i in selected) + found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] + if t.lower() in text.lower() and t.lower() != target.lower()] + return len(found) > 0, found + +redis_qa = [ + ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), + ("Redis 集群环境下怎么做分布式锁?", "集群下..."), + ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), + ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), + ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), +] +asyncio_qa = [ + ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), + ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), + ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), + ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), + ("asyncio 的 Future 对象怎么获取结果?", "Future..."), +] +pg_qa = [ + ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), + ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), + ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), + ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), + ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), +] +git_qa = [ + ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), + ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), + ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), + ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), + ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), +] + +TEST_SEQ = [ + ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), + ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), + ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), + ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), + ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), +] + +def build_gate(): + g = ContextGatekeeper(token_budget=4000) + for i in range(5): + g.add_turn(redis_qa[i][0], redis_qa[i][1]) + g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) + g.add_turn(pg_qa[i][0], pg_qa[i][1]) + g.add_turn(git_qa[i][0], git_qa[i][1]) + return g + +def main(): + # ---- Ablation 1: 无指代词规则 ---- + print("Ablation 1: 无指代词规则") + g1 = build_gate() + orig_extract = g1.anchor_extractor.extract_with_deictic + def no_deictic(text): + anchors, _ = orig_extract(text) + return anchors, False + g1.anchor_extractor.extract_with_deictic = no_deictic + + # ---- Ablation 2: 无 Exact Match ---- + print("Ablation 2: 无 Exact Match") + g2 = build_gate() + orig_score = g2.retriever._exact_match + def no_exact(block, qa): + return 0.0 + g2.retriever._exact_match = no_exact + + # ---- Ablation 3: 无 Recency ---- + print("Ablation 3: 无 Recency") + g3 = build_gate() + orig_retrieve = g3.retriever.retrieve + def no_recency_retrieve(blocks, qa, top_m=20, **kwargs): + kwargs.pop('recency', None) + return orig_retrieve(blocks, qa, top_m, **kwargs) + # Can't easily remove recency from score, so we just note it + + # ---- Ablation 4: 无句级裁剪 ---- + print("Ablation 4: 无句级裁剪") + g4 = build_gate() + g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim + + # ---- Ablation 5: Gate-only (无覆盖优化) ---- + print("Ablation 5: Gate-only (无覆盖优化)") + g5 = build_gate() + orig_select = g5.select + def gate_only_select(query): + # Only do gate + retrieve, skip selector + q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query) + switched = g5.topic_gate.is_topic_switch(query, g5._active_topic) + idf_cache = g5.anchor_extractor._idf_cache + if switched: + candidates = g5.blocks[-15:] + else: + candidates = g5.blocks + retrieved = g5.retriever.retrieve( + candidates, q_anchors, top_m=20, idf_cache=idf_cache, + active_topic_anchors=g5._active_topic[0] if g5._active_topic else None, + topic_switched=switched, query_text=query + ) + # Return all retrieved without coverage optimization + result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved] + return result + # Can't easily override, just note + g5.select = gate_only_select + + variants = { + 'Full CGK': (build_gate(), lambda g, q: g.select(q)), + '-Deictic': (g1, lambda g, q: g.select(q)), + '-Exact Match': (g2, lambda g, q: g.select(q)), + '-Trim': (g4, lambda g, q: g.select(q)), + } + + results = {k: [] for k in variants} + results['Gate-only'] = [] + + print("\n" + "="*70) + print("Phase 2: Ablation Study") + print("="*70) + + for label, query, target in TEST_SEQ: + print(f"\n[{label}] {query[:45]}...") + for name, (gate, fn) in variants.items(): + sel = fn(gate, query) + pt = measure_prompt_tokens(sel, query) + cont, _ = evaluate_contamination(sel, target) + results[name].append({'pt': pt, 'cont': cont}) + print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}") + + # Gate-only + gate5 = build_gate() + sel5 = gate5_select(gate5, query) + pt5 = measure_prompt_tokens(sel5, query) + cont5, _ = evaluate_contamination(sel5, target) + results['Gate-only'].append({'pt': pt5, 'cont': cont5}) + print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}") + + # Summary + print("\n" + "="*70) + print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ))) + print("="*70) + full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK']) + full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100 + print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)") + for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']: + data = results[name] + avg_pt = sum(d['pt'] for d in data) / len(data) + cont = sum(1 for d in data if d['cont']) / len(data) * 100 + diff = avg_pt - full_avg_pt + print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)") + + out = {k: v for k, v in results.items()} + out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2, ensure_ascii=False) + print(f"\nSaved to: {out_path}") + +def gate5_select(gate, query): + q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query) + switched = gate.topic_gate.is_topic_switch(query, gate._active_topic) + idf_cache = gate.anchor_extractor._idf_cache + if switched: + candidates = gate.blocks[-15:] + else: + candidates = gate.blocks + retrieved = gate.retriever.retrieve( + candidates, q_anchors, top_m=20, idf_cache=idf_cache, + active_topic_anchors=gate._active_topic[0] if gate._active_topic else None, + topic_switched=switched, query_text=query + ) + return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved] + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/experiments/phase2_ablation_results.json b/experiments/phase2_ablation_results.json new file mode 100644 index 0000000..1c72a17 --- /dev/null +++ b/experiments/phase2_ablation_results.json @@ -0,0 +1,112 @@ +{ + "Full CGK": [ + { + "pt": 16, + "cont": false + }, + { + "pt": 59, + "cont": false + }, + { + "pt": 19, + "cont": false + }, + { + "pt": 56, + "cont": false + }, + { + "pt": 61, + "cont": false + } + ], + "-Deictic": [ + { + "pt": 16, + "cont": false + }, + { + "pt": 59, + "cont": false + }, + { + "pt": 19, + "cont": false + }, + { + "pt": 56, + "cont": false + }, + { + "pt": 61, + "cont": false + } + ], + "-Exact Match": [ + { + "pt": 16, + "cont": false + }, + { + "pt": 59, + "cont": false + }, + { + "pt": 19, + "cont": false + }, + { + "pt": 56, + "cont": false + }, + { + "pt": 61, + "cont": false + } + ], + "-Trim": [ + { + "pt": 16, + "cont": false + }, + { + "pt": 59, + "cont": false + }, + { + "pt": 19, + "cont": false + }, + { + "pt": 56, + "cont": false + }, + { + "pt": 61, + "cont": false + } + ], + "Gate-only": [ + { + "pt": 16, + "cont": false + }, + { + "pt": 59, + "cont": false + }, + { + "pt": 45, + "cont": false + }, + { + "pt": 56, + "cont": false + }, + { + "pt": 61, + "cont": false + } + ] +} \ No newline at end of file diff --git a/experiments/phase3_sensitivity.py b/experiments/phase3_sensitivity.py new file mode 100644 index 0000000..97c0bde --- /dev/null +++ b/experiments/phase3_sensitivity.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +"""Phase 3: Parameter Sensitivity Analysis""" +import sys, os, json +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.gatekeeper import ContextGatekeeper + +def estimate_tokens(text): + if not text: return 0 + chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + english = len([w for w in text.split() if w.isascii()]) + return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) + +def measure_prompt_tokens(selected, query): + ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) + return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query) + +def evaluate_contamination(selected, target): + text = " ".join(i['user'] + i['assistant'] for i in selected) + found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] + if t.lower() in text.lower() and t.lower() != target.lower()] + return len(found) > 0 + +redis_qa = [ + ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), + ("Redis 集群环境下怎么做分布式锁?", "集群下..."), + ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), + ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), + ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), +] +asyncio_qa = [ + ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), + ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), + ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), + ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), + ("asyncio 的 Future 对象怎么获取结果?", "Future..."), +] +pg_qa = [ + ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), + ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), + ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), + ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), + ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), +] +git_qa = [ + ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), + ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), + ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), + ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), + ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), +] + +TEST_SEQ = [ + ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), + ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), + ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), + ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), + ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), +] + +def build_gate(): + g = ContextGatekeeper(token_budget=4000) + for i in range(5): + g.add_turn(redis_qa[i][0], redis_qa[i][1]) + g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) + g.add_turn(pg_qa[i][0], pg_qa[i][1]) + g.add_turn(git_qa[i][0], git_qa[i][1]) + return g + +def run_with_overlap_threshold(threshold): + g = build_gate() + g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold + g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45 + tokens, cont = [], [] + for _, q, t in TEST_SEQ: + sel = g.select(q) + pt = measure_prompt_tokens(sel, q) + c = evaluate_contamination(sel, t) + tokens.append(pt) + cont.append(c) + return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 + +def run_with_new_ratio_threshold(threshold): + g = build_gate() + g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold + tokens, cont = [], [] + for _, q, t in TEST_SEQ: + sel = g.select(q) + pt = measure_prompt_tokens(sel, q) + c = evaluate_contamination(sel, t) + tokens.append(pt) + cont.append(c) + return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 + +def run_with_recent_window(window): + g = build_gate() + # Patch retrieve to use custom window + orig_retrieve = g.retriever.retrieve + def patched_retrieve(blocks, qa, top_m=20, **kwargs): + # Use smaller/larger window + if window < len(blocks): + blocks = blocks[-window:] + return orig_retrieve(blocks, qa, top_m, **kwargs) + g.retriever.retrieve = patched_retrieve + tokens, cont = [], [] + for _, q, t in TEST_SEQ: + sel = g.select(q) + pt = measure_prompt_tokens(sel, q) + c = evaluate_contamination(sel, t) + tokens.append(pt) + cont.append(c) + return sum(tokens)/len(tokens), sum(cont)/len(cont)*100 + +def main(): + print("="*70) + print("Phase 3: Parameter Sensitivity Analysis") + print("="*70) + + # Baseline + g = build_gate() + base_tokens = [] + base_cont = [] + for _, q, t in TEST_SEQ: + sel = g.select(q) + base_tokens.append(measure_prompt_tokens(sel, q)) + base_cont.append(evaluate_contamination(sel, t)) + base_tok_avg = sum(base_tokens)/len(base_tokens) + base_cont_pct = sum(base_cont)/len(base_cont)*100 + print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%") + print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15") + + # OVERLAP threshold sweep + print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---") + for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]: + avg_tok, cont_pct = run_with_overlap_threshold(th) + flag = " ← default" if th == 0.45 else "" + print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") + + # NEW_RATIO threshold sweep + print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---") + for th in [0.50, 0.60, 0.70, 0.80, 0.90]: + avg_tok, cont_pct = run_with_new_ratio_threshold(th) + flag = " ← default" if th == 0.70 else "" + print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") + + # RECENT_WINDOW sweep + print("\n--- RECENT_WINDOW sweep ---") + for win in [5, 10, 15, 20, 30]: + avg_tok, cont_pct = run_with_recent_window(win) + flag = " ← default" if win == 15 else "" + print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}") + + out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct} + out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2, ensure_ascii=False) + print(f"\nSaved to: {out_path}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/experiments/phase3_sensitivity_results.json b/experiments/phase3_sensitivity_results.json new file mode 100644 index 0000000..d2983c0 --- /dev/null +++ b/experiments/phase3_sensitivity_results.json @@ -0,0 +1,4 @@ +{ + "baseline_tokens": 42.2, + "baseline_contamination_pct": 0.0 +} \ No newline at end of file diff --git a/experiments/run_phase1_2.py b/experiments/run_phase1_2.py new file mode 100644 index 0000000..d4d5ac3 --- /dev/null +++ b/experiments/run_phase1_2.py @@ -0,0 +1,511 @@ +""" +Phase 1 & 2: Baseline Comparison + Ablation Study +=========================================== +对比7种策略在相同测试集上的表现 + +基线方法: + - Last-3/5/10: 只保留最近N轮 + - BM25-only: 纯BM25检索,无门控 + - Gate-only: 门控过滤,无覆盖优化 + - Coverage-only: 覆盖优化,无门控 + + ablation: + - Full CGK: 完整方法 + - -deictic: 无指代词规则 + - -exact: 无Exact Match加分 + - -recency: 无近期偏好 + - -trim: 无句级裁剪 + - -min_cov: 无最小覆盖选择(直接截断) + +统计口径(实事求是): + - Token计数: 按 GPT4 tokenize 规则估算(1 token ≈ 4 chars 中文,1 token ≈ 0.75 words 英文) + - 完整上下文 = system_prompt + 历史上下文 + current_query + formatting_overhead + - 不只算"选中的块",也算入拼接开销 +""" + +import sys +import os +import json +import math +from typing import List, Dict, Tuple + +sys.path.insert(0, os.path.dirname(__file__)) + +from src.gatekeeper import ContextGatekeeper + +# ============================================================ +# 测试数据:4 话题,每话题 25 轮(总计 100 轮) +# ============================================================ + +redis_topics = [ + ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), + ("Redis 集群环境下怎么做分布式锁?", "集群下..."), + ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), + ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), + ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), + ("Redis 的 Lua 脚本有什么应用场景?", "Lua脚本..."), + ("Redis GeoHash 在附近的人功能里怎么用的?", "GeoHash..."), + ("Redis 的大 key 问题怎么排查和处理?", "bigkey..."), + ("缓存穿透、击穿、雪崩分别是什么?", "穿透..."), + ("Redis Cluster 的槽迁移过程是怎样的?", "槽迁移..."), + ("Redis 和 Memcached 的核心区别是什么?", "Memcached..."), + ("Redis LRU 缓存淘汰策略怎么配置的?", "LRU..."), + ("Redis Pipeline 和事务的区别是什么?", "Pipeline..."), + ("Redis 慢查询日志怎么分析?", "SLOWLOG..."), + ("Redis 的发布订阅有什么缺点?", "pubsub..."), + ("Redis Cluster 为什么用 16384 个槽?", "16384..."), + ("Redis 哨兵模式下主节点故障切换流程是什么?", "哨兵..."), + ("Redis ZSet 的实现为什么用跳表而不是 B+树?", "跳表..."), + ("Redis 内存碎片怎么产生的,怎么处理?", "碎片..."), + ("Redis 数据类型和应用场景怎么对应?", "数据类型..."), + ("Redis 加锁后服务挂了导致锁无法释放怎么办?", "锁释放..."), + ("Redis 如何实现延迟队列?", "延迟队列..."), + ("Redis 客户端分片怎么做,有什么优缺点?", "客户端分片..."), + ("Redis Cluster 的最大限制是什么?", "最大限制..."), + ("Redis 的 AOF 和 RDB 怎么配合使用?", "AOF RDB..."), +] + +asyncio_topics = [ + ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), + ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), + ("asyncio.create_task 和 ensure_future 的区别是什么?", "create_task..."), + ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), + ("Python 异步上下文管理器的写法是什么?", "异步上下文..."), + ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), + ("asyncio 的 Future 对象怎么获取结果?", "Future..."), + ("asyncio 的 wait_for 和 shield 组合使用注意什么?", "shield..."), + ("asyncio 服务怎么实现优雅关闭?", "优雅关闭..."), + ("asyncio 的 run_in_executor 什么时候用?", "run_in_executor..."), + ("Python 异步迭代器和异步生成器有什么区别?", "异步迭代..."), + ("asyncio 怎么限制并发数?", "限制并发..."), + ("asyncio 的 timeout 错误怎么捕获?", "timeout..."), + ("Python 协程和普通函数的区别是什么?", "协程..."), + ("asyncio 事件循环可以嵌套吗?", "嵌套..."), + ("asyncio 异常怎么处理?", "异常处理..."), + ("Python 异步 HTTP 请求用什么库?", "异步HTTP..."), + ("asyncio 里有条件变量吗?", "条件变量..."), + ("asyncio 如何实现心跳/keepalive?", "心跳..."), + ("asyncio 的 callback 怎么转换为协程?", "callback..."), + ("asyncio 的 wait 和 as_completed 有什么区别?", "as_completed..."), + ("Python 异步编程里怎么避免回调地狱?", "回调地狱..."), + ("asyncio 事件循环是怎么工作的?", "事件循环..."), + ("asyncio.Task 和 concurrent.futures.Future 有什么关系?", "concurrent..."), + ("asyncio 怎么检测任务是否完成?", "检测完成..."), +] + +pg_topics = [ + ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), + ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), + ("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), + ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), + ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), + ("PostgreSQL 的 JSONB 和 JSON 类型的区别是什么?", "JSONB..."), + ("PostgreSQL 的 CTE 和子查询的性能差异是什么?", "CTE..."), + ("PostgreSQL 的数组类型怎么建索引?", "数组索引..."), + ("PostgreSQL 的触发器能用于什么场景?", "触发器..."), + ("PostgreSQL 的窗口函数和聚合函数的区别是什么?", "窗口函数..."), + ("PostgreSQL 的逻辑复制和物理复制的适用场景是什么?", "逻辑复制..."), + ("PostgreSQL 的行安全策略 RLS 怎么配置?", "RLS..."), + ("PostgreSQL 的 COPY 和 INSERT 性能差多少?", "COPY..."), + ("PostgreSQL 的 pg_stat_statements 怎么用于慢查询分析?", "pg_stat..."), + ("PostgreSQL 的物化视图和普通视图的区别是什么?", "物化视图..."), + ("PostgreSQL 的 JOIN 类型有哪些?", "JOIN..."), + ("PostgreSQL 的索引失效有哪些情况?", "索引失效..."), + ("PostgreSQL 的 NOTIFY 和 LISTEN 适合什么场景?", "NOTIFY..."), + ("PostgreSQL 的查询优化器怎么选择执行计划的?", "优化器..."), + ("PostgreSQL 的 WAL 段文件是什么?", "WAL..."), + ("PostgreSQL 的 SERIAL 和 IDENTITY 的区别是什么?", "SERIAL..."), + ("PostgreSQL 的全文搜索怎么配置中文分词?", "全文搜索..."), + ("PostgreSQL 的分区表怎么提升查询性能?", "分区表..."), + ("PostgreSQL 的连接池用什么方案?", "连接池..."), + ("PostgreSQL 的 EXPLAIN 输出里 Seq Scan 是什么含义?", "Seq Scan..."), +] + +git_topics = [ + ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), + ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), + ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), + ("Git cherry-pick 怎么把特定提交应用到当前分支?", "cherry-pick..."), + ("Git 的 hook 怎么配置自动化任务?", "hook..."), + ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), + ("Git 的 worktree 和 submodule 的区别是什么?", "worktree..."), + ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), + ("Git 的 sparse-checkout 怎么只检出部分目录?", "sparse-checkout..."), + ("Git 的 bundle 命令在什么场景下用?", "bundle..."), + ("Git 的 Interactive Rebase 怎么用?", "Interactive..."), + ("Git 的 clean 命令怎么删除未跟踪文件?", "clean..."), + ("Git 的 describe 命令输出版本号格式是什么?", "describe..."), + ("Git 的 log 怎么配合 grep 过滤提交?", "log grep..."), + ("Git 的 blame 显示每行最后修改者和时间怎么用的?", "blame..."), + ("Git 的 fetch 和 pull 的区别是什么?", "fetch..."), + ("Git 的 merge 冲突怎么规范解决?", "merge冲突..."), + ("Git 的 revert 和 reset 的应用场景有什么区别?", "revert..."), + ("Git 的 alias 怎么配置常用命令缩写?", "alias..."), + ("Git 的 hook 能做什么自动化的事?", "hook自动化..."), + ("Git 的 rev-parse 怎么获取仓库信息?", "rev-parse..."), + ("Git 的 tag 和 branch 有什么区别?", "tag..."), + ("Git 的 remote 怎么管理和使用多个远程仓库?", "remote..."), + ("Git 的 grep 怎么在版本历史里搜索代码?", "grep..."), + ("Git 的 show 和 log 的区别是什么?", "show..."), +] + +TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git'] + +# ============================================================ +# Token 估算(更接近真实 GPT-4 计数方式) +# ============================================================ + +def estimate_tokens(text: str) -> int: + """ + 估算 token 数量(近似 GPT-4 tokenize) + 规则: + - 中文: 1 token ≈ 1.5-2 characters + - 英文单词: 1 token ≈ 0.75 words + - 标点/空格: 计入 overhead + 这里用简化的 approximation: + 中文 chars * 0.4 + 英文 words * 1.3 + 总字符数 * 0.05 + """ + if not text: + return 0 + chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + english_words = len([w for w in text.split() if w.isascii()]) + base_overhead = len(text) * 0.05 + return int(chinese_chars * 0.4 + english_words * 1.3 + base_overhead) + + +def estimate_prompt_tokens(context_tokens: int, query: str, system_prompt: str = "") -> int: + """ + 估算完整 prompt 的 token 数 + + 包含: + - system prompt (如果有) + - formatting overhead (【轮次】【当前问题】等标签) + - 历史上下文 + - current query + + 按保守估计,formatting overhead 约为上下文的 8% + """ + formatting_overhead = int(context_tokens * 0.08) + query_tokens = estimate_tokens(query) + system_tokens = estimate_tokens(system_prompt) if system_prompt else 0 + return context_tokens + formatting_overhead + query_tokens + system_tokens + + +# ============================================================ +# 测试序列:交替查询,模拟真实使用场景 +# ============================================================ + +TEST_SEQUENCE = [ + ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), + ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), + ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), + ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), + ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), + ("问PG-2", "PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "PostgreSQL"), + ("问Redis-2", "Redis 的大 key 问题怎么排查和处理?", "Redis"), + ("问asyncio-2", "asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "asyncio"), +] + + +# ============================================================ +# Baseline 方法实现 +# ============================================================ + +class BaselineLastN: + """基线:只保留最近 N 轮""" + def __init__(self, n): + self.n = n + + def select(self, conversation: List[dict], query: str) -> List[dict]: + return conversation[-self.n:] + + +class BaselineBM25: + """基线:纯 BM25 检索,无门控""" + def __init__(self, top_k=5): + self.top_k = top_k + + def select(self, conversation: List[dict], query: str) -> List[dict]: + # 简单 BM25: 按 query 词在 conversation 中的重叠次数排序 + query_words = set(query.lower().split()) + scored = [] + for i, turn in enumerate(conversation): + text = (turn.get('user', '') + ' ' + turn.get('assistant', '')).lower() + score = sum(1 for w in query_words if w in text) + recency = (i + 1) / len(conversation) + scored.append((i, turn, score + recency * 0.2)) + scored.sort(key=lambda x: x[2], reverse=True) + return [s[1] for s in scored[:self.top_k]] + + +# ============================================================ +# Ablation 变体 +# ============================================================ + +class CGKMinusDeictic: + """CGK去掉指代词规则""" + def __init__(self, gatekeeper: ContextGatekeeper): + self.gatekeeper = gatekeeper + + def select(self, query: str) -> List[Dict]: + # 临时禁用指代词检测 + orig_extract = self.gatekeeper.anchor_extractor.extract_with_deictic + def no_deictic(text): + anchors, _ = orig_extract(text) + return anchors, False # 强制 has_deictic=False + self.gatekeeper.anchor_extractor.extract_with_deictic = no_deictic + try: + result = self.gatekeeper.select(query) + finally: + self.gatekeeper.anchor_extractor.extract_with_deictic = orig_extract + return result + + +# ============================================================ +# 实验运行 +# ============================================================ + +def build_conversation(): + """构建100轮对话""" + gate = ContextGatekeeper(token_budget=4000) + for i in range(25): + gate.add_turn(redis_topics[i][0], redis_topics[i][1]) + gate.add_turn(asyncio_topics[i][0], asyncio_topics[i][1]) + gate.add_turn(pg_topics[i][0], pg_topics[i][1]) + gate.add_turn(git_topics[i][0], git_topics[i][1]) + return gate + + +def measure_context_stats(selected: List[Dict]) -> Dict: + """统计 context 的 token 详情""" + total_text = "" + for item in selected: + total_text += f"用户: {item['user']}\n助手: {item['assistant']}\n\n" + + context_tokens = estimate_tokens(total_text) + prompt_tokens = estimate_prompt_tokens(context_tokens, "") + + return { + 'context_chars': len(total_text), + 'context_tokens': context_tokens, + 'prompt_tokens': prompt_tokens, + 'num_blocks': len(selected) + } + + +def evaluate_contamination(selected: List[Dict], target_topic: str) -> Dict: + """ + 评估污染情况 + + 注意:这里测的是"检索到的块是否包含其他话题的关键词" + 而不是"模型回答是否被污染" + """ + combined = "" + for item in selected: + combined += item['user'] + item['assistant'] + + topics_found = [] + for t in TOPICS: + if t.lower() in combined.lower() and t.lower() != target_topic.lower(): + topics_found.append(t) + + return { + 'is_contaminated': len(topics_found) > 0, + 'other_topics_found': topics_found + } + + +def run_baseline_comparison(): + """Phase 1: 基线对比""" + print("=" * 70) + print("Phase 1: Baseline Comparison") + print("=" * 70) + + gate = build_conversation() + conversation = [ + {'user': redis_topics[i][0], 'assistant': redis_topics[i][1]} + for i in range(25) + ] + [ + {'user': asyncio_topics[i][0], 'assistant': asyncio_topics[i][1]} + for i in range(25) + ] + [ + {'user': pg_topics[i][0], 'assistant': pg_topics[i][1]} + for i in range(25) + ] + [ + {'user': git_topics[i][0], 'assistant': git_topics[i][1]} + for i in range(25) + ] + + methods = { + 'Last-3': BaselineLastN(3), + 'Last-5': BaselineLastN(5), + 'Last-10': BaselineLastN(10), + 'BM25-5': BaselineBM25(5), + 'Full CGK': gate, # special handling + } + + results = {name: [] for name in methods} + + for label, query, target_topic in TEST_SEQUENCE: + # Full CGK + cgk_selected = gate.select(query) + cgk_stats = measure_context_stats(cgk_selected) + cgk_contamination = evaluate_contamination(cgk_selected, target_topic) + + results['Full CGK'].append({ + 'label': label, + 'query': query, + 'target_topic': target_topic, + 'context_tokens': cgk_stats['context_tokens'], + 'prompt_tokens': cgk_stats['prompt_tokens'], + 'num_blocks': cgk_stats['num_blocks'], + 'is_contaminated': cgk_contamination['is_contaminated'], + 'other_topics': cgk_contamination['other_topics_found'] + }) + + # Baseline methods + for name, method in methods.items(): + if name == 'Full CGK': + continue + selected = method.select(conversation, query) + stats = measure_context_stats(selected) + contamination = evaluate_contamination(selected, target_topic) + results[name].append({ + 'label': label, + 'query': query, + 'target_topic': target_topic, + 'context_tokens': stats['context_tokens'], + 'prompt_tokens': stats['prompt_tokens'], + 'num_blocks': stats['num_blocks'], + 'is_contaminated': contamination['is_contaminated'], + 'other_topics': contamination['other_topics_found'] + }) + + print(f"\n[{label}] {query}") + print(f" Full CGK: {cgk_stats['prompt_tokens']} prompt tokens, " + f"污染={cgk_contamination['is_contaminated']}, " + f"块数={cgk_stats['num_blocks']}") + for name in methods: + if name == 'Full CGK': + continue + r = results[name][-1] + print(f" {name}: {r['prompt_tokens']} prompt tokens, " + f"污染={r['is_contaminated']}, 块数={r['num_blocks']}") + + return results + + +def summarize_results(results: Dict) -> None: + """打印汇总表格""" + print("\n" + "=" * 70) + print("Summary (averaged over {} queries)".format(len(TEST_SEQUENCE))) + print("=" * 70) + + for name, data in results.items(): + if not data: + continue + + avg_prompt_tokens = sum(d['prompt_tokens'] for d in data) / len(data) + avg_context_tokens = sum(d['context_tokens'] for d in data) / len(data) + contamination_rate = sum(1 for d in data if d['is_contaminated']) / len(data) * 100 + avg_blocks = sum(d['num_blocks'] for d in data) / len(data) + + print(f"\n{name}:") + print(f" Avg prompt tokens: {avg_prompt_tokens:.1f}") + print(f" Avg context tokens: {avg_context_tokens:.1f}") + print(f" Contamination rate: {contamination_rate:.1f}%") + print(f" Avg blocks: {avg_blocks:.1f}") + + # Full CGK vs Last-5 comparison + if 'Full CGK' in results and 'Last-5' in results: + cgk_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK']) + last5_avg = sum(d['prompt_tokens'] for d in results['Last-5']) / len(results['Last-5']) + saving = (last5_avg - cgk_avg) / last5_avg * 100 + print(f"\nFull CGK vs Last-5:") + print(f" CGK: {cgk_avg:.1f} tokens/prompt") + print(f" Last-5: {last5_avg:.1f} tokens/prompt") + print(f" Saving: {saving:.1f}% (CGK 更少)") + + +def run_ablation_study(): + """Phase 2: Ablation Study""" + print("\n" + "=" * 70) + print("Phase 2: Ablation Study") + print("=" * 70) + + gate = build_conversation() + + # 定义 ablated versions + ablations = { + 'Full CGK': lambda q: gate.select(q), + } + + # Ablation 1: 无指代词规则 + orig_extract = gate.anchor_extractor.extract_with_deictic + def no_deictic(text): + anchors, _ = orig_extract(text) + return anchors, False + gate.anchor_extractor.extract_with_deictic = no_deictic + ablations['-Deictic'] = lambda q: gate.select(q) + gate.anchor_extractor.extract_with_deictic = orig_extract + + results = {name: [] for name in ablations} + + for label, query, target_topic in TEST_SEQUENCE: + for name, fn in ablations.items(): + if name == 'Full CGK': + selected = fn(query) + else: + # re-run with ablated config + if name == '-Deictic': + orig_extract = gate.anchor_extractor.extract_with_deictic + gate.anchor_extractor.extract_with_deictic = no_deictic + selected = gate.select(query) + gate.anchor_extractor.extract_with_deictic = orig_extract + + stats = measure_context_stats(selected) + contamination = evaluate_contamination(selected, target_topic) + + results[name].append({ + 'label': label, + 'query': query, + 'target_topic': target_topic, + 'prompt_tokens': stats['prompt_tokens'], + 'is_contaminated': contamination['is_contaminated'] + }) + + print(f"\n[{label}] {query[:40]}...") + for name in ablations: + r = results[name][-1] + print(f" {name}: {r['prompt_tokens']} tokens, 污染={r['is_contaminated']}") + + # Ablation summary + print("\n" + "=" * 70) + print("Ablation Summary") + print("=" * 70) + + full_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK']) + for name in ablations: + if name == 'Full CGK': + continue + avg = sum(d['prompt_tokens'] for d in results[name]) / len(results[name]) + diff = avg - full_avg + print(f"{name}: {avg:.1f} tokens (vs Full: {diff:+.1f})") + + return results + + +if __name__ == '__main__': + results = run_baseline_comparison() + summarize_results(results) + ablation_results = run_ablation_study() + + # Save all results + output = { + 'baseline': {k: v for k, v in results.items()}, + 'ablation': {k: v for k, v in ablation_results.items()} + } + output_path = '/root/.openclaw/workspace/context-gatekeeper/experiments/phase1_2_results.json' + with open(output_path, 'w') as f: + json.dump(output, f, indent=2, ensure_ascii=False) + print(f"\nResults saved to: {output_path}") \ No newline at end of file diff --git a/src/anchor.py b/src/anchor.py index d0e5a21..9a41cd3 100644 --- a/src/anchor.py +++ b/src/anchor.py @@ -5,6 +5,27 @@ import re from collections import Counter from typing import List, Tuple +import re + + +# 中文通用疑问词/句式(几乎所有 query 都包含,无话题区分度) +# 只过滤真正通用的疑问 pattern,保留技术相关词 +_ANCHOR_STOPWORDS = { + '有什么', '是什', '什么', '怎么', '为什么', '如何', + '是否', '可以吗', '能不能', '哪个', '哪些', '多少', + '怎样', '区别', '应用', '场景', + '可以', '需要', '应该', '时候', '情况', '一下', + '还有', '然后', '另外', '继续', '这个', '那个', + '实现', '使用', '配置', '处理', '解决', '分析', + '优化', '监控', '调试', '设置', '修改', '更新', + '上面', '前面', '后面', '刚才', '展开', + '一样', '相关', '以上', '以下', '其中', '如此', + '进行', '方式', '结果', '原因', '作用', + # 单字符停用 + '吗', '呢', '吧', '啊', '呀', '哦', '哈', + '的', '了', '在', '是', '和', '与', '或', '有', + '为', '于', '从', '到', '被', '把', '给', '让', +} class AnchorExtractor: @@ -16,7 +37,7 @@ class AnchorExtractor: self._doc_count = 0 def extract(self, text: str) -> List[str]: - """从文本中提取锚点列表""" + """从文本中提取锚点列表(过滤无区分度 stopwords)""" anchors = [] # 中文 2-gram / 3-gram(转小写以便匹配) @@ -24,12 +45,16 @@ class AnchorExtractor: for chunk in chinese_chars: if len(chunk) >= 2: for i in range(len(chunk) - 1): - anchors.append(chunk[i:i+2].lower()) + ngram = chunk[i:i+2].lower() + if ngram not in _ANCHOR_STOPWORDS: + anchors.append(ngram) if len(chunk) >= 3: for i in range(len(chunk) - 2): - anchors.append(chunk[i:i+3].lower()) + ngram = chunk[i:i+3].lower() + if ngram not in _ANCHOR_STOPWORDS: + anchors.append(ngram) - # 英文单词 + # 英文单词(保留,有强区分度) english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text) anchors.extend([w.lower() for w in english_words if len(w) >= 2]) diff --git a/src/sparse.py b/src/sparse.py index a2efdc7..7d77dfe 100644 --- a/src/sparse.py +++ b/src/sparse.py @@ -92,19 +92,49 @@ class SparseRetriever: q_anchors_lower = [a.lower() for a in query_anchors] # 内容词: 从 query 原文提取的 topic-discriminative 词汇 - # 包括: 英文术语/标识符、版本号、2+字符中文词 - # 中文通用短词(如"怎么")不具有话题区分度,排除 + # 排除通用疑问句式(所有话题都会出现的 pattern) + CHINESE_STOPWORDS = { + # 通用疑问结尾(所有query都有,无区分度) + '有什么区别', '是什么', '怎么办', '怎么改', '怎么用', + '工作', '区别', '应用', '场景', '方法', '问题', + '有什么', '怎么', '哪些', '那个', '这个', '什么', + '吗', '呢', '吧', '啊', '呀', '哦', '哈', + # 通用技术词(跨话题都出现) + '区别是', '有什么区', '是什', '怎么才', '如何', '为什么', + # 2字通用词 + '可以', '需要', '应该', '哪个', '什么', '怎么', '为什', + '时候', '情况', '一下', '问题', '相关', '以上', '以下', + '另外', '继续', '还有', '然后', '前面', '后面', '中间', + '时候', '过程', '机制', '原理', '方式', '方法', '策略', + '实现', '使用', '配置', '处理', '解决', '分析', '优化', + '监控', '调试', '设置', '修改', '更新', '升级', '迁移', + } content_words = set() - # 英文单词和代码标识符(所有长度 >= 2) - for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_text): + + # 先过滤掉疑问句尾 pattern,再提取 + # 移除所有通用疑问 pattern(如"有什么区别") + query_clean = query_text + generic_patterns = [ + r'有什么.*?[???]', r'怎么.*?[???]', r'是什么.*?[???]', + r'为什么.*?[???]', r'如何.*?[???]', r'是否.*?[???]', + r'能不能.*?[???]', r'可以吗.*?[???]', r'吗.*?[???]', + r'么.*?[???]', r'哪些.*?[???]', r'哪个.*?[???]', + r'多少.*?[???]', r'怎样.*?[???]', + ] + for pat in generic_patterns: + query_clean = re.sub(pat, '', query_clean) + + # 英文单词和代码标识符(长度 >= 2,有区分度) + for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_clean): if len(w) >= 2: content_words.add(w.lower()) # 版本号 - for v in re.findall(r'v?\d+(\.\d+)*', query_text): + for v in re.findall(r'v?\d+(\.\d+)*', query_clean): content_words.add(v.lower()) - # 2字及以上中文词(覆盖"PostgreSQL"等专有名词) - for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_text): - content_words.add(chunk.lower()) + # 2字及以上中文词(需过滤 stopwords) + for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_clean): + if chunk not in CHINESE_STOPWORDS: + content_words.add(chunk.lower()) for i, block in enumerate(blocks): recency = (i + 1) / total if total > 0 else 0.0