Files
context-gatekeeper/experiments/run_phase1_2.py
Elaina 9e44748f91 fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
2026-04-22 22:30:18 +08:00

511 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase 1 & 2: Baseline Comparison + Ablation Study
===========================================
对比7种策略在相同测试集上的表现
基线方法:
- Last-3/5/10: 只保留最近N轮
- BM25-only: 纯BM25检索无门控
- Gate-only: 门控过滤,无覆盖优化
- Coverage-only: 覆盖优化,无门控
ablation:
- Full CGK: 完整方法
- -deictic: 无指代词规则
- -exact: 无Exact Match加分
- -recency: 无近期偏好
- -trim: 无句级裁剪
- -min_cov: 无最小覆盖选择(直接截断)
统计口径(实事求是):
- Token计数: 按 GPT4 tokenize 规则估算1 token ≈ 4 chars 中文1 token ≈ 0.75 words 英文)
- 完整上下文 = system_prompt + 历史上下文 + current_query + formatting_overhead
- 不只算"选中的块",也算入拼接开销
"""
import sys
import os
import json
import math
from typing import List, Dict, Tuple
sys.path.insert(0, os.path.dirname(__file__))
from src.gatekeeper import ContextGatekeeper
# ============================================================
# 测试数据4 话题,每话题 25 轮(总计 100 轮)
# ============================================================
redis_topics = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
("Redis 的 Lua 脚本有什么应用场景?", "Lua脚本..."),
("Redis GeoHash 在附近的人功能里怎么用的?", "GeoHash..."),
("Redis 的大 key 问题怎么排查和处理?", "bigkey..."),
("缓存穿透、击穿、雪崩分别是什么?", "穿透..."),
("Redis Cluster 的槽迁移过程是怎样的?", "槽迁移..."),
("Redis 和 Memcached 的核心区别是什么?", "Memcached..."),
("Redis LRU 缓存淘汰策略怎么配置的?", "LRU..."),
("Redis Pipeline 和事务的区别是什么?", "Pipeline..."),
("Redis 慢查询日志怎么分析?", "SLOWLOG..."),
("Redis 的发布订阅有什么缺点?", "pubsub..."),
("Redis Cluster 为什么用 16384 个槽?", "16384..."),
("Redis 哨兵模式下主节点故障切换流程是什么?", "哨兵..."),
("Redis ZSet 的实现为什么用跳表而不是 B+树?", "跳表..."),
("Redis 内存碎片怎么产生的,怎么处理?", "碎片..."),
("Redis 数据类型和应用场景怎么对应?", "数据类型..."),
("Redis 加锁后服务挂了导致锁无法释放怎么办?", "锁释放..."),
("Redis 如何实现延迟队列?", "延迟队列..."),
("Redis 客户端分片怎么做,有什么优缺点?", "客户端分片..."),
("Redis Cluster 的最大限制是什么?", "最大限制..."),
("Redis 的 AOF 和 RDB 怎么配合使用?", "AOF RDB..."),
]
asyncio_topics = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio.create_task 和 ensure_future 的区别是什么?", "create_task..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("Python 异步上下文管理器的写法是什么?", "异步上下文..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
("asyncio 的 wait_for 和 shield 组合使用注意什么?", "shield..."),
("asyncio 服务怎么实现优雅关闭?", "优雅关闭..."),
("asyncio 的 run_in_executor 什么时候用?", "run_in_executor..."),
("Python 异步迭代器和异步生成器有什么区别?", "异步迭代..."),
("asyncio 怎么限制并发数?", "限制并发..."),
("asyncio 的 timeout 错误怎么捕获?", "timeout..."),
("Python 协程和普通函数的区别是什么?", "协程..."),
("asyncio 事件循环可以嵌套吗?", "嵌套..."),
("asyncio 异常怎么处理?", "异常处理..."),
("Python 异步 HTTP 请求用什么库?", "异步HTTP..."),
("asyncio 里有条件变量吗?", "条件变量..."),
("asyncio 如何实现心跳/keepalive", "心跳..."),
("asyncio 的 callback 怎么转换为协程?", "callback..."),
("asyncio 的 wait 和 as_completed 有什么区别?", "as_completed..."),
("Python 异步编程里怎么避免回调地狱?", "回调地狱..."),
("asyncio 事件循环是怎么工作的?", "事件循环..."),
("asyncio.Task 和 concurrent.futures.Future 有什么关系?", "concurrent..."),
("asyncio 怎么检测任务是否完成?", "检测完成..."),
]
pg_topics = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
("PostgreSQL 的 JSONB 和 JSON 类型的区别是什么?", "JSONB..."),
("PostgreSQL 的 CTE 和子查询的性能差异是什么?", "CTE..."),
("PostgreSQL 的数组类型怎么建索引?", "数组索引..."),
("PostgreSQL 的触发器能用于什么场景?", "触发器..."),
("PostgreSQL 的窗口函数和聚合函数的区别是什么?", "窗口函数..."),
("PostgreSQL 的逻辑复制和物理复制的适用场景是什么?", "逻辑复制..."),
("PostgreSQL 的行安全策略 RLS 怎么配置?", "RLS..."),
("PostgreSQL 的 COPY 和 INSERT 性能差多少?", "COPY..."),
("PostgreSQL 的 pg_stat_statements 怎么用于慢查询分析?", "pg_stat..."),
("PostgreSQL 的物化视图和普通视图的区别是什么?", "物化视图..."),
("PostgreSQL 的 JOIN 类型有哪些?", "JOIN..."),
("PostgreSQL 的索引失效有哪些情况?", "索引失效..."),
("PostgreSQL 的 NOTIFY 和 LISTEN 适合什么场景?", "NOTIFY..."),
("PostgreSQL 的查询优化器怎么选择执行计划的?", "优化器..."),
("PostgreSQL 的 WAL 段文件是什么?", "WAL..."),
("PostgreSQL 的 SERIAL 和 IDENTITY 的区别是什么?", "SERIAL..."),
("PostgreSQL 的全文搜索怎么配置中文分词?", "全文搜索..."),
("PostgreSQL 的分区表怎么提升查询性能?", "分区表..."),
("PostgreSQL 的连接池用什么方案?", "连接池..."),
("PostgreSQL 的 EXPLAIN 输出里 Seq Scan 是什么含义?", "Seq Scan..."),
]
git_topics = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git cherry-pick 怎么把特定提交应用到当前分支?", "cherry-pick..."),
("Git 的 hook 怎么配置自动化任务?", "hook..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 worktree 和 submodule 的区别是什么?", "worktree..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
("Git 的 sparse-checkout 怎么只检出部分目录?", "sparse-checkout..."),
("Git 的 bundle 命令在什么场景下用?", "bundle..."),
("Git 的 Interactive Rebase 怎么用?", "Interactive..."),
("Git 的 clean 命令怎么删除未跟踪文件?", "clean..."),
("Git 的 describe 命令输出版本号格式是什么?", "describe..."),
("Git 的 log 怎么配合 grep 过滤提交?", "log grep..."),
("Git 的 blame 显示每行最后修改者和时间怎么用的?", "blame..."),
("Git 的 fetch 和 pull 的区别是什么?", "fetch..."),
("Git 的 merge 冲突怎么规范解决?", "merge冲突..."),
("Git 的 revert 和 reset 的应用场景有什么区别?", "revert..."),
("Git 的 alias 怎么配置常用命令缩写?", "alias..."),
("Git 的 hook 能做什么自动化的事?", "hook自动化..."),
("Git 的 rev-parse 怎么获取仓库信息?", "rev-parse..."),
("Git 的 tag 和 branch 有什么区别?", "tag..."),
("Git 的 remote 怎么管理和使用多个远程仓库?", "remote..."),
("Git 的 grep 怎么在版本历史里搜索代码?", "grep..."),
("Git 的 show 和 log 的区别是什么?", "show..."),
]
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
# ============================================================
# Token 估算(更接近真实 GPT-4 计数方式)
# ============================================================
def estimate_tokens(text: str) -> int:
"""
估算 token 数量(近似 GPT-4 tokenize
规则:
- 中文: 1 token ≈ 1.5-2 characters
- 英文单词: 1 token ≈ 0.75 words
- 标点/空格: 计入 overhead
这里用简化的 approximation:
中文 chars * 0.4 + 英文 words * 1.3 + 总字符数 * 0.05
"""
if not text:
return 0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english_words = len([w for w in text.split() if w.isascii()])
base_overhead = len(text) * 0.05
return int(chinese_chars * 0.4 + english_words * 1.3 + base_overhead)
def estimate_prompt_tokens(context_tokens: int, query: str, system_prompt: str = "") -> int:
"""
估算完整 prompt 的 token 数
包含:
- system prompt (如果有)
- formatting overhead (【轮次】【当前问题】等标签)
- 历史上下文
- current query
按保守估计formatting overhead 约为上下文的 8%
"""
formatting_overhead = int(context_tokens * 0.08)
query_tokens = estimate_tokens(query)
system_tokens = estimate_tokens(system_prompt) if system_prompt else 0
return context_tokens + formatting_overhead + query_tokens + system_tokens
# ============================================================
# 测试序列:交替查询,模拟真实使用场景
# ============================================================
TEST_SEQUENCE = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
("问PG-2", "PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "PostgreSQL"),
("问Redis-2", "Redis 的大 key 问题怎么排查和处理?", "Redis"),
("问asyncio-2", "asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "asyncio"),
]
# ============================================================
# Baseline 方法实现
# ============================================================
class BaselineLastN:
"""基线:只保留最近 N 轮"""
def __init__(self, n):
self.n = n
def select(self, conversation: List[dict], query: str) -> List[dict]:
return conversation[-self.n:]
class BaselineBM25:
"""基线:纯 BM25 检索,无门控"""
def __init__(self, top_k=5):
self.top_k = top_k
def select(self, conversation: List[dict], query: str) -> List[dict]:
# 简单 BM25: 按 query 词在 conversation 中的重叠次数排序
query_words = set(query.lower().split())
scored = []
for i, turn in enumerate(conversation):
text = (turn.get('user', '') + ' ' + turn.get('assistant', '')).lower()
score = sum(1 for w in query_words if w in text)
recency = (i + 1) / len(conversation)
scored.append((i, turn, score + recency * 0.2))
scored.sort(key=lambda x: x[2], reverse=True)
return [s[1] for s in scored[:self.top_k]]
# ============================================================
# Ablation 变体
# ============================================================
class CGKMinusDeictic:
"""CGK去掉指代词规则"""
def __init__(self, gatekeeper: ContextGatekeeper):
self.gatekeeper = gatekeeper
def select(self, query: str) -> List[Dict]:
# 临时禁用指代词检测
orig_extract = self.gatekeeper.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False # 强制 has_deictic=False
self.gatekeeper.anchor_extractor.extract_with_deictic = no_deictic
try:
result = self.gatekeeper.select(query)
finally:
self.gatekeeper.anchor_extractor.extract_with_deictic = orig_extract
return result
# ============================================================
# 实验运行
# ============================================================
def build_conversation():
"""构建100轮对话"""
gate = ContextGatekeeper(token_budget=4000)
for i in range(25):
gate.add_turn(redis_topics[i][0], redis_topics[i][1])
gate.add_turn(asyncio_topics[i][0], asyncio_topics[i][1])
gate.add_turn(pg_topics[i][0], pg_topics[i][1])
gate.add_turn(git_topics[i][0], git_topics[i][1])
return gate
def measure_context_stats(selected: List[Dict]) -> Dict:
"""统计 context 的 token 详情"""
total_text = ""
for item in selected:
total_text += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
context_tokens = estimate_tokens(total_text)
prompt_tokens = estimate_prompt_tokens(context_tokens, "")
return {
'context_chars': len(total_text),
'context_tokens': context_tokens,
'prompt_tokens': prompt_tokens,
'num_blocks': len(selected)
}
def evaluate_contamination(selected: List[Dict], target_topic: str) -> Dict:
"""
评估污染情况
注意:这里测的是"检索到的块是否包含其他话题的关键词"
而不是"模型回答是否被污染"
"""
combined = ""
for item in selected:
combined += item['user'] + item['assistant']
topics_found = []
for t in TOPICS:
if t.lower() in combined.lower() and t.lower() != target_topic.lower():
topics_found.append(t)
return {
'is_contaminated': len(topics_found) > 0,
'other_topics_found': topics_found
}
def run_baseline_comparison():
"""Phase 1: 基线对比"""
print("=" * 70)
print("Phase 1: Baseline Comparison")
print("=" * 70)
gate = build_conversation()
conversation = [
{'user': redis_topics[i][0], 'assistant': redis_topics[i][1]}
for i in range(25)
] + [
{'user': asyncio_topics[i][0], 'assistant': asyncio_topics[i][1]}
for i in range(25)
] + [
{'user': pg_topics[i][0], 'assistant': pg_topics[i][1]}
for i in range(25)
] + [
{'user': git_topics[i][0], 'assistant': git_topics[i][1]}
for i in range(25)
]
methods = {
'Last-3': BaselineLastN(3),
'Last-5': BaselineLastN(5),
'Last-10': BaselineLastN(10),
'BM25-5': BaselineBM25(5),
'Full CGK': gate, # special handling
}
results = {name: [] for name in methods}
for label, query, target_topic in TEST_SEQUENCE:
# Full CGK
cgk_selected = gate.select(query)
cgk_stats = measure_context_stats(cgk_selected)
cgk_contamination = evaluate_contamination(cgk_selected, target_topic)
results['Full CGK'].append({
'label': label,
'query': query,
'target_topic': target_topic,
'context_tokens': cgk_stats['context_tokens'],
'prompt_tokens': cgk_stats['prompt_tokens'],
'num_blocks': cgk_stats['num_blocks'],
'is_contaminated': cgk_contamination['is_contaminated'],
'other_topics': cgk_contamination['other_topics_found']
})
# Baseline methods
for name, method in methods.items():
if name == 'Full CGK':
continue
selected = method.select(conversation, query)
stats = measure_context_stats(selected)
contamination = evaluate_contamination(selected, target_topic)
results[name].append({
'label': label,
'query': query,
'target_topic': target_topic,
'context_tokens': stats['context_tokens'],
'prompt_tokens': stats['prompt_tokens'],
'num_blocks': stats['num_blocks'],
'is_contaminated': contamination['is_contaminated'],
'other_topics': contamination['other_topics_found']
})
print(f"\n[{label}] {query}")
print(f" Full CGK: {cgk_stats['prompt_tokens']} prompt tokens, "
f"污染={cgk_contamination['is_contaminated']}, "
f"块数={cgk_stats['num_blocks']}")
for name in methods:
if name == 'Full CGK':
continue
r = results[name][-1]
print(f" {name}: {r['prompt_tokens']} prompt tokens, "
f"污染={r['is_contaminated']}, 块数={r['num_blocks']}")
return results
def summarize_results(results: Dict) -> None:
"""打印汇总表格"""
print("\n" + "=" * 70)
print("Summary (averaged over {} queries)".format(len(TEST_SEQUENCE)))
print("=" * 70)
for name, data in results.items():
if not data:
continue
avg_prompt_tokens = sum(d['prompt_tokens'] for d in data) / len(data)
avg_context_tokens = sum(d['context_tokens'] for d in data) / len(data)
contamination_rate = sum(1 for d in data if d['is_contaminated']) / len(data) * 100
avg_blocks = sum(d['num_blocks'] for d in data) / len(data)
print(f"\n{name}:")
print(f" Avg prompt tokens: {avg_prompt_tokens:.1f}")
print(f" Avg context tokens: {avg_context_tokens:.1f}")
print(f" Contamination rate: {contamination_rate:.1f}%")
print(f" Avg blocks: {avg_blocks:.1f}")
# Full CGK vs Last-5 comparison
if 'Full CGK' in results and 'Last-5' in results:
cgk_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
last5_avg = sum(d['prompt_tokens'] for d in results['Last-5']) / len(results['Last-5'])
saving = (last5_avg - cgk_avg) / last5_avg * 100
print(f"\nFull CGK vs Last-5:")
print(f" CGK: {cgk_avg:.1f} tokens/prompt")
print(f" Last-5: {last5_avg:.1f} tokens/prompt")
print(f" Saving: {saving:.1f}% (CGK 更少)")
def run_ablation_study():
"""Phase 2: Ablation Study"""
print("\n" + "=" * 70)
print("Phase 2: Ablation Study")
print("=" * 70)
gate = build_conversation()
# 定义 ablated versions
ablations = {
'Full CGK': lambda q: gate.select(q),
}
# Ablation 1: 无指代词规则
orig_extract = gate.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False
gate.anchor_extractor.extract_with_deictic = no_deictic
ablations['-Deictic'] = lambda q: gate.select(q)
gate.anchor_extractor.extract_with_deictic = orig_extract
results = {name: [] for name in ablations}
for label, query, target_topic in TEST_SEQUENCE:
for name, fn in ablations.items():
if name == 'Full CGK':
selected = fn(query)
else:
# re-run with ablated config
if name == '-Deictic':
orig_extract = gate.anchor_extractor.extract_with_deictic
gate.anchor_extractor.extract_with_deictic = no_deictic
selected = gate.select(query)
gate.anchor_extractor.extract_with_deictic = orig_extract
stats = measure_context_stats(selected)
contamination = evaluate_contamination(selected, target_topic)
results[name].append({
'label': label,
'query': query,
'target_topic': target_topic,
'prompt_tokens': stats['prompt_tokens'],
'is_contaminated': contamination['is_contaminated']
})
print(f"\n[{label}] {query[:40]}...")
for name in ablations:
r = results[name][-1]
print(f" {name}: {r['prompt_tokens']} tokens, 污染={r['is_contaminated']}")
# Ablation summary
print("\n" + "=" * 70)
print("Ablation Summary")
print("=" * 70)
full_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
for name in ablations:
if name == 'Full CGK':
continue
avg = sum(d['prompt_tokens'] for d in results[name]) / len(results[name])
diff = avg - full_avg
print(f"{name}: {avg:.1f} tokens (vs Full: {diff:+.1f})")
return results
if __name__ == '__main__':
results = run_baseline_comparison()
summarize_results(results)
ablation_results = run_ablation_study()
# Save all results
output = {
'baseline': {k: v for k, v in results.items()},
'ablation': {k: v for k, v in ablation_results.items()}
}
output_path = '/root/.openclaw/workspace/context-gatekeeper/experiments/phase1_2_results.json'
with open(output_path, 'w') as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {output_path}")