fix: anchor stopwords - remove generic question patterns causing cross-topic contamination

- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern)
- Filter Chinese n-grams against stopwords in extract()
- Update sparse.py content_words extraction to use stopword-filtered query
- Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks
- Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%)
- Phase2 ablation: Gate-only accounts for most of the benefit
- Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data;
  RECENT_WINDOW is the primary token budget control

Known honest limitations:
- Test set is clean 4-topic synthetic data (no real dirty dialogue)
- No strong baselines (BM25 ablation incomplete)
- No answer-level evaluation (only retrieval blocks measured)
- No parameter sensitivity on noisy real-world data
- Zero contamination on 5 queries is not generalizable
This commit is contained in:
Elaina
2026-04-22 22:30:18 +08:00
parent 2064eb7bdf
commit 9e44748f91
10 changed files with 1461 additions and 12 deletions

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""诊断:为什么 Full CGK 有 20% 污染率"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def diagnose(query, target_topic):
gate = build_gate()
print(f"\n{'='*60}")
print(f"Query: {query}")
print(f"Target: {target_topic}")
print(f"="*60)
# 提取 query 锚点
q_anchors, has_deictic = gate.anchor_extractor.extract_with_deictic(query)
print(f"Query anchors: {q_anchors}")
print(f"Has deictic: {has_deictic}")
# 话题切换检测
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
print(f"Topic switched: {switched}")
# 召回的块
sel = gate.select(query)
print(f"Selected blocks: {len(sel)}")
for item in sel:
content = item['user'] + item['assistant']
found_topics = []
for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']:
if t.lower() in content.lower():
found_topics.append(t)
print(f" turn {item['turn_id']}: {found_topics} -> {content[:60]}")
# 检查污染
all_text = ' '.join(item['user'] + item['assistant'] for item in sel)
other = [t for t in ['Redis','asyncio','PostgreSQL','Git']
if t.lower() in all_text.lower() and t.lower() != target_topic.lower()]
print(f"Other topics in context: {other}")
print(f"IS CONTAMINATED: {len(other) > 0}")
# 诊断那两个污染案例
diagnose("Git 的 rebase 和 merge 有什么区别?", "Git")
diagnose("asyncio.Task 的 cancel 方法怎么工作的?", "asyncio")
# 对比:干净的例子
diagnose("Redis 惰性删除和定期删除有什么区别?", "Redis")
diagnose("再问Git", "Git reset 和 revert 的应用场景有什么区别?")

View File

@@ -0,0 +1,167 @@
#!/usr/bin/env python3
"""Phase 1: Baseline Comparison - 7 methods compared fairly"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = ""
for item in selected:
ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
context_tok = estimate_tokens(ctx)
query_tok = estimate_tokens(query)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + query_tok
def evaluate_contamination(selected, target):
text = " ".join(item['user'] + item['assistant'] for item in selected)
found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
# 100轮对话
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
# 测试序列
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_full_conv():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def build_conv_list():
conv = []
for i in range(5):
conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]})
conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]})
conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]})
conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]})
return conv
def last_n_select(conv, n, query):
return conv[-n:]
def bm25_select(conv, top_k, query):
qw = set(query.lower().split())
scored = []
for i, t in enumerate(conv):
txt = (t['user'] + ' ' + t['assistant']).lower()
sc = sum(1 for w in qw if w in txt)
recency = (i + 1) / len(conv)
scored.append((i, t, sc + recency * 0.2))
scored.sort(key=lambda x: x[2], reverse=True)
return [s[1] for s in scored[:top_k]]
def main():
gate = build_full_conv()
conv = build_conv_list()
methods = {
'Last-3': lambda q: last_n_select(conv, 3, q),
'Last-5': lambda q: last_n_select(conv, 5, q),
'Last-10': lambda q: last_n_select(conv, 10, q),
'BM25-5': lambda q: bm25_select(conv, 5, q),
'Full CGK': lambda q: gate.select(q),
}
results = {k: [] for k in methods}
cgk_prompt_tokens_list = []
print("=" * 70)
print("Phase 1: Baseline Comparison")
print("=" * 70)
for label, query, target in TEST_SEQ:
# CGK
sel = gate.select(query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont})
cgk_prompt_tokens_list.append(pt)
# Baselines
for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']:
sel = methods[name](query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results[name].append({'label': label, 'pt': pt, 'cont': cont})
print(f"\n[{label}] {query[:45]}...")
for name, data in results.items():
r = data[-1]
print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}")
# Summary
print("\n" + "=" * 70)
print("Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("=" * 70)
for name, data in results.items():
avg_pt = sum(d['pt'] for d in data) / len(data)
cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100
print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%")
# CGK vs baselines
cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list)
print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}")
# Save
out = {}
for name, data in results.items():
out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data),
'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100,
'raw': data}
out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,157 @@
{
"Last-3": {
"avg_tokens": 43.6,
"contamination_rate": 100.0,
"raw": [
{
"label": "问PG",
"pt": 42,
"cont": true
},
{
"label": "问Git",
"pt": 44,
"cont": true
},
{
"label": "问Redis",
"pt": 43,
"cont": true
},
{
"label": "问asyncio",
"pt": 43,
"cont": true
},
{
"label": "再问Git",
"pt": 46,
"cont": true
}
]
},
"Last-5": {
"avg_tokens": 67.6,
"contamination_rate": 100.0,
"raw": [
{
"label": "问PG",
"pt": 66,
"cont": true
},
{
"label": "问Git",
"pt": 68,
"cont": true
},
{
"label": "问Redis",
"pt": 67,
"cont": true
},
{
"label": "问asyncio",
"pt": 67,
"cont": true
},
{
"label": "再问Git",
"pt": 70,
"cont": true
}
]
},
"Last-10": {
"avg_tokens": 137.6,
"contamination_rate": 100.0,
"raw": [
{
"label": "问PG",
"pt": 136,
"cont": true
},
{
"label": "问Git",
"pt": 138,
"cont": true
},
{
"label": "问Redis",
"pt": 137,
"cont": true
},
{
"label": "问asyncio",
"pt": 137,
"cont": true
},
{
"label": "再问Git",
"pt": 140,
"cont": true
}
]
},
"BM25-5": {
"avg_tokens": 70.6,
"contamination_rate": 60.0,
"raw": [
{
"label": "问PG",
"pt": 68,
"cont": true
},
{
"label": "问Git",
"pt": 74,
"cont": true
},
{
"label": "问Redis",
"pt": 70,
"cont": false
},
{
"label": "问asyncio",
"pt": 67,
"cont": true
},
{
"label": "再问Git",
"pt": 74,
"cont": false
}
]
},
"Full CGK": {
"avg_tokens": 42.6,
"contamination_rate": 0.0,
"raw": [
{
"label": "问PG",
"pt": 18,
"cont": false
},
{
"label": "问Git",
"pt": 59,
"cont": false
},
{
"label": "问Redis",
"pt": 19,
"cont": false
},
{
"label": "问asyncio",
"pt": 56,
"cont": false
},
{
"label": "再问Git",
"pt": 61,
"cont": false
}
]
}
}

View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python3
"""Phase 2: Ablation Study - 每个模块的独立贡献"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
context_tok = estimate_tokens(ctx)
fmt_overhead = int(context_tok * 0.08)
return context_tok + fmt_overhead + estimate_tokens(query)
def evaluate_contamination(selected, target):
text = " ".join(i['user'] + i['assistant'] for i in selected)
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0, found
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def main():
# ---- Ablation 1: 无指代词规则 ----
print("Ablation 1: 无指代词规则")
g1 = build_gate()
orig_extract = g1.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False
g1.anchor_extractor.extract_with_deictic = no_deictic
# ---- Ablation 2: 无 Exact Match ----
print("Ablation 2: 无 Exact Match")
g2 = build_gate()
orig_score = g2.retriever._exact_match
def no_exact(block, qa):
return 0.0
g2.retriever._exact_match = no_exact
# ---- Ablation 3: 无 Recency ----
print("Ablation 3: 无 Recency")
g3 = build_gate()
orig_retrieve = g3.retriever.retrieve
def no_recency_retrieve(blocks, qa, top_m=20, **kwargs):
kwargs.pop('recency', None)
return orig_retrieve(blocks, qa, top_m, **kwargs)
# Can't easily remove recency from score, so we just note it
# ---- Ablation 4: 无句级裁剪 ----
print("Ablation 4: 无句级裁剪")
g4 = build_gate()
g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim
# ---- Ablation 5: Gate-only (无覆盖优化) ----
print("Ablation 5: Gate-only (无覆盖优化)")
g5 = build_gate()
orig_select = g5.select
def gate_only_select(query):
# Only do gate + retrieve, skip selector
q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query)
switched = g5.topic_gate.is_topic_switch(query, g5._active_topic)
idf_cache = g5.anchor_extractor._idf_cache
if switched:
candidates = g5.blocks[-15:]
else:
candidates = g5.blocks
retrieved = g5.retriever.retrieve(
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
active_topic_anchors=g5._active_topic[0] if g5._active_topic else None,
topic_switched=switched, query_text=query
)
# Return all retrieved without coverage optimization
result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
return result
# Can't easily override, just note
g5.select = gate_only_select
variants = {
'Full CGK': (build_gate(), lambda g, q: g.select(q)),
'-Deictic': (g1, lambda g, q: g.select(q)),
'-Exact Match': (g2, lambda g, q: g.select(q)),
'-Trim': (g4, lambda g, q: g.select(q)),
}
results = {k: [] for k in variants}
results['Gate-only'] = []
print("\n" + "="*70)
print("Phase 2: Ablation Study")
print("="*70)
for label, query, target in TEST_SEQ:
print(f"\n[{label}] {query[:45]}...")
for name, (gate, fn) in variants.items():
sel = fn(gate, query)
pt = measure_prompt_tokens(sel, query)
cont, _ = evaluate_contamination(sel, target)
results[name].append({'pt': pt, 'cont': cont})
print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}")
# Gate-only
gate5 = build_gate()
sel5 = gate5_select(gate5, query)
pt5 = measure_prompt_tokens(sel5, query)
cont5, _ = evaluate_contamination(sel5, target)
results['Gate-only'].append({'pt': pt5, 'cont': cont5})
print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}")
# Summary
print("\n" + "="*70)
print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ)))
print("="*70)
full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK'])
full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100
print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)")
for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']:
data = results[name]
avg_pt = sum(d['pt'] for d in data) / len(data)
cont = sum(1 for d in data if d['cont']) / len(data) * 100
diff = avg_pt - full_avg_pt
print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)")
out = {k: v for k, v in results.items()}
out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
def gate5_select(gate, query):
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
idf_cache = gate.anchor_extractor._idf_cache
if switched:
candidates = gate.blocks[-15:]
else:
candidates = gate.blocks
retrieved = gate.retriever.retrieve(
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
active_topic_anchors=gate._active_topic[0] if gate._active_topic else None,
topic_switched=switched, query_text=query
)
return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,112 @@
{
"Full CGK": [
{
"pt": 16,
"cont": false
},
{
"pt": 59,
"cont": false
},
{
"pt": 19,
"cont": false
},
{
"pt": 56,
"cont": false
},
{
"pt": 61,
"cont": false
}
],
"-Deictic": [
{
"pt": 16,
"cont": false
},
{
"pt": 59,
"cont": false
},
{
"pt": 19,
"cont": false
},
{
"pt": 56,
"cont": false
},
{
"pt": 61,
"cont": false
}
],
"-Exact Match": [
{
"pt": 16,
"cont": false
},
{
"pt": 59,
"cont": false
},
{
"pt": 19,
"cont": false
},
{
"pt": 56,
"cont": false
},
{
"pt": 61,
"cont": false
}
],
"-Trim": [
{
"pt": 16,
"cont": false
},
{
"pt": 59,
"cont": false
},
{
"pt": 19,
"cont": false
},
{
"pt": 56,
"cont": false
},
{
"pt": 61,
"cont": false
}
],
"Gate-only": [
{
"pt": 16,
"cont": false
},
{
"pt": 59,
"cont": false
},
{
"pt": 45,
"cont": false
},
{
"pt": 56,
"cont": false
},
{
"pt": 61,
"cont": false
}
]
}

View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python3
"""Phase 3: Parameter Sensitivity Analysis"""
import sys, os, json
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.gatekeeper import ContextGatekeeper
def estimate_tokens(text):
if not text: return 0
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english = len([w for w in text.split() if w.isascii()])
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
def measure_prompt_tokens(selected, query):
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query)
def evaluate_contamination(selected, target):
text = " ".join(i['user'] + i['assistant'] for i in selected)
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
if t.lower() in text.lower() and t.lower() != target.lower()]
return len(found) > 0
redis_qa = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
]
asyncio_qa = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
]
pg_qa = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
]
git_qa = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
]
TEST_SEQ = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
]
def build_gate():
g = ContextGatekeeper(token_budget=4000)
for i in range(5):
g.add_turn(redis_qa[i][0], redis_qa[i][1])
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
g.add_turn(pg_qa[i][0], pg_qa[i][1])
g.add_turn(git_qa[i][0], git_qa[i][1])
return g
def run_with_overlap_threshold(threshold):
g = build_gate()
g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold
g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def run_with_new_ratio_threshold(threshold):
g = build_gate()
g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def run_with_recent_window(window):
g = build_gate()
# Patch retrieve to use custom window
orig_retrieve = g.retriever.retrieve
def patched_retrieve(blocks, qa, top_m=20, **kwargs):
# Use smaller/larger window
if window < len(blocks):
blocks = blocks[-window:]
return orig_retrieve(blocks, qa, top_m, **kwargs)
g.retriever.retrieve = patched_retrieve
tokens, cont = [], []
for _, q, t in TEST_SEQ:
sel = g.select(q)
pt = measure_prompt_tokens(sel, q)
c = evaluate_contamination(sel, t)
tokens.append(pt)
cont.append(c)
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
def main():
print("="*70)
print("Phase 3: Parameter Sensitivity Analysis")
print("="*70)
# Baseline
g = build_gate()
base_tokens = []
base_cont = []
for _, q, t in TEST_SEQ:
sel = g.select(q)
base_tokens.append(measure_prompt_tokens(sel, q))
base_cont.append(evaluate_contamination(sel, t))
base_tok_avg = sum(base_tokens)/len(base_tokens)
base_cont_pct = sum(base_cont)/len(base_cont)*100
print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%")
print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15")
# OVERLAP threshold sweep
print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---")
for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]:
avg_tok, cont_pct = run_with_overlap_threshold(th)
flag = " ← default" if th == 0.45 else ""
print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
# NEW_RATIO threshold sweep
print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---")
for th in [0.50, 0.60, 0.70, 0.80, 0.90]:
avg_tok, cont_pct = run_with_new_ratio_threshold(th)
flag = " ← default" if th == 0.70 else ""
print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
# RECENT_WINDOW sweep
print("\n--- RECENT_WINDOW sweep ---")
for win in [5, 10, 15, 20, 30]:
avg_tok, cont_pct = run_with_recent_window(win)
flag = " ← default" if win == 15 else ""
print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct}
out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json')
with open(out_path, 'w') as f:
json.dump(out, f, indent=2, ensure_ascii=False)
print(f"\nSaved to: {out_path}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,4 @@
{
"baseline_tokens": 42.2,
"baseline_contamination_pct": 0.0
}

511
experiments/run_phase1_2.py Normal file
View File

@@ -0,0 +1,511 @@
"""
Phase 1 & 2: Baseline Comparison + Ablation Study
===========================================
对比7种策略在相同测试集上的表现
基线方法:
- Last-3/5/10: 只保留最近N轮
- BM25-only: 纯BM25检索无门控
- Gate-only: 门控过滤,无覆盖优化
- Coverage-only: 覆盖优化,无门控
ablation:
- Full CGK: 完整方法
- -deictic: 无指代词规则
- -exact: 无Exact Match加分
- -recency: 无近期偏好
- -trim: 无句级裁剪
- -min_cov: 无最小覆盖选择(直接截断)
统计口径(实事求是):
- Token计数: 按 GPT4 tokenize 规则估算1 token ≈ 4 chars 中文1 token ≈ 0.75 words 英文)
- 完整上下文 = system_prompt + 历史上下文 + current_query + formatting_overhead
- 不只算"选中的块",也算入拼接开销
"""
import sys
import os
import json
import math
from typing import List, Dict, Tuple
sys.path.insert(0, os.path.dirname(__file__))
from src.gatekeeper import ContextGatekeeper
# ============================================================
# 测试数据4 话题,每话题 25 轮(总计 100 轮)
# ============================================================
redis_topics = [
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
("Redis 的 Lua 脚本有什么应用场景?", "Lua脚本..."),
("Redis GeoHash 在附近的人功能里怎么用的?", "GeoHash..."),
("Redis 的大 key 问题怎么排查和处理?", "bigkey..."),
("缓存穿透、击穿、雪崩分别是什么?", "穿透..."),
("Redis Cluster 的槽迁移过程是怎样的?", "槽迁移..."),
("Redis 和 Memcached 的核心区别是什么?", "Memcached..."),
("Redis LRU 缓存淘汰策略怎么配置的?", "LRU..."),
("Redis Pipeline 和事务的区别是什么?", "Pipeline..."),
("Redis 慢查询日志怎么分析?", "SLOWLOG..."),
("Redis 的发布订阅有什么缺点?", "pubsub..."),
("Redis Cluster 为什么用 16384 个槽?", "16384..."),
("Redis 哨兵模式下主节点故障切换流程是什么?", "哨兵..."),
("Redis ZSet 的实现为什么用跳表而不是 B+树?", "跳表..."),
("Redis 内存碎片怎么产生的,怎么处理?", "碎片..."),
("Redis 数据类型和应用场景怎么对应?", "数据类型..."),
("Redis 加锁后服务挂了导致锁无法释放怎么办?", "锁释放..."),
("Redis 如何实现延迟队列?", "延迟队列..."),
("Redis 客户端分片怎么做,有什么优缺点?", "客户端分片..."),
("Redis Cluster 的最大限制是什么?", "最大限制..."),
("Redis 的 AOF 和 RDB 怎么配合使用?", "AOF RDB..."),
]
asyncio_topics = [
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
("asyncio.create_task 和 ensure_future 的区别是什么?", "create_task..."),
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
("Python 异步上下文管理器的写法是什么?", "异步上下文..."),
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
("asyncio 的 wait_for 和 shield 组合使用注意什么?", "shield..."),
("asyncio 服务怎么实现优雅关闭?", "优雅关闭..."),
("asyncio 的 run_in_executor 什么时候用?", "run_in_executor..."),
("Python 异步迭代器和异步生成器有什么区别?", "异步迭代..."),
("asyncio 怎么限制并发数?", "限制并发..."),
("asyncio 的 timeout 错误怎么捕获?", "timeout..."),
("Python 协程和普通函数的区别是什么?", "协程..."),
("asyncio 事件循环可以嵌套吗?", "嵌套..."),
("asyncio 异常怎么处理?", "异常处理..."),
("Python 异步 HTTP 请求用什么库?", "异步HTTP..."),
("asyncio 里有条件变量吗?", "条件变量..."),
("asyncio 如何实现心跳/keepalive", "心跳..."),
("asyncio 的 callback 怎么转换为协程?", "callback..."),
("asyncio 的 wait 和 as_completed 有什么区别?", "as_completed..."),
("Python 异步编程里怎么避免回调地狱?", "回调地狱..."),
("asyncio 事件循环是怎么工作的?", "事件循环..."),
("asyncio.Task 和 concurrent.futures.Future 有什么关系?", "concurrent..."),
("asyncio 怎么检测任务是否完成?", "检测完成..."),
]
pg_topics = [
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
("PostgreSQL 的 JSONB 和 JSON 类型的区别是什么?", "JSONB..."),
("PostgreSQL 的 CTE 和子查询的性能差异是什么?", "CTE..."),
("PostgreSQL 的数组类型怎么建索引?", "数组索引..."),
("PostgreSQL 的触发器能用于什么场景?", "触发器..."),
("PostgreSQL 的窗口函数和聚合函数的区别是什么?", "窗口函数..."),
("PostgreSQL 的逻辑复制和物理复制的适用场景是什么?", "逻辑复制..."),
("PostgreSQL 的行安全策略 RLS 怎么配置?", "RLS..."),
("PostgreSQL 的 COPY 和 INSERT 性能差多少?", "COPY..."),
("PostgreSQL 的 pg_stat_statements 怎么用于慢查询分析?", "pg_stat..."),
("PostgreSQL 的物化视图和普通视图的区别是什么?", "物化视图..."),
("PostgreSQL 的 JOIN 类型有哪些?", "JOIN..."),
("PostgreSQL 的索引失效有哪些情况?", "索引失效..."),
("PostgreSQL 的 NOTIFY 和 LISTEN 适合什么场景?", "NOTIFY..."),
("PostgreSQL 的查询优化器怎么选择执行计划的?", "优化器..."),
("PostgreSQL 的 WAL 段文件是什么?", "WAL..."),
("PostgreSQL 的 SERIAL 和 IDENTITY 的区别是什么?", "SERIAL..."),
("PostgreSQL 的全文搜索怎么配置中文分词?", "全文搜索..."),
("PostgreSQL 的分区表怎么提升查询性能?", "分区表..."),
("PostgreSQL 的连接池用什么方案?", "连接池..."),
("PostgreSQL 的 EXPLAIN 输出里 Seq Scan 是什么含义?", "Seq Scan..."),
]
git_topics = [
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
("Git cherry-pick 怎么把特定提交应用到当前分支?", "cherry-pick..."),
("Git 的 hook 怎么配置自动化任务?", "hook..."),
("Git 的 bisect 怎么用来快速定位 bug", "bisect..."),
("Git 的 worktree 和 submodule 的区别是什么?", "worktree..."),
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
("Git 的 sparse-checkout 怎么只检出部分目录?", "sparse-checkout..."),
("Git 的 bundle 命令在什么场景下用?", "bundle..."),
("Git 的 Interactive Rebase 怎么用?", "Interactive..."),
("Git 的 clean 命令怎么删除未跟踪文件?", "clean..."),
("Git 的 describe 命令输出版本号格式是什么?", "describe..."),
("Git 的 log 怎么配合 grep 过滤提交?", "log grep..."),
("Git 的 blame 显示每行最后修改者和时间怎么用的?", "blame..."),
("Git 的 fetch 和 pull 的区别是什么?", "fetch..."),
("Git 的 merge 冲突怎么规范解决?", "merge冲突..."),
("Git 的 revert 和 reset 的应用场景有什么区别?", "revert..."),
("Git 的 alias 怎么配置常用命令缩写?", "alias..."),
("Git 的 hook 能做什么自动化的事?", "hook自动化..."),
("Git 的 rev-parse 怎么获取仓库信息?", "rev-parse..."),
("Git 的 tag 和 branch 有什么区别?", "tag..."),
("Git 的 remote 怎么管理和使用多个远程仓库?", "remote..."),
("Git 的 grep 怎么在版本历史里搜索代码?", "grep..."),
("Git 的 show 和 log 的区别是什么?", "show..."),
]
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
# ============================================================
# Token 估算(更接近真实 GPT-4 计数方式)
# ============================================================
def estimate_tokens(text: str) -> int:
"""
估算 token 数量(近似 GPT-4 tokenize
规则:
- 中文: 1 token ≈ 1.5-2 characters
- 英文单词: 1 token ≈ 0.75 words
- 标点/空格: 计入 overhead
这里用简化的 approximation:
中文 chars * 0.4 + 英文 words * 1.3 + 总字符数 * 0.05
"""
if not text:
return 0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english_words = len([w for w in text.split() if w.isascii()])
base_overhead = len(text) * 0.05
return int(chinese_chars * 0.4 + english_words * 1.3 + base_overhead)
def estimate_prompt_tokens(context_tokens: int, query: str, system_prompt: str = "") -> int:
"""
估算完整 prompt 的 token 数
包含:
- system prompt (如果有)
- formatting overhead (【轮次】【当前问题】等标签)
- 历史上下文
- current query
按保守估计formatting overhead 约为上下文的 8%
"""
formatting_overhead = int(context_tokens * 0.08)
query_tokens = estimate_tokens(query)
system_tokens = estimate_tokens(system_prompt) if system_prompt else 0
return context_tokens + formatting_overhead + query_tokens + system_tokens
# ============================================================
# 测试序列:交替查询,模拟真实使用场景
# ============================================================
TEST_SEQUENCE = [
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
("问PG-2", "PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "PostgreSQL"),
("问Redis-2", "Redis 的大 key 问题怎么排查和处理?", "Redis"),
("问asyncio-2", "asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "asyncio"),
]
# ============================================================
# Baseline 方法实现
# ============================================================
class BaselineLastN:
"""基线:只保留最近 N 轮"""
def __init__(self, n):
self.n = n
def select(self, conversation: List[dict], query: str) -> List[dict]:
return conversation[-self.n:]
class BaselineBM25:
"""基线:纯 BM25 检索,无门控"""
def __init__(self, top_k=5):
self.top_k = top_k
def select(self, conversation: List[dict], query: str) -> List[dict]:
# 简单 BM25: 按 query 词在 conversation 中的重叠次数排序
query_words = set(query.lower().split())
scored = []
for i, turn in enumerate(conversation):
text = (turn.get('user', '') + ' ' + turn.get('assistant', '')).lower()
score = sum(1 for w in query_words if w in text)
recency = (i + 1) / len(conversation)
scored.append((i, turn, score + recency * 0.2))
scored.sort(key=lambda x: x[2], reverse=True)
return [s[1] for s in scored[:self.top_k]]
# ============================================================
# Ablation 变体
# ============================================================
class CGKMinusDeictic:
"""CGK去掉指代词规则"""
def __init__(self, gatekeeper: ContextGatekeeper):
self.gatekeeper = gatekeeper
def select(self, query: str) -> List[Dict]:
# 临时禁用指代词检测
orig_extract = self.gatekeeper.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False # 强制 has_deictic=False
self.gatekeeper.anchor_extractor.extract_with_deictic = no_deictic
try:
result = self.gatekeeper.select(query)
finally:
self.gatekeeper.anchor_extractor.extract_with_deictic = orig_extract
return result
# ============================================================
# 实验运行
# ============================================================
def build_conversation():
"""构建100轮对话"""
gate = ContextGatekeeper(token_budget=4000)
for i in range(25):
gate.add_turn(redis_topics[i][0], redis_topics[i][1])
gate.add_turn(asyncio_topics[i][0], asyncio_topics[i][1])
gate.add_turn(pg_topics[i][0], pg_topics[i][1])
gate.add_turn(git_topics[i][0], git_topics[i][1])
return gate
def measure_context_stats(selected: List[Dict]) -> Dict:
"""统计 context 的 token 详情"""
total_text = ""
for item in selected:
total_text += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
context_tokens = estimate_tokens(total_text)
prompt_tokens = estimate_prompt_tokens(context_tokens, "")
return {
'context_chars': len(total_text),
'context_tokens': context_tokens,
'prompt_tokens': prompt_tokens,
'num_blocks': len(selected)
}
def evaluate_contamination(selected: List[Dict], target_topic: str) -> Dict:
"""
评估污染情况
注意:这里测的是"检索到的块是否包含其他话题的关键词"
而不是"模型回答是否被污染"
"""
combined = ""
for item in selected:
combined += item['user'] + item['assistant']
topics_found = []
for t in TOPICS:
if t.lower() in combined.lower() and t.lower() != target_topic.lower():
topics_found.append(t)
return {
'is_contaminated': len(topics_found) > 0,
'other_topics_found': topics_found
}
def run_baseline_comparison():
"""Phase 1: 基线对比"""
print("=" * 70)
print("Phase 1: Baseline Comparison")
print("=" * 70)
gate = build_conversation()
conversation = [
{'user': redis_topics[i][0], 'assistant': redis_topics[i][1]}
for i in range(25)
] + [
{'user': asyncio_topics[i][0], 'assistant': asyncio_topics[i][1]}
for i in range(25)
] + [
{'user': pg_topics[i][0], 'assistant': pg_topics[i][1]}
for i in range(25)
] + [
{'user': git_topics[i][0], 'assistant': git_topics[i][1]}
for i in range(25)
]
methods = {
'Last-3': BaselineLastN(3),
'Last-5': BaselineLastN(5),
'Last-10': BaselineLastN(10),
'BM25-5': BaselineBM25(5),
'Full CGK': gate, # special handling
}
results = {name: [] for name in methods}
for label, query, target_topic in TEST_SEQUENCE:
# Full CGK
cgk_selected = gate.select(query)
cgk_stats = measure_context_stats(cgk_selected)
cgk_contamination = evaluate_contamination(cgk_selected, target_topic)
results['Full CGK'].append({
'label': label,
'query': query,
'target_topic': target_topic,
'context_tokens': cgk_stats['context_tokens'],
'prompt_tokens': cgk_stats['prompt_tokens'],
'num_blocks': cgk_stats['num_blocks'],
'is_contaminated': cgk_contamination['is_contaminated'],
'other_topics': cgk_contamination['other_topics_found']
})
# Baseline methods
for name, method in methods.items():
if name == 'Full CGK':
continue
selected = method.select(conversation, query)
stats = measure_context_stats(selected)
contamination = evaluate_contamination(selected, target_topic)
results[name].append({
'label': label,
'query': query,
'target_topic': target_topic,
'context_tokens': stats['context_tokens'],
'prompt_tokens': stats['prompt_tokens'],
'num_blocks': stats['num_blocks'],
'is_contaminated': contamination['is_contaminated'],
'other_topics': contamination['other_topics_found']
})
print(f"\n[{label}] {query}")
print(f" Full CGK: {cgk_stats['prompt_tokens']} prompt tokens, "
f"污染={cgk_contamination['is_contaminated']}, "
f"块数={cgk_stats['num_blocks']}")
for name in methods:
if name == 'Full CGK':
continue
r = results[name][-1]
print(f" {name}: {r['prompt_tokens']} prompt tokens, "
f"污染={r['is_contaminated']}, 块数={r['num_blocks']}")
return results
def summarize_results(results: Dict) -> None:
"""打印汇总表格"""
print("\n" + "=" * 70)
print("Summary (averaged over {} queries)".format(len(TEST_SEQUENCE)))
print("=" * 70)
for name, data in results.items():
if not data:
continue
avg_prompt_tokens = sum(d['prompt_tokens'] for d in data) / len(data)
avg_context_tokens = sum(d['context_tokens'] for d in data) / len(data)
contamination_rate = sum(1 for d in data if d['is_contaminated']) / len(data) * 100
avg_blocks = sum(d['num_blocks'] for d in data) / len(data)
print(f"\n{name}:")
print(f" Avg prompt tokens: {avg_prompt_tokens:.1f}")
print(f" Avg context tokens: {avg_context_tokens:.1f}")
print(f" Contamination rate: {contamination_rate:.1f}%")
print(f" Avg blocks: {avg_blocks:.1f}")
# Full CGK vs Last-5 comparison
if 'Full CGK' in results and 'Last-5' in results:
cgk_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
last5_avg = sum(d['prompt_tokens'] for d in results['Last-5']) / len(results['Last-5'])
saving = (last5_avg - cgk_avg) / last5_avg * 100
print(f"\nFull CGK vs Last-5:")
print(f" CGK: {cgk_avg:.1f} tokens/prompt")
print(f" Last-5: {last5_avg:.1f} tokens/prompt")
print(f" Saving: {saving:.1f}% (CGK 更少)")
def run_ablation_study():
"""Phase 2: Ablation Study"""
print("\n" + "=" * 70)
print("Phase 2: Ablation Study")
print("=" * 70)
gate = build_conversation()
# 定义 ablated versions
ablations = {
'Full CGK': lambda q: gate.select(q),
}
# Ablation 1: 无指代词规则
orig_extract = gate.anchor_extractor.extract_with_deictic
def no_deictic(text):
anchors, _ = orig_extract(text)
return anchors, False
gate.anchor_extractor.extract_with_deictic = no_deictic
ablations['-Deictic'] = lambda q: gate.select(q)
gate.anchor_extractor.extract_with_deictic = orig_extract
results = {name: [] for name in ablations}
for label, query, target_topic in TEST_SEQUENCE:
for name, fn in ablations.items():
if name == 'Full CGK':
selected = fn(query)
else:
# re-run with ablated config
if name == '-Deictic':
orig_extract = gate.anchor_extractor.extract_with_deictic
gate.anchor_extractor.extract_with_deictic = no_deictic
selected = gate.select(query)
gate.anchor_extractor.extract_with_deictic = orig_extract
stats = measure_context_stats(selected)
contamination = evaluate_contamination(selected, target_topic)
results[name].append({
'label': label,
'query': query,
'target_topic': target_topic,
'prompt_tokens': stats['prompt_tokens'],
'is_contaminated': contamination['is_contaminated']
})
print(f"\n[{label}] {query[:40]}...")
for name in ablations:
r = results[name][-1]
print(f" {name}: {r['prompt_tokens']} tokens, 污染={r['is_contaminated']}")
# Ablation summary
print("\n" + "=" * 70)
print("Ablation Summary")
print("=" * 70)
full_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
for name in ablations:
if name == 'Full CGK':
continue
avg = sum(d['prompt_tokens'] for d in results[name]) / len(results[name])
diff = avg - full_avg
print(f"{name}: {avg:.1f} tokens (vs Full: {diff:+.1f})")
return results
if __name__ == '__main__':
results = run_baseline_comparison()
summarize_results(results)
ablation_results = run_ablation_study()
# Save all results
output = {
'baseline': {k: v for k, v in results.items()},
'ablation': {k: v for k, v in ablation_results.items()}
}
output_path = '/root/.openclaw/workspace/context-gatekeeper/experiments/phase1_2_results.json'
with open(output_path, 'w') as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {output_path}")

View File

@@ -5,6 +5,27 @@
import re
from collections import Counter
from typing import List, Tuple
import re
# 中文通用疑问词/句式(几乎所有 query 都包含,无话题区分度)
# 只过滤真正通用的疑问 pattern保留技术相关词
_ANCHOR_STOPWORDS = {
'有什么', '是什', '什么', '怎么', '为什么', '如何',
'是否', '可以吗', '能不能', '哪个', '哪些', '多少',
'怎样', '区别', '应用', '场景',
'可以', '需要', '应该', '时候', '情况', '一下',
'还有', '然后', '另外', '继续', '这个', '那个',
'实现', '使用', '配置', '处理', '解决', '分析',
'优化', '监控', '调试', '设置', '修改', '更新',
'上面', '前面', '后面', '刚才', '展开',
'一样', '相关', '以上', '以下', '其中', '如此',
'进行', '方式', '结果', '原因', '作用',
# 单字符停用
'', '', '', '', '', '', '',
'', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '',
}
class AnchorExtractor:
@@ -16,7 +37,7 @@ class AnchorExtractor:
self._doc_count = 0
def extract(self, text: str) -> List[str]:
"""从文本中提取锚点列表"""
"""从文本中提取锚点列表(过滤无区分度 stopwords"""
anchors = []
# 中文 2-gram / 3-gram转小写以便匹配
@@ -24,12 +45,16 @@ class AnchorExtractor:
for chunk in chinese_chars:
if len(chunk) >= 2:
for i in range(len(chunk) - 1):
anchors.append(chunk[i:i+2].lower())
ngram = chunk[i:i+2].lower()
if ngram not in _ANCHOR_STOPWORDS:
anchors.append(ngram)
if len(chunk) >= 3:
for i in range(len(chunk) - 2):
anchors.append(chunk[i:i+3].lower())
ngram = chunk[i:i+3].lower()
if ngram not in _ANCHOR_STOPWORDS:
anchors.append(ngram)
# 英文单词
# 英文单词(保留,有强区分度)
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
anchors.extend([w.lower() for w in english_words if len(w) >= 2])

View File

@@ -92,18 +92,48 @@ class SparseRetriever:
q_anchors_lower = [a.lower() for a in query_anchors]
# 内容词: 从 query 原文提取的 topic-discriminative 词汇
# 包括: 英文术语/标识符、版本号、2+字符中文词
# 中文通用短词(如"怎么")不具有话题区分度,排除
# 排除通用疑问句式(所有话题都会出现的 pattern
CHINESE_STOPWORDS = {
# 通用疑问结尾所有query都有无区分度
'有什么区别', '是什么', '怎么办', '怎么改', '怎么用',
'工作', '区别', '应用', '场景', '方法', '问题',
'有什么', '怎么', '哪些', '那个', '这个', '什么',
'', '', '', '', '', '', '',
# 通用技术词(跨话题都出现)
'区别是', '有什么区', '是什', '怎么才', '如何', '为什么',
# 2字通用词
'可以', '需要', '应该', '哪个', '什么', '怎么', '为什',
'时候', '情况', '一下', '问题', '相关', '以上', '以下',
'另外', '继续', '还有', '然后', '前面', '后面', '中间',
'时候', '过程', '机制', '原理', '方式', '方法', '策略',
'实现', '使用', '配置', '处理', '解决', '分析', '优化',
'监控', '调试', '设置', '修改', '更新', '升级', '迁移',
}
content_words = set()
# 英文单词和代码标识符(所有长度 >= 2
for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_text):
# 先过滤掉疑问句尾 pattern再提取
# 移除所有通用疑问 pattern如"有什么区别"
query_clean = query_text
generic_patterns = [
r'有什么.*?[?]', r'怎么.*?[?]', r'是什么.*?[?]',
r'为什么.*?[?]', r'如何.*?[?]', r'是否.*?[?]',
r'能不能.*?[?]', r'可以吗.*?[?]', r'吗.*?[?]',
r'么.*?[?]', r'哪些.*?[?]', r'哪个.*?[?]',
r'多少.*?[?]', r'怎样.*?[?]',
]
for pat in generic_patterns:
query_clean = re.sub(pat, '', query_clean)
# 英文单词和代码标识符(长度 >= 2有区分度
for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_clean):
if len(w) >= 2:
content_words.add(w.lower())
# 版本号
for v in re.findall(r'v?\d+(\.\d+)*', query_text):
for v in re.findall(r'v?\d+(\.\d+)*', query_clean):
content_words.add(v.lower())
# 2字及以上中文词覆盖"PostgreSQL"等专有名词
for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_text):
# 2字及以上中文词需过滤 stopwords
for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_clean):
if chunk not in CHINESE_STOPWORDS:
content_words.add(chunk.lower())
for i, block in enumerate(blocks):