""" Phase 1 & 2: Baseline Comparison + Ablation Study =========================================== 对比7种策略在相同测试集上的表现 基线方法: - Last-3/5/10: 只保留最近N轮 - BM25-only: 纯BM25检索,无门控 - Gate-only: 门控过滤,无覆盖优化 - Coverage-only: 覆盖优化,无门控 ablation: - Full CGK: 完整方法 - -deictic: 无指代词规则 - -exact: 无Exact Match加分 - -recency: 无近期偏好 - -trim: 无句级裁剪 - -min_cov: 无最小覆盖选择(直接截断) 统计口径(实事求是): - Token计数: 按 GPT4 tokenize 规则估算(1 token ≈ 4 chars 中文,1 token ≈ 0.75 words 英文) - 完整上下文 = system_prompt + 历史上下文 + current_query + formatting_overhead - 不只算"选中的块",也算入拼接开销 """ import sys import os import json import math from typing import List, Dict, Tuple sys.path.insert(0, os.path.dirname(__file__)) from src.gatekeeper import ContextGatekeeper # ============================================================ # 测试数据:4 话题,每话题 25 轮(总计 100 轮) # ============================================================ redis_topics = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ("Redis 的 Lua 脚本有什么应用场景?", "Lua脚本..."), ("Redis GeoHash 在附近的人功能里怎么用的?", "GeoHash..."), ("Redis 的大 key 问题怎么排查和处理?", "bigkey..."), ("缓存穿透、击穿、雪崩分别是什么?", "穿透..."), ("Redis Cluster 的槽迁移过程是怎样的?", "槽迁移..."), ("Redis 和 Memcached 的核心区别是什么?", "Memcached..."), ("Redis LRU 缓存淘汰策略怎么配置的?", "LRU..."), ("Redis Pipeline 和事务的区别是什么?", "Pipeline..."), ("Redis 慢查询日志怎么分析?", "SLOWLOG..."), ("Redis 的发布订阅有什么缺点?", "pubsub..."), ("Redis Cluster 为什么用 16384 个槽?", "16384..."), ("Redis 哨兵模式下主节点故障切换流程是什么?", "哨兵..."), ("Redis ZSet 的实现为什么用跳表而不是 B+树?", "跳表..."), ("Redis 内存碎片怎么产生的,怎么处理?", "碎片..."), ("Redis 数据类型和应用场景怎么对应?", "数据类型..."), ("Redis 加锁后服务挂了导致锁无法释放怎么办?", "锁释放..."), ("Redis 如何实现延迟队列?", "延迟队列..."), ("Redis 客户端分片怎么做,有什么优缺点?", "客户端分片..."), ("Redis Cluster 的最大限制是什么?", "最大限制..."), ("Redis 的 AOF 和 RDB 怎么配合使用?", "AOF RDB..."), ] asyncio_topics = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio.create_task 和 ensure_future 的区别是什么?", "create_task..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("Python 异步上下文管理器的写法是什么?", "异步上下文..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ("asyncio 的 wait_for 和 shield 组合使用注意什么?", "shield..."), ("asyncio 服务怎么实现优雅关闭?", "优雅关闭..."), ("asyncio 的 run_in_executor 什么时候用?", "run_in_executor..."), ("Python 异步迭代器和异步生成器有什么区别?", "异步迭代..."), ("asyncio 怎么限制并发数?", "限制并发..."), ("asyncio 的 timeout 错误怎么捕获?", "timeout..."), ("Python 协程和普通函数的区别是什么?", "协程..."), ("asyncio 事件循环可以嵌套吗?", "嵌套..."), ("asyncio 异常怎么处理?", "异常处理..."), ("Python 异步 HTTP 请求用什么库?", "异步HTTP..."), ("asyncio 里有条件变量吗?", "条件变量..."), ("asyncio 如何实现心跳/keepalive?", "心跳..."), ("asyncio 的 callback 怎么转换为协程?", "callback..."), ("asyncio 的 wait 和 as_completed 有什么区别?", "as_completed..."), ("Python 异步编程里怎么避免回调地狱?", "回调地狱..."), ("asyncio 事件循环是怎么工作的?", "事件循环..."), ("asyncio.Task 和 concurrent.futures.Future 有什么关系?", "concurrent..."), ("asyncio 怎么检测任务是否完成?", "检测完成..."), ] pg_topics = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ("PostgreSQL 的 JSONB 和 JSON 类型的区别是什么?", "JSONB..."), ("PostgreSQL 的 CTE 和子查询的性能差异是什么?", "CTE..."), ("PostgreSQL 的数组类型怎么建索引?", "数组索引..."), ("PostgreSQL 的触发器能用于什么场景?", "触发器..."), ("PostgreSQL 的窗口函数和聚合函数的区别是什么?", "窗口函数..."), ("PostgreSQL 的逻辑复制和物理复制的适用场景是什么?", "逻辑复制..."), ("PostgreSQL 的行安全策略 RLS 怎么配置?", "RLS..."), ("PostgreSQL 的 COPY 和 INSERT 性能差多少?", "COPY..."), ("PostgreSQL 的 pg_stat_statements 怎么用于慢查询分析?", "pg_stat..."), ("PostgreSQL 的物化视图和普通视图的区别是什么?", "物化视图..."), ("PostgreSQL 的 JOIN 类型有哪些?", "JOIN..."), ("PostgreSQL 的索引失效有哪些情况?", "索引失效..."), ("PostgreSQL 的 NOTIFY 和 LISTEN 适合什么场景?", "NOTIFY..."), ("PostgreSQL 的查询优化器怎么选择执行计划的?", "优化器..."), ("PostgreSQL 的 WAL 段文件是什么?", "WAL..."), ("PostgreSQL 的 SERIAL 和 IDENTITY 的区别是什么?", "SERIAL..."), ("PostgreSQL 的全文搜索怎么配置中文分词?", "全文搜索..."), ("PostgreSQL 的分区表怎么提升查询性能?", "分区表..."), ("PostgreSQL 的连接池用什么方案?", "连接池..."), ("PostgreSQL 的 EXPLAIN 输出里 Seq Scan 是什么含义?", "Seq Scan..."), ] git_topics = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git cherry-pick 怎么把特定提交应用到当前分支?", "cherry-pick..."), ("Git 的 hook 怎么配置自动化任务?", "hook..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 worktree 和 submodule 的区别是什么?", "worktree..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ("Git 的 sparse-checkout 怎么只检出部分目录?", "sparse-checkout..."), ("Git 的 bundle 命令在什么场景下用?", "bundle..."), ("Git 的 Interactive Rebase 怎么用?", "Interactive..."), ("Git 的 clean 命令怎么删除未跟踪文件?", "clean..."), ("Git 的 describe 命令输出版本号格式是什么?", "describe..."), ("Git 的 log 怎么配合 grep 过滤提交?", "log grep..."), ("Git 的 blame 显示每行最后修改者和时间怎么用的?", "blame..."), ("Git 的 fetch 和 pull 的区别是什么?", "fetch..."), ("Git 的 merge 冲突怎么规范解决?", "merge冲突..."), ("Git 的 revert 和 reset 的应用场景有什么区别?", "revert..."), ("Git 的 alias 怎么配置常用命令缩写?", "alias..."), ("Git 的 hook 能做什么自动化的事?", "hook自动化..."), ("Git 的 rev-parse 怎么获取仓库信息?", "rev-parse..."), ("Git 的 tag 和 branch 有什么区别?", "tag..."), ("Git 的 remote 怎么管理和使用多个远程仓库?", "remote..."), ("Git 的 grep 怎么在版本历史里搜索代码?", "grep..."), ("Git 的 show 和 log 的区别是什么?", "show..."), ] TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git'] # ============================================================ # Token 估算(更接近真实 GPT-4 计数方式) # ============================================================ def estimate_tokens(text: str) -> int: """ 估算 token 数量(近似 GPT-4 tokenize) 规则: - 中文: 1 token ≈ 1.5-2 characters - 英文单词: 1 token ≈ 0.75 words - 标点/空格: 计入 overhead 这里用简化的 approximation: 中文 chars * 0.4 + 英文 words * 1.3 + 总字符数 * 0.05 """ if not text: return 0 chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english_words = len([w for w in text.split() if w.isascii()]) base_overhead = len(text) * 0.05 return int(chinese_chars * 0.4 + english_words * 1.3 + base_overhead) def estimate_prompt_tokens(context_tokens: int, query: str, system_prompt: str = "") -> int: """ 估算完整 prompt 的 token 数 包含: - system prompt (如果有) - formatting overhead (【轮次】【当前问题】等标签) - 历史上下文 - current query 按保守估计,formatting overhead 约为上下文的 8% """ formatting_overhead = int(context_tokens * 0.08) query_tokens = estimate_tokens(query) system_tokens = estimate_tokens(system_prompt) if system_prompt else 0 return context_tokens + formatting_overhead + query_tokens + system_tokens # ============================================================ # 测试序列:交替查询,模拟真实使用场景 # ============================================================ TEST_SEQUENCE = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ("问PG-2", "PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "PostgreSQL"), ("问Redis-2", "Redis 的大 key 问题怎么排查和处理?", "Redis"), ("问asyncio-2", "asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "asyncio"), ] # ============================================================ # Baseline 方法实现 # ============================================================ class BaselineLastN: """基线:只保留最近 N 轮""" def __init__(self, n): self.n = n def select(self, conversation: List[dict], query: str) -> List[dict]: return conversation[-self.n:] class BaselineBM25: """基线:纯 BM25 检索,无门控""" def __init__(self, top_k=5): self.top_k = top_k def select(self, conversation: List[dict], query: str) -> List[dict]: # 简单 BM25: 按 query 词在 conversation 中的重叠次数排序 query_words = set(query.lower().split()) scored = [] for i, turn in enumerate(conversation): text = (turn.get('user', '') + ' ' + turn.get('assistant', '')).lower() score = sum(1 for w in query_words if w in text) recency = (i + 1) / len(conversation) scored.append((i, turn, score + recency * 0.2)) scored.sort(key=lambda x: x[2], reverse=True) return [s[1] for s in scored[:self.top_k]] # ============================================================ # Ablation 变体 # ============================================================ class CGKMinusDeictic: """CGK去掉指代词规则""" def __init__(self, gatekeeper: ContextGatekeeper): self.gatekeeper = gatekeeper def select(self, query: str) -> List[Dict]: # 临时禁用指代词检测 orig_extract = self.gatekeeper.anchor_extractor.extract_with_deictic def no_deictic(text): anchors, _ = orig_extract(text) return anchors, False # 强制 has_deictic=False self.gatekeeper.anchor_extractor.extract_with_deictic = no_deictic try: result = self.gatekeeper.select(query) finally: self.gatekeeper.anchor_extractor.extract_with_deictic = orig_extract return result # ============================================================ # 实验运行 # ============================================================ def build_conversation(): """构建100轮对话""" gate = ContextGatekeeper(token_budget=4000) for i in range(25): gate.add_turn(redis_topics[i][0], redis_topics[i][1]) gate.add_turn(asyncio_topics[i][0], asyncio_topics[i][1]) gate.add_turn(pg_topics[i][0], pg_topics[i][1]) gate.add_turn(git_topics[i][0], git_topics[i][1]) return gate def measure_context_stats(selected: List[Dict]) -> Dict: """统计 context 的 token 详情""" total_text = "" for item in selected: total_text += f"用户: {item['user']}\n助手: {item['assistant']}\n\n" context_tokens = estimate_tokens(total_text) prompt_tokens = estimate_prompt_tokens(context_tokens, "") return { 'context_chars': len(total_text), 'context_tokens': context_tokens, 'prompt_tokens': prompt_tokens, 'num_blocks': len(selected) } def evaluate_contamination(selected: List[Dict], target_topic: str) -> Dict: """ 评估污染情况 注意:这里测的是"检索到的块是否包含其他话题的关键词" 而不是"模型回答是否被污染" """ combined = "" for item in selected: combined += item['user'] + item['assistant'] topics_found = [] for t in TOPICS: if t.lower() in combined.lower() and t.lower() != target_topic.lower(): topics_found.append(t) return { 'is_contaminated': len(topics_found) > 0, 'other_topics_found': topics_found } def run_baseline_comparison(): """Phase 1: 基线对比""" print("=" * 70) print("Phase 1: Baseline Comparison") print("=" * 70) gate = build_conversation() conversation = [ {'user': redis_topics[i][0], 'assistant': redis_topics[i][1]} for i in range(25) ] + [ {'user': asyncio_topics[i][0], 'assistant': asyncio_topics[i][1]} for i in range(25) ] + [ {'user': pg_topics[i][0], 'assistant': pg_topics[i][1]} for i in range(25) ] + [ {'user': git_topics[i][0], 'assistant': git_topics[i][1]} for i in range(25) ] methods = { 'Last-3': BaselineLastN(3), 'Last-5': BaselineLastN(5), 'Last-10': BaselineLastN(10), 'BM25-5': BaselineBM25(5), 'Full CGK': gate, # special handling } results = {name: [] for name in methods} for label, query, target_topic in TEST_SEQUENCE: # Full CGK cgk_selected = gate.select(query) cgk_stats = measure_context_stats(cgk_selected) cgk_contamination = evaluate_contamination(cgk_selected, target_topic) results['Full CGK'].append({ 'label': label, 'query': query, 'target_topic': target_topic, 'context_tokens': cgk_stats['context_tokens'], 'prompt_tokens': cgk_stats['prompt_tokens'], 'num_blocks': cgk_stats['num_blocks'], 'is_contaminated': cgk_contamination['is_contaminated'], 'other_topics': cgk_contamination['other_topics_found'] }) # Baseline methods for name, method in methods.items(): if name == 'Full CGK': continue selected = method.select(conversation, query) stats = measure_context_stats(selected) contamination = evaluate_contamination(selected, target_topic) results[name].append({ 'label': label, 'query': query, 'target_topic': target_topic, 'context_tokens': stats['context_tokens'], 'prompt_tokens': stats['prompt_tokens'], 'num_blocks': stats['num_blocks'], 'is_contaminated': contamination['is_contaminated'], 'other_topics': contamination['other_topics_found'] }) print(f"\n[{label}] {query}") print(f" Full CGK: {cgk_stats['prompt_tokens']} prompt tokens, " f"污染={cgk_contamination['is_contaminated']}, " f"块数={cgk_stats['num_blocks']}") for name in methods: if name == 'Full CGK': continue r = results[name][-1] print(f" {name}: {r['prompt_tokens']} prompt tokens, " f"污染={r['is_contaminated']}, 块数={r['num_blocks']}") return results def summarize_results(results: Dict) -> None: """打印汇总表格""" print("\n" + "=" * 70) print("Summary (averaged over {} queries)".format(len(TEST_SEQUENCE))) print("=" * 70) for name, data in results.items(): if not data: continue avg_prompt_tokens = sum(d['prompt_tokens'] for d in data) / len(data) avg_context_tokens = sum(d['context_tokens'] for d in data) / len(data) contamination_rate = sum(1 for d in data if d['is_contaminated']) / len(data) * 100 avg_blocks = sum(d['num_blocks'] for d in data) / len(data) print(f"\n{name}:") print(f" Avg prompt tokens: {avg_prompt_tokens:.1f}") print(f" Avg context tokens: {avg_context_tokens:.1f}") print(f" Contamination rate: {contamination_rate:.1f}%") print(f" Avg blocks: {avg_blocks:.1f}") # Full CGK vs Last-5 comparison if 'Full CGK' in results and 'Last-5' in results: cgk_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK']) last5_avg = sum(d['prompt_tokens'] for d in results['Last-5']) / len(results['Last-5']) saving = (last5_avg - cgk_avg) / last5_avg * 100 print(f"\nFull CGK vs Last-5:") print(f" CGK: {cgk_avg:.1f} tokens/prompt") print(f" Last-5: {last5_avg:.1f} tokens/prompt") print(f" Saving: {saving:.1f}% (CGK 更少)") def run_ablation_study(): """Phase 2: Ablation Study""" print("\n" + "=" * 70) print("Phase 2: Ablation Study") print("=" * 70) gate = build_conversation() # 定义 ablated versions ablations = { 'Full CGK': lambda q: gate.select(q), } # Ablation 1: 无指代词规则 orig_extract = gate.anchor_extractor.extract_with_deictic def no_deictic(text): anchors, _ = orig_extract(text) return anchors, False gate.anchor_extractor.extract_with_deictic = no_deictic ablations['-Deictic'] = lambda q: gate.select(q) gate.anchor_extractor.extract_with_deictic = orig_extract results = {name: [] for name in ablations} for label, query, target_topic in TEST_SEQUENCE: for name, fn in ablations.items(): if name == 'Full CGK': selected = fn(query) else: # re-run with ablated config if name == '-Deictic': orig_extract = gate.anchor_extractor.extract_with_deictic gate.anchor_extractor.extract_with_deictic = no_deictic selected = gate.select(query) gate.anchor_extractor.extract_with_deictic = orig_extract stats = measure_context_stats(selected) contamination = evaluate_contamination(selected, target_topic) results[name].append({ 'label': label, 'query': query, 'target_topic': target_topic, 'prompt_tokens': stats['prompt_tokens'], 'is_contaminated': contamination['is_contaminated'] }) print(f"\n[{label}] {query[:40]}...") for name in ablations: r = results[name][-1] print(f" {name}: {r['prompt_tokens']} tokens, 污染={r['is_contaminated']}") # Ablation summary print("\n" + "=" * 70) print("Ablation Summary") print("=" * 70) full_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK']) for name in ablations: if name == 'Full CGK': continue avg = sum(d['prompt_tokens'] for d in results[name]) / len(results[name]) diff = avg - full_avg print(f"{name}: {avg:.1f} tokens (vs Full: {diff:+.1f})") return results if __name__ == '__main__': results = run_baseline_comparison() summarize_results(results) ablation_results = run_ablation_study() # Save all results output = { 'baseline': {k: v for k, v in results.items()}, 'ablation': {k: v for k, v in ablation_results.items()} } output_path = '/root/.openclaw/workspace/context-gatekeeper/experiments/phase1_2_results.json' with open(output_path, 'w') as f: json.dump(output, f, indent=2, ensure_ascii=False) print(f"\nResults saved to: {output_path}")