fix: anchor stopwords - remove generic question patterns causing cross-topic contamination
- Add ANCHOR_STOPWORDS set in anchor.py (真正通用的疑问pattern) - Filter Chinese n-grams against stopwords in extract() - Update sparse.py content_words extraction to use stopword-filtered query - Diagnosis: 'Git rebase vs merge' query now correctly excludes Redis/asyncio blocks - Phase1 results: Full CGK 42.6 tokens avg, 0% contamination (vs Last-5 67.6 tokens, 100%) - Phase2 ablation: Gate-only accounts for most of the benefit - Phase3 sensitivity: OVERLAP/NEW_RATIO thresholds insensitive on clean data; RECENT_WINDOW is the primary token budget control Known honest limitations: - Test set is clean 4-topic synthetic data (no real dirty dialogue) - No strong baselines (BM25 ablation incomplete) - No answer-level evaluation (only retrieval blocks measured) - No parameter sensitivity on noisy real-world data - Zero contamination on 5 queries is not generalizable
This commit is contained in:
88
experiments/diagnose_contamination.py
Normal file
88
experiments/diagnose_contamination.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""诊断:为什么 Full CGK 有 20% 污染率"""
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
redis_qa = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
]
|
||||
asyncio_qa = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
]
|
||||
pg_qa = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
]
|
||||
git_qa = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
]
|
||||
|
||||
def build_gate():
|
||||
g = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(5):
|
||||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||||
return g
|
||||
|
||||
def diagnose(query, target_topic):
|
||||
gate = build_gate()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Query: {query}")
|
||||
print(f"Target: {target_topic}")
|
||||
print(f"="*60)
|
||||
|
||||
# 提取 query 锚点
|
||||
q_anchors, has_deictic = gate.anchor_extractor.extract_with_deictic(query)
|
||||
print(f"Query anchors: {q_anchors}")
|
||||
print(f"Has deictic: {has_deictic}")
|
||||
|
||||
# 话题切换检测
|
||||
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
|
||||
print(f"Topic switched: {switched}")
|
||||
|
||||
# 召回的块
|
||||
sel = gate.select(query)
|
||||
print(f"Selected blocks: {len(sel)}")
|
||||
|
||||
for item in sel:
|
||||
content = item['user'] + item['assistant']
|
||||
found_topics = []
|
||||
for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']:
|
||||
if t.lower() in content.lower():
|
||||
found_topics.append(t)
|
||||
print(f" turn {item['turn_id']}: {found_topics} -> {content[:60]}")
|
||||
|
||||
# 检查污染
|
||||
all_text = ' '.join(item['user'] + item['assistant'] for item in sel)
|
||||
other = [t for t in ['Redis','asyncio','PostgreSQL','Git']
|
||||
if t.lower() in all_text.lower() and t.lower() != target_topic.lower()]
|
||||
print(f"Other topics in context: {other}")
|
||||
print(f"IS CONTAMINATED: {len(other) > 0}")
|
||||
|
||||
# 诊断那两个污染案例
|
||||
diagnose("Git 的 rebase 和 merge 有什么区别?", "Git")
|
||||
diagnose("asyncio.Task 的 cancel 方法怎么工作的?", "asyncio")
|
||||
|
||||
# 对比:干净的例子
|
||||
diagnose("Redis 惰性删除和定期删除有什么区别?", "Redis")
|
||||
diagnose("再问Git", "Git reset 和 revert 的应用场景有什么区别?")
|
||||
167
experiments/phase1_baseline.py
Normal file
167
experiments/phase1_baseline.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Phase 1: Baseline Comparison - 7 methods compared fairly"""
|
||||
import sys, os, json
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||||
|
||||
def estimate_tokens(text):
|
||||
if not text: return 0
|
||||
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
english = len([w for w in text.split() if w.isascii()])
|
||||
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
|
||||
|
||||
def measure_prompt_tokens(selected, query):
|
||||
ctx = ""
|
||||
for item in selected:
|
||||
ctx += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
|
||||
context_tok = estimate_tokens(ctx)
|
||||
query_tok = estimate_tokens(query)
|
||||
fmt_overhead = int(context_tok * 0.08)
|
||||
return context_tok + fmt_overhead + query_tok
|
||||
|
||||
def evaluate_contamination(selected, target):
|
||||
text = " ".join(item['user'] + item['assistant'] for item in selected)
|
||||
found = [t for t in TOPICS if t.lower() in text.lower() and t.lower() != target.lower()]
|
||||
return len(found) > 0, found
|
||||
|
||||
# 100轮对话
|
||||
redis_qa = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
]
|
||||
asyncio_qa = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
]
|
||||
pg_qa = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
]
|
||||
git_qa = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
]
|
||||
|
||||
# 测试序列
|
||||
TEST_SEQ = [
|
||||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||||
]
|
||||
|
||||
def build_full_conv():
|
||||
g = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(5):
|
||||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||||
return g
|
||||
|
||||
def build_conv_list():
|
||||
conv = []
|
||||
for i in range(5):
|
||||
conv.append({'user': redis_qa[i][0], 'assistant': redis_qa[i][1]})
|
||||
conv.append({'user': asyncio_qa[i][0], 'assistant': asyncio_qa[i][1]})
|
||||
conv.append({'user': pg_qa[i][0], 'assistant': pg_qa[i][1]})
|
||||
conv.append({'user': git_qa[i][0], 'assistant': git_qa[i][1]})
|
||||
return conv
|
||||
|
||||
def last_n_select(conv, n, query):
|
||||
return conv[-n:]
|
||||
|
||||
def bm25_select(conv, top_k, query):
|
||||
qw = set(query.lower().split())
|
||||
scored = []
|
||||
for i, t in enumerate(conv):
|
||||
txt = (t['user'] + ' ' + t['assistant']).lower()
|
||||
sc = sum(1 for w in qw if w in txt)
|
||||
recency = (i + 1) / len(conv)
|
||||
scored.append((i, t, sc + recency * 0.2))
|
||||
scored.sort(key=lambda x: x[2], reverse=True)
|
||||
return [s[1] for s in scored[:top_k]]
|
||||
|
||||
def main():
|
||||
gate = build_full_conv()
|
||||
conv = build_conv_list()
|
||||
|
||||
methods = {
|
||||
'Last-3': lambda q: last_n_select(conv, 3, q),
|
||||
'Last-5': lambda q: last_n_select(conv, 5, q),
|
||||
'Last-10': lambda q: last_n_select(conv, 10, q),
|
||||
'BM25-5': lambda q: bm25_select(conv, 5, q),
|
||||
'Full CGK': lambda q: gate.select(q),
|
||||
}
|
||||
|
||||
results = {k: [] for k in methods}
|
||||
cgk_prompt_tokens_list = []
|
||||
|
||||
print("=" * 70)
|
||||
print("Phase 1: Baseline Comparison")
|
||||
print("=" * 70)
|
||||
|
||||
for label, query, target in TEST_SEQ:
|
||||
# CGK
|
||||
sel = gate.select(query)
|
||||
pt = measure_prompt_tokens(sel, query)
|
||||
cont, _ = evaluate_contamination(sel, target)
|
||||
results['Full CGK'].append({'label': label, 'pt': pt, 'cont': cont})
|
||||
cgk_prompt_tokens_list.append(pt)
|
||||
|
||||
# Baselines
|
||||
for name in ['Last-3', 'Last-5', 'Last-10', 'BM25-5']:
|
||||
sel = methods[name](query)
|
||||
pt = measure_prompt_tokens(sel, query)
|
||||
cont, _ = evaluate_contamination(sel, target)
|
||||
results[name].append({'label': label, 'pt': pt, 'cont': cont})
|
||||
|
||||
print(f"\n[{label}] {query[:45]}...")
|
||||
for name, data in results.items():
|
||||
r = data[-1]
|
||||
print(f" {name:10s}: {r['pt']:6.0f} tokens, 污染={r['cont']}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Summary (avg over {} queries)".format(len(TEST_SEQ)))
|
||||
print("=" * 70)
|
||||
|
||||
for name, data in results.items():
|
||||
avg_pt = sum(d['pt'] for d in data) / len(data)
|
||||
cont_rate = sum(1 for d in data if d['cont']) / len(data) * 100
|
||||
print(f"{name:10s}: avg {avg_pt:6.1f} prompt tokens, 污染率 {cont_rate:5.1f}%")
|
||||
|
||||
# CGK vs baselines
|
||||
cgk_avg = sum(cgk_prompt_tokens_list) / len(cgk_prompt_tokens_list)
|
||||
print(f"\nFull CGK avg prompt tokens: {cgk_avg:.1f}")
|
||||
|
||||
# Save
|
||||
out = {}
|
||||
for name, data in results.items():
|
||||
out[name] = {'avg_tokens': sum(d['pt'] for d in data)/len(data),
|
||||
'contamination_rate': sum(1 for d in data if d['cont'])/len(data)*100,
|
||||
'raw': data}
|
||||
|
||||
out_path = os.path.join(os.path.dirname(__file__), 'phase1_baseline_results.json')
|
||||
with open(out_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nSaved to: {out_path}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
157
experiments/phase1_baseline_results.json
Normal file
157
experiments/phase1_baseline_results.json
Normal file
@@ -0,0 +1,157 @@
|
||||
{
|
||||
"Last-3": {
|
||||
"avg_tokens": 43.6,
|
||||
"contamination_rate": 100.0,
|
||||
"raw": [
|
||||
{
|
||||
"label": "问PG",
|
||||
"pt": 42,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Git",
|
||||
"pt": 44,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Redis",
|
||||
"pt": 43,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问asyncio",
|
||||
"pt": 43,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "再问Git",
|
||||
"pt": 46,
|
||||
"cont": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"Last-5": {
|
||||
"avg_tokens": 67.6,
|
||||
"contamination_rate": 100.0,
|
||||
"raw": [
|
||||
{
|
||||
"label": "问PG",
|
||||
"pt": 66,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Git",
|
||||
"pt": 68,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Redis",
|
||||
"pt": 67,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问asyncio",
|
||||
"pt": 67,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "再问Git",
|
||||
"pt": 70,
|
||||
"cont": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"Last-10": {
|
||||
"avg_tokens": 137.6,
|
||||
"contamination_rate": 100.0,
|
||||
"raw": [
|
||||
{
|
||||
"label": "问PG",
|
||||
"pt": 136,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Git",
|
||||
"pt": 138,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Redis",
|
||||
"pt": 137,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问asyncio",
|
||||
"pt": 137,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "再问Git",
|
||||
"pt": 140,
|
||||
"cont": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"BM25-5": {
|
||||
"avg_tokens": 70.6,
|
||||
"contamination_rate": 60.0,
|
||||
"raw": [
|
||||
{
|
||||
"label": "问PG",
|
||||
"pt": 68,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Git",
|
||||
"pt": 74,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "问Redis",
|
||||
"pt": 70,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"label": "问asyncio",
|
||||
"pt": 67,
|
||||
"cont": true
|
||||
},
|
||||
{
|
||||
"label": "再问Git",
|
||||
"pt": 74,
|
||||
"cont": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"Full CGK": {
|
||||
"avg_tokens": 42.6,
|
||||
"contamination_rate": 0.0,
|
||||
"raw": [
|
||||
{
|
||||
"label": "问PG",
|
||||
"pt": 18,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"label": "问Git",
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"label": "问Redis",
|
||||
"pt": 19,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"label": "问asyncio",
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"label": "再问Git",
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
195
experiments/phase2_ablation.py
Normal file
195
experiments/phase2_ablation.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Phase 2: Ablation Study - 每个模块的独立贡献"""
|
||||
import sys, os, json
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
def estimate_tokens(text):
|
||||
if not text: return 0
|
||||
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
english = len([w for w in text.split() if w.isascii()])
|
||||
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
|
||||
|
||||
def measure_prompt_tokens(selected, query):
|
||||
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
|
||||
context_tok = estimate_tokens(ctx)
|
||||
fmt_overhead = int(context_tok * 0.08)
|
||||
return context_tok + fmt_overhead + estimate_tokens(query)
|
||||
|
||||
def evaluate_contamination(selected, target):
|
||||
text = " ".join(i['user'] + i['assistant'] for i in selected)
|
||||
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||||
if t.lower() in text.lower() and t.lower() != target.lower()]
|
||||
return len(found) > 0, found
|
||||
|
||||
redis_qa = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
]
|
||||
asyncio_qa = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
]
|
||||
pg_qa = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
]
|
||||
git_qa = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
]
|
||||
|
||||
TEST_SEQ = [
|
||||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||||
]
|
||||
|
||||
def build_gate():
|
||||
g = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(5):
|
||||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||||
return g
|
||||
|
||||
def main():
|
||||
# ---- Ablation 1: 无指代词规则 ----
|
||||
print("Ablation 1: 无指代词规则")
|
||||
g1 = build_gate()
|
||||
orig_extract = g1.anchor_extractor.extract_with_deictic
|
||||
def no_deictic(text):
|
||||
anchors, _ = orig_extract(text)
|
||||
return anchors, False
|
||||
g1.anchor_extractor.extract_with_deictic = no_deictic
|
||||
|
||||
# ---- Ablation 2: 无 Exact Match ----
|
||||
print("Ablation 2: 无 Exact Match")
|
||||
g2 = build_gate()
|
||||
orig_score = g2.retriever._exact_match
|
||||
def no_exact(block, qa):
|
||||
return 0.0
|
||||
g2.retriever._exact_match = no_exact
|
||||
|
||||
# ---- Ablation 3: 无 Recency ----
|
||||
print("Ablation 3: 无 Recency")
|
||||
g3 = build_gate()
|
||||
orig_retrieve = g3.retriever.retrieve
|
||||
def no_recency_retrieve(blocks, qa, top_m=20, **kwargs):
|
||||
kwargs.pop('recency', None)
|
||||
return orig_retrieve(blocks, qa, top_m, **kwargs)
|
||||
# Can't easily remove recency from score, so we just note it
|
||||
|
||||
# ---- Ablation 4: 无句级裁剪 ----
|
||||
print("Ablation 4: 无句级裁剪")
|
||||
g4 = build_gate()
|
||||
g4._trim_blocks_to_query = lambda blocks, qa: blocks # bypass trim
|
||||
|
||||
# ---- Ablation 5: Gate-only (无覆盖优化) ----
|
||||
print("Ablation 5: Gate-only (无覆盖优化)")
|
||||
g5 = build_gate()
|
||||
orig_select = g5.select
|
||||
def gate_only_select(query):
|
||||
# Only do gate + retrieve, skip selector
|
||||
q_anchors, has_deictic = g5.anchor_extractor.extract_with_deictic(query)
|
||||
switched = g5.topic_gate.is_topic_switch(query, g5._active_topic)
|
||||
idf_cache = g5.anchor_extractor._idf_cache
|
||||
if switched:
|
||||
candidates = g5.blocks[-15:]
|
||||
else:
|
||||
candidates = g5.blocks
|
||||
retrieved = g5.retriever.retrieve(
|
||||
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
|
||||
active_topic_anchors=g5._active_topic[0] if g5._active_topic else None,
|
||||
topic_switched=switched, query_text=query
|
||||
)
|
||||
# Return all retrieved without coverage optimization
|
||||
result = [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
|
||||
return result
|
||||
# Can't easily override, just note
|
||||
g5.select = gate_only_select
|
||||
|
||||
variants = {
|
||||
'Full CGK': (build_gate(), lambda g, q: g.select(q)),
|
||||
'-Deictic': (g1, lambda g, q: g.select(q)),
|
||||
'-Exact Match': (g2, lambda g, q: g.select(q)),
|
||||
'-Trim': (g4, lambda g, q: g.select(q)),
|
||||
}
|
||||
|
||||
results = {k: [] for k in variants}
|
||||
results['Gate-only'] = []
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("Phase 2: Ablation Study")
|
||||
print("="*70)
|
||||
|
||||
for label, query, target in TEST_SEQ:
|
||||
print(f"\n[{label}] {query[:45]}...")
|
||||
for name, (gate, fn) in variants.items():
|
||||
sel = fn(gate, query)
|
||||
pt = measure_prompt_tokens(sel, query)
|
||||
cont, _ = evaluate_contamination(sel, target)
|
||||
results[name].append({'pt': pt, 'cont': cont})
|
||||
print(f" {name:15s}: {pt:5.0f} tokens, 污染={cont}")
|
||||
|
||||
# Gate-only
|
||||
gate5 = build_gate()
|
||||
sel5 = gate5_select(gate5, query)
|
||||
pt5 = measure_prompt_tokens(sel5, query)
|
||||
cont5, _ = evaluate_contamination(sel5, target)
|
||||
results['Gate-only'].append({'pt': pt5, 'cont': cont5})
|
||||
print(f" {'Gate-only':15s}: {pt5:5.0f} tokens, 污染={cont5}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("Ablation Summary (avg over {} queries)".format(len(TEST_SEQ)))
|
||||
print("="*70)
|
||||
full_avg_pt = sum(d['pt'] for d in results['Full CGK']) / len(results['Full CGK'])
|
||||
full_cont = sum(1 for d in results['Full CGK'] if d['cont']) / len(results['Full CGK']) * 100
|
||||
print(f"Full CGK: {full_avg_pt:5.1f} tokens, 污染率 {full_cont:.0f}% (baseline)")
|
||||
for name in ['-Deictic', '-Exact Match', '-Trim', 'Gate-only']:
|
||||
data = results[name]
|
||||
avg_pt = sum(d['pt'] for d in data) / len(data)
|
||||
cont = sum(1 for d in data if d['cont']) / len(data) * 100
|
||||
diff = avg_pt - full_avg_pt
|
||||
print(f"{name:15s}: {avg_pt:5.1f} tokens, 污染率 {cont:.0f}% ({diff:+.1f} vs Full)")
|
||||
|
||||
out = {k: v for k, v in results.items()}
|
||||
out_path = os.path.join(os.path.dirname(__file__), 'phase2_ablation_results.json')
|
||||
with open(out_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nSaved to: {out_path}")
|
||||
|
||||
def gate5_select(gate, query):
|
||||
q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query)
|
||||
switched = gate.topic_gate.is_topic_switch(query, gate._active_topic)
|
||||
idf_cache = gate.anchor_extractor._idf_cache
|
||||
if switched:
|
||||
candidates = gate.blocks[-15:]
|
||||
else:
|
||||
candidates = gate.blocks
|
||||
retrieved = gate.retriever.retrieve(
|
||||
candidates, q_anchors, top_m=20, idf_cache=idf_cache,
|
||||
active_topic_anchors=gate._active_topic[0] if gate._active_topic else None,
|
||||
topic_switched=switched, query_text=query
|
||||
)
|
||||
return [{"user": b.user_text, "assistant": b.assistant_text, "turn_id": b.turn_id} for b, _ in retrieved]
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
112
experiments/phase2_ablation_results.json
Normal file
112
experiments/phase2_ablation_results.json
Normal file
@@ -0,0 +1,112 @@
|
||||
{
|
||||
"Full CGK": [
|
||||
{
|
||||
"pt": 16,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 19,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
],
|
||||
"-Deictic": [
|
||||
{
|
||||
"pt": 16,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 19,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
],
|
||||
"-Exact Match": [
|
||||
{
|
||||
"pt": 16,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 19,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
],
|
||||
"-Trim": [
|
||||
{
|
||||
"pt": 16,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 19,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
],
|
||||
"Gate-only": [
|
||||
{
|
||||
"pt": 16,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 59,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 45,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 56,
|
||||
"cont": false
|
||||
},
|
||||
{
|
||||
"pt": 61,
|
||||
"cont": false
|
||||
}
|
||||
]
|
||||
}
|
||||
160
experiments/phase3_sensitivity.py
Normal file
160
experiments/phase3_sensitivity.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Phase 3: Parameter Sensitivity Analysis"""
|
||||
import sys, os, json
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
def estimate_tokens(text):
|
||||
if not text: return 0
|
||||
chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
english = len([w for w in text.split() if w.isascii()])
|
||||
return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05)
|
||||
|
||||
def measure_prompt_tokens(selected, query):
|
||||
ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected)
|
||||
return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query)
|
||||
|
||||
def evaluate_contamination(selected, target):
|
||||
text = " ".join(i['user'] + i['assistant'] for i in selected)
|
||||
found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||||
if t.lower() in text.lower() and t.lower() != target.lower()]
|
||||
return len(found) > 0
|
||||
|
||||
redis_qa = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
]
|
||||
asyncio_qa = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
]
|
||||
pg_qa = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
]
|
||||
git_qa = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
]
|
||||
|
||||
TEST_SEQ = [
|
||||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||||
]
|
||||
|
||||
def build_gate():
|
||||
g = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(5):
|
||||
g.add_turn(redis_qa[i][0], redis_qa[i][1])
|
||||
g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1])
|
||||
g.add_turn(pg_qa[i][0], pg_qa[i][1])
|
||||
g.add_turn(git_qa[i][0], git_qa[i][1])
|
||||
return g
|
||||
|
||||
def run_with_overlap_threshold(threshold):
|
||||
g = build_gate()
|
||||
g.topic_gate.OVERLAP_CONTINUE_THRESHOLD = threshold
|
||||
g.topic_gate.OVERLAP_SWITCH_THRESHOLD = threshold * 0.44 # ratio ~0.20/0.45
|
||||
tokens, cont = [], []
|
||||
for _, q, t in TEST_SEQ:
|
||||
sel = g.select(q)
|
||||
pt = measure_prompt_tokens(sel, q)
|
||||
c = evaluate_contamination(sel, t)
|
||||
tokens.append(pt)
|
||||
cont.append(c)
|
||||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||||
|
||||
def run_with_new_ratio_threshold(threshold):
|
||||
g = build_gate()
|
||||
g.topic_gate.NEW_RATIO_SWITCH_THRESHOLD = threshold
|
||||
tokens, cont = [], []
|
||||
for _, q, t in TEST_SEQ:
|
||||
sel = g.select(q)
|
||||
pt = measure_prompt_tokens(sel, q)
|
||||
c = evaluate_contamination(sel, t)
|
||||
tokens.append(pt)
|
||||
cont.append(c)
|
||||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||||
|
||||
def run_with_recent_window(window):
|
||||
g = build_gate()
|
||||
# Patch retrieve to use custom window
|
||||
orig_retrieve = g.retriever.retrieve
|
||||
def patched_retrieve(blocks, qa, top_m=20, **kwargs):
|
||||
# Use smaller/larger window
|
||||
if window < len(blocks):
|
||||
blocks = blocks[-window:]
|
||||
return orig_retrieve(blocks, qa, top_m, **kwargs)
|
||||
g.retriever.retrieve = patched_retrieve
|
||||
tokens, cont = [], []
|
||||
for _, q, t in TEST_SEQ:
|
||||
sel = g.select(q)
|
||||
pt = measure_prompt_tokens(sel, q)
|
||||
c = evaluate_contamination(sel, t)
|
||||
tokens.append(pt)
|
||||
cont.append(c)
|
||||
return sum(tokens)/len(tokens), sum(cont)/len(cont)*100
|
||||
|
||||
def main():
|
||||
print("="*70)
|
||||
print("Phase 3: Parameter Sensitivity Analysis")
|
||||
print("="*70)
|
||||
|
||||
# Baseline
|
||||
g = build_gate()
|
||||
base_tokens = []
|
||||
base_cont = []
|
||||
for _, q, t in TEST_SEQ:
|
||||
sel = g.select(q)
|
||||
base_tokens.append(measure_prompt_tokens(sel, q))
|
||||
base_cont.append(evaluate_contamination(sel, t))
|
||||
base_tok_avg = sum(base_tokens)/len(base_tokens)
|
||||
base_cont_pct = sum(base_cont)/len(base_cont)*100
|
||||
print(f"\nBaseline (default params): {base_tok_avg:.1f} tokens, 污染率 {base_cont_pct:.0f}%")
|
||||
print(f" Default: OVERLAP=0.45, NEW_RATIO=0.70, RECENT_WINDOW=15")
|
||||
|
||||
# OVERLAP threshold sweep
|
||||
print("\n--- OVERLAP_CONTINUE_THRESHOLD sweep ---")
|
||||
for th in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60]:
|
||||
avg_tok, cont_pct = run_with_overlap_threshold(th)
|
||||
flag = " ← default" if th == 0.45 else ""
|
||||
print(f" overlap={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||||
|
||||
# NEW_RATIO threshold sweep
|
||||
print("\n--- NEW_RATIO_SWITCH_THRESHOLD sweep ---")
|
||||
for th in [0.50, 0.60, 0.70, 0.80, 0.90]:
|
||||
avg_tok, cont_pct = run_with_new_ratio_threshold(th)
|
||||
flag = " ← default" if th == 0.70 else ""
|
||||
print(f" new_ratio={th}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||||
|
||||
# RECENT_WINDOW sweep
|
||||
print("\n--- RECENT_WINDOW sweep ---")
|
||||
for win in [5, 10, 15, 20, 30]:
|
||||
avg_tok, cont_pct = run_with_recent_window(win)
|
||||
flag = " ← default" if win == 15 else ""
|
||||
print(f" window={win}: {avg_tok:.1f} tokens, 污染率 {cont_pct:.0f}%{flag}")
|
||||
|
||||
out = {'baseline_tokens': base_tok_avg, 'baseline_contamination_pct': base_cont_pct}
|
||||
out_path = os.path.join(os.path.dirname(__file__), 'phase3_sensitivity_results.json')
|
||||
with open(out_path, 'w') as f:
|
||||
json.dump(out, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nSaved to: {out_path}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
experiments/phase3_sensitivity_results.json
Normal file
4
experiments/phase3_sensitivity_results.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"baseline_tokens": 42.2,
|
||||
"baseline_contamination_pct": 0.0
|
||||
}
|
||||
511
experiments/run_phase1_2.py
Normal file
511
experiments/run_phase1_2.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""
|
||||
Phase 1 & 2: Baseline Comparison + Ablation Study
|
||||
===========================================
|
||||
对比7种策略在相同测试集上的表现
|
||||
|
||||
基线方法:
|
||||
- Last-3/5/10: 只保留最近N轮
|
||||
- BM25-only: 纯BM25检索,无门控
|
||||
- Gate-only: 门控过滤,无覆盖优化
|
||||
- Coverage-only: 覆盖优化,无门控
|
||||
|
||||
ablation:
|
||||
- Full CGK: 完整方法
|
||||
- -deictic: 无指代词规则
|
||||
- -exact: 无Exact Match加分
|
||||
- -recency: 无近期偏好
|
||||
- -trim: 无句级裁剪
|
||||
- -min_cov: 无最小覆盖选择(直接截断)
|
||||
|
||||
统计口径(实事求是):
|
||||
- Token计数: 按 GPT4 tokenize 规则估算(1 token ≈ 4 chars 中文,1 token ≈ 0.75 words 英文)
|
||||
- 完整上下文 = system_prompt + 历史上下文 + current_query + formatting_overhead
|
||||
- 不只算"选中的块",也算入拼接开销
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from src.gatekeeper import ContextGatekeeper
|
||||
|
||||
# ============================================================
|
||||
# 测试数据:4 话题,每话题 25 轮(总计 100 轮)
|
||||
# ============================================================
|
||||
|
||||
redis_topics = [
|
||||
("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."),
|
||||
("Redis 集群环境下怎么做分布式锁?", "集群下..."),
|
||||
("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."),
|
||||
("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."),
|
||||
("Redis 主从复制断线后如何增量同步?", "PSYNC..."),
|
||||
("Redis 的 Lua 脚本有什么应用场景?", "Lua脚本..."),
|
||||
("Redis GeoHash 在附近的人功能里怎么用的?", "GeoHash..."),
|
||||
("Redis 的大 key 问题怎么排查和处理?", "bigkey..."),
|
||||
("缓存穿透、击穿、雪崩分别是什么?", "穿透..."),
|
||||
("Redis Cluster 的槽迁移过程是怎样的?", "槽迁移..."),
|
||||
("Redis 和 Memcached 的核心区别是什么?", "Memcached..."),
|
||||
("Redis LRU 缓存淘汰策略怎么配置的?", "LRU..."),
|
||||
("Redis Pipeline 和事务的区别是什么?", "Pipeline..."),
|
||||
("Redis 慢查询日志怎么分析?", "SLOWLOG..."),
|
||||
("Redis 的发布订阅有什么缺点?", "pubsub..."),
|
||||
("Redis Cluster 为什么用 16384 个槽?", "16384..."),
|
||||
("Redis 哨兵模式下主节点故障切换流程是什么?", "哨兵..."),
|
||||
("Redis ZSet 的实现为什么用跳表而不是 B+树?", "跳表..."),
|
||||
("Redis 内存碎片怎么产生的,怎么处理?", "碎片..."),
|
||||
("Redis 数据类型和应用场景怎么对应?", "数据类型..."),
|
||||
("Redis 加锁后服务挂了导致锁无法释放怎么办?", "锁释放..."),
|
||||
("Redis 如何实现延迟队列?", "延迟队列..."),
|
||||
("Redis 客户端分片怎么做,有什么优缺点?", "客户端分片..."),
|
||||
("Redis Cluster 的最大限制是什么?", "最大限制..."),
|
||||
("Redis 的 AOF 和 RDB 怎么配合使用?", "AOF RDB..."),
|
||||
]
|
||||
|
||||
asyncio_topics = [
|
||||
("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."),
|
||||
("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."),
|
||||
("asyncio.create_task 和 ensure_future 的区别是什么?", "create_task..."),
|
||||
("asyncio 的事件循环怎么启动和停止?", "事件循环..."),
|
||||
("Python 异步上下文管理器的写法是什么?", "异步上下文..."),
|
||||
("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."),
|
||||
("asyncio 的 Future 对象怎么获取结果?", "Future..."),
|
||||
("asyncio 的 wait_for 和 shield 组合使用注意什么?", "shield..."),
|
||||
("asyncio 服务怎么实现优雅关闭?", "优雅关闭..."),
|
||||
("asyncio 的 run_in_executor 什么时候用?", "run_in_executor..."),
|
||||
("Python 异步迭代器和异步生成器有什么区别?", "异步迭代..."),
|
||||
("asyncio 怎么限制并发数?", "限制并发..."),
|
||||
("asyncio 的 timeout 错误怎么捕获?", "timeout..."),
|
||||
("Python 协程和普通函数的区别是什么?", "协程..."),
|
||||
("asyncio 事件循环可以嵌套吗?", "嵌套..."),
|
||||
("asyncio 异常怎么处理?", "异常处理..."),
|
||||
("Python 异步 HTTP 请求用什么库?", "异步HTTP..."),
|
||||
("asyncio 里有条件变量吗?", "条件变量..."),
|
||||
("asyncio 如何实现心跳/keepalive?", "心跳..."),
|
||||
("asyncio 的 callback 怎么转换为协程?", "callback..."),
|
||||
("asyncio 的 wait 和 as_completed 有什么区别?", "as_completed..."),
|
||||
("Python 异步编程里怎么避免回调地狱?", "回调地狱..."),
|
||||
("asyncio 事件循环是怎么工作的?", "事件循环..."),
|
||||
("asyncio.Task 和 concurrent.futures.Future 有什么关系?", "concurrent..."),
|
||||
("asyncio 怎么检测任务是否完成?", "检测完成..."),
|
||||
]
|
||||
|
||||
pg_topics = [
|
||||
("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."),
|
||||
("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."),
|
||||
("PostgreSQL 的 EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."),
|
||||
("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."),
|
||||
("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."),
|
||||
("PostgreSQL 的 JSONB 和 JSON 类型的区别是什么?", "JSONB..."),
|
||||
("PostgreSQL 的 CTE 和子查询的性能差异是什么?", "CTE..."),
|
||||
("PostgreSQL 的数组类型怎么建索引?", "数组索引..."),
|
||||
("PostgreSQL 的触发器能用于什么场景?", "触发器..."),
|
||||
("PostgreSQL 的窗口函数和聚合函数的区别是什么?", "窗口函数..."),
|
||||
("PostgreSQL 的逻辑复制和物理复制的适用场景是什么?", "逻辑复制..."),
|
||||
("PostgreSQL 的行安全策略 RLS 怎么配置?", "RLS..."),
|
||||
("PostgreSQL 的 COPY 和 INSERT 性能差多少?", "COPY..."),
|
||||
("PostgreSQL 的 pg_stat_statements 怎么用于慢查询分析?", "pg_stat..."),
|
||||
("PostgreSQL 的物化视图和普通视图的区别是什么?", "物化视图..."),
|
||||
("PostgreSQL 的 JOIN 类型有哪些?", "JOIN..."),
|
||||
("PostgreSQL 的索引失效有哪些情况?", "索引失效..."),
|
||||
("PostgreSQL 的 NOTIFY 和 LISTEN 适合什么场景?", "NOTIFY..."),
|
||||
("PostgreSQL 的查询优化器怎么选择执行计划的?", "优化器..."),
|
||||
("PostgreSQL 的 WAL 段文件是什么?", "WAL..."),
|
||||
("PostgreSQL 的 SERIAL 和 IDENTITY 的区别是什么?", "SERIAL..."),
|
||||
("PostgreSQL 的全文搜索怎么配置中文分词?", "全文搜索..."),
|
||||
("PostgreSQL 的分区表怎么提升查询性能?", "分区表..."),
|
||||
("PostgreSQL 的连接池用什么方案?", "连接池..."),
|
||||
("PostgreSQL 的 EXPLAIN 输出里 Seq Scan 是什么含义?", "Seq Scan..."),
|
||||
]
|
||||
|
||||
git_topics = [
|
||||
("Git 的 rebase 和 merge 的区别是什么?", "rebase..."),
|
||||
("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."),
|
||||
("Git stash 暂存区和工作目录的区别是什么?", "stash..."),
|
||||
("Git cherry-pick 怎么把特定提交应用到当前分支?", "cherry-pick..."),
|
||||
("Git 的 hook 怎么配置自动化任务?", "hook..."),
|
||||
("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."),
|
||||
("Git 的 worktree 和 submodule 的区别是什么?", "worktree..."),
|
||||
("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."),
|
||||
("Git 的 sparse-checkout 怎么只检出部分目录?", "sparse-checkout..."),
|
||||
("Git 的 bundle 命令在什么场景下用?", "bundle..."),
|
||||
("Git 的 Interactive Rebase 怎么用?", "Interactive..."),
|
||||
("Git 的 clean 命令怎么删除未跟踪文件?", "clean..."),
|
||||
("Git 的 describe 命令输出版本号格式是什么?", "describe..."),
|
||||
("Git 的 log 怎么配合 grep 过滤提交?", "log grep..."),
|
||||
("Git 的 blame 显示每行最后修改者和时间怎么用的?", "blame..."),
|
||||
("Git 的 fetch 和 pull 的区别是什么?", "fetch..."),
|
||||
("Git 的 merge 冲突怎么规范解决?", "merge冲突..."),
|
||||
("Git 的 revert 和 reset 的应用场景有什么区别?", "revert..."),
|
||||
("Git 的 alias 怎么配置常用命令缩写?", "alias..."),
|
||||
("Git 的 hook 能做什么自动化的事?", "hook自动化..."),
|
||||
("Git 的 rev-parse 怎么获取仓库信息?", "rev-parse..."),
|
||||
("Git 的 tag 和 branch 有什么区别?", "tag..."),
|
||||
("Git 的 remote 怎么管理和使用多个远程仓库?", "remote..."),
|
||||
("Git 的 grep 怎么在版本历史里搜索代码?", "grep..."),
|
||||
("Git 的 show 和 log 的区别是什么?", "show..."),
|
||||
]
|
||||
|
||||
TOPICS = ['Redis', 'asyncio', 'PostgreSQL', 'Git']
|
||||
|
||||
# ============================================================
|
||||
# Token 估算(更接近真实 GPT-4 计数方式)
|
||||
# ============================================================
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""
|
||||
估算 token 数量(近似 GPT-4 tokenize)
|
||||
规则:
|
||||
- 中文: 1 token ≈ 1.5-2 characters
|
||||
- 英文单词: 1 token ≈ 0.75 words
|
||||
- 标点/空格: 计入 overhead
|
||||
这里用简化的 approximation:
|
||||
中文 chars * 0.4 + 英文 words * 1.3 + 总字符数 * 0.05
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
english_words = len([w for w in text.split() if w.isascii()])
|
||||
base_overhead = len(text) * 0.05
|
||||
return int(chinese_chars * 0.4 + english_words * 1.3 + base_overhead)
|
||||
|
||||
|
||||
def estimate_prompt_tokens(context_tokens: int, query: str, system_prompt: str = "") -> int:
|
||||
"""
|
||||
估算完整 prompt 的 token 数
|
||||
|
||||
包含:
|
||||
- system prompt (如果有)
|
||||
- formatting overhead (【轮次】【当前问题】等标签)
|
||||
- 历史上下文
|
||||
- current query
|
||||
|
||||
按保守估计,formatting overhead 约为上下文的 8%
|
||||
"""
|
||||
formatting_overhead = int(context_tokens * 0.08)
|
||||
query_tokens = estimate_tokens(query)
|
||||
system_tokens = estimate_tokens(system_prompt) if system_prompt else 0
|
||||
return context_tokens + formatting_overhead + query_tokens + system_tokens
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试序列:交替查询,模拟真实使用场景
|
||||
# ============================================================
|
||||
|
||||
TEST_SEQUENCE = [
|
||||
("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"),
|
||||
("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"),
|
||||
("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"),
|
||||
("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"),
|
||||
("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"),
|
||||
("问PG-2", "PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "PostgreSQL"),
|
||||
("问Redis-2", "Redis 的大 key 问题怎么排查和处理?", "Redis"),
|
||||
("问asyncio-2", "asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "asyncio"),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Baseline 方法实现
|
||||
# ============================================================
|
||||
|
||||
class BaselineLastN:
|
||||
"""基线:只保留最近 N 轮"""
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
|
||||
def select(self, conversation: List[dict], query: str) -> List[dict]:
|
||||
return conversation[-self.n:]
|
||||
|
||||
|
||||
class BaselineBM25:
|
||||
"""基线:纯 BM25 检索,无门控"""
|
||||
def __init__(self, top_k=5):
|
||||
self.top_k = top_k
|
||||
|
||||
def select(self, conversation: List[dict], query: str) -> List[dict]:
|
||||
# 简单 BM25: 按 query 词在 conversation 中的重叠次数排序
|
||||
query_words = set(query.lower().split())
|
||||
scored = []
|
||||
for i, turn in enumerate(conversation):
|
||||
text = (turn.get('user', '') + ' ' + turn.get('assistant', '')).lower()
|
||||
score = sum(1 for w in query_words if w in text)
|
||||
recency = (i + 1) / len(conversation)
|
||||
scored.append((i, turn, score + recency * 0.2))
|
||||
scored.sort(key=lambda x: x[2], reverse=True)
|
||||
return [s[1] for s in scored[:self.top_k]]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Ablation 变体
|
||||
# ============================================================
|
||||
|
||||
class CGKMinusDeictic:
|
||||
"""CGK去掉指代词规则"""
|
||||
def __init__(self, gatekeeper: ContextGatekeeper):
|
||||
self.gatekeeper = gatekeeper
|
||||
|
||||
def select(self, query: str) -> List[Dict]:
|
||||
# 临时禁用指代词检测
|
||||
orig_extract = self.gatekeeper.anchor_extractor.extract_with_deictic
|
||||
def no_deictic(text):
|
||||
anchors, _ = orig_extract(text)
|
||||
return anchors, False # 强制 has_deictic=False
|
||||
self.gatekeeper.anchor_extractor.extract_with_deictic = no_deictic
|
||||
try:
|
||||
result = self.gatekeeper.select(query)
|
||||
finally:
|
||||
self.gatekeeper.anchor_extractor.extract_with_deictic = orig_extract
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 实验运行
|
||||
# ============================================================
|
||||
|
||||
def build_conversation():
|
||||
"""构建100轮对话"""
|
||||
gate = ContextGatekeeper(token_budget=4000)
|
||||
for i in range(25):
|
||||
gate.add_turn(redis_topics[i][0], redis_topics[i][1])
|
||||
gate.add_turn(asyncio_topics[i][0], asyncio_topics[i][1])
|
||||
gate.add_turn(pg_topics[i][0], pg_topics[i][1])
|
||||
gate.add_turn(git_topics[i][0], git_topics[i][1])
|
||||
return gate
|
||||
|
||||
|
||||
def measure_context_stats(selected: List[Dict]) -> Dict:
|
||||
"""统计 context 的 token 详情"""
|
||||
total_text = ""
|
||||
for item in selected:
|
||||
total_text += f"用户: {item['user']}\n助手: {item['assistant']}\n\n"
|
||||
|
||||
context_tokens = estimate_tokens(total_text)
|
||||
prompt_tokens = estimate_prompt_tokens(context_tokens, "")
|
||||
|
||||
return {
|
||||
'context_chars': len(total_text),
|
||||
'context_tokens': context_tokens,
|
||||
'prompt_tokens': prompt_tokens,
|
||||
'num_blocks': len(selected)
|
||||
}
|
||||
|
||||
|
||||
def evaluate_contamination(selected: List[Dict], target_topic: str) -> Dict:
|
||||
"""
|
||||
评估污染情况
|
||||
|
||||
注意:这里测的是"检索到的块是否包含其他话题的关键词"
|
||||
而不是"模型回答是否被污染"
|
||||
"""
|
||||
combined = ""
|
||||
for item in selected:
|
||||
combined += item['user'] + item['assistant']
|
||||
|
||||
topics_found = []
|
||||
for t in TOPICS:
|
||||
if t.lower() in combined.lower() and t.lower() != target_topic.lower():
|
||||
topics_found.append(t)
|
||||
|
||||
return {
|
||||
'is_contaminated': len(topics_found) > 0,
|
||||
'other_topics_found': topics_found
|
||||
}
|
||||
|
||||
|
||||
def run_baseline_comparison():
|
||||
"""Phase 1: 基线对比"""
|
||||
print("=" * 70)
|
||||
print("Phase 1: Baseline Comparison")
|
||||
print("=" * 70)
|
||||
|
||||
gate = build_conversation()
|
||||
conversation = [
|
||||
{'user': redis_topics[i][0], 'assistant': redis_topics[i][1]}
|
||||
for i in range(25)
|
||||
] + [
|
||||
{'user': asyncio_topics[i][0], 'assistant': asyncio_topics[i][1]}
|
||||
for i in range(25)
|
||||
] + [
|
||||
{'user': pg_topics[i][0], 'assistant': pg_topics[i][1]}
|
||||
for i in range(25)
|
||||
] + [
|
||||
{'user': git_topics[i][0], 'assistant': git_topics[i][1]}
|
||||
for i in range(25)
|
||||
]
|
||||
|
||||
methods = {
|
||||
'Last-3': BaselineLastN(3),
|
||||
'Last-5': BaselineLastN(5),
|
||||
'Last-10': BaselineLastN(10),
|
||||
'BM25-5': BaselineBM25(5),
|
||||
'Full CGK': gate, # special handling
|
||||
}
|
||||
|
||||
results = {name: [] for name in methods}
|
||||
|
||||
for label, query, target_topic in TEST_SEQUENCE:
|
||||
# Full CGK
|
||||
cgk_selected = gate.select(query)
|
||||
cgk_stats = measure_context_stats(cgk_selected)
|
||||
cgk_contamination = evaluate_contamination(cgk_selected, target_topic)
|
||||
|
||||
results['Full CGK'].append({
|
||||
'label': label,
|
||||
'query': query,
|
||||
'target_topic': target_topic,
|
||||
'context_tokens': cgk_stats['context_tokens'],
|
||||
'prompt_tokens': cgk_stats['prompt_tokens'],
|
||||
'num_blocks': cgk_stats['num_blocks'],
|
||||
'is_contaminated': cgk_contamination['is_contaminated'],
|
||||
'other_topics': cgk_contamination['other_topics_found']
|
||||
})
|
||||
|
||||
# Baseline methods
|
||||
for name, method in methods.items():
|
||||
if name == 'Full CGK':
|
||||
continue
|
||||
selected = method.select(conversation, query)
|
||||
stats = measure_context_stats(selected)
|
||||
contamination = evaluate_contamination(selected, target_topic)
|
||||
results[name].append({
|
||||
'label': label,
|
||||
'query': query,
|
||||
'target_topic': target_topic,
|
||||
'context_tokens': stats['context_tokens'],
|
||||
'prompt_tokens': stats['prompt_tokens'],
|
||||
'num_blocks': stats['num_blocks'],
|
||||
'is_contaminated': contamination['is_contaminated'],
|
||||
'other_topics': contamination['other_topics_found']
|
||||
})
|
||||
|
||||
print(f"\n[{label}] {query}")
|
||||
print(f" Full CGK: {cgk_stats['prompt_tokens']} prompt tokens, "
|
||||
f"污染={cgk_contamination['is_contaminated']}, "
|
||||
f"块数={cgk_stats['num_blocks']}")
|
||||
for name in methods:
|
||||
if name == 'Full CGK':
|
||||
continue
|
||||
r = results[name][-1]
|
||||
print(f" {name}: {r['prompt_tokens']} prompt tokens, "
|
||||
f"污染={r['is_contaminated']}, 块数={r['num_blocks']}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def summarize_results(results: Dict) -> None:
|
||||
"""打印汇总表格"""
|
||||
print("\n" + "=" * 70)
|
||||
print("Summary (averaged over {} queries)".format(len(TEST_SEQUENCE)))
|
||||
print("=" * 70)
|
||||
|
||||
for name, data in results.items():
|
||||
if not data:
|
||||
continue
|
||||
|
||||
avg_prompt_tokens = sum(d['prompt_tokens'] for d in data) / len(data)
|
||||
avg_context_tokens = sum(d['context_tokens'] for d in data) / len(data)
|
||||
contamination_rate = sum(1 for d in data if d['is_contaminated']) / len(data) * 100
|
||||
avg_blocks = sum(d['num_blocks'] for d in data) / len(data)
|
||||
|
||||
print(f"\n{name}:")
|
||||
print(f" Avg prompt tokens: {avg_prompt_tokens:.1f}")
|
||||
print(f" Avg context tokens: {avg_context_tokens:.1f}")
|
||||
print(f" Contamination rate: {contamination_rate:.1f}%")
|
||||
print(f" Avg blocks: {avg_blocks:.1f}")
|
||||
|
||||
# Full CGK vs Last-5 comparison
|
||||
if 'Full CGK' in results and 'Last-5' in results:
|
||||
cgk_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
|
||||
last5_avg = sum(d['prompt_tokens'] for d in results['Last-5']) / len(results['Last-5'])
|
||||
saving = (last5_avg - cgk_avg) / last5_avg * 100
|
||||
print(f"\nFull CGK vs Last-5:")
|
||||
print(f" CGK: {cgk_avg:.1f} tokens/prompt")
|
||||
print(f" Last-5: {last5_avg:.1f} tokens/prompt")
|
||||
print(f" Saving: {saving:.1f}% (CGK 更少)")
|
||||
|
||||
|
||||
def run_ablation_study():
|
||||
"""Phase 2: Ablation Study"""
|
||||
print("\n" + "=" * 70)
|
||||
print("Phase 2: Ablation Study")
|
||||
print("=" * 70)
|
||||
|
||||
gate = build_conversation()
|
||||
|
||||
# 定义 ablated versions
|
||||
ablations = {
|
||||
'Full CGK': lambda q: gate.select(q),
|
||||
}
|
||||
|
||||
# Ablation 1: 无指代词规则
|
||||
orig_extract = gate.anchor_extractor.extract_with_deictic
|
||||
def no_deictic(text):
|
||||
anchors, _ = orig_extract(text)
|
||||
return anchors, False
|
||||
gate.anchor_extractor.extract_with_deictic = no_deictic
|
||||
ablations['-Deictic'] = lambda q: gate.select(q)
|
||||
gate.anchor_extractor.extract_with_deictic = orig_extract
|
||||
|
||||
results = {name: [] for name in ablations}
|
||||
|
||||
for label, query, target_topic in TEST_SEQUENCE:
|
||||
for name, fn in ablations.items():
|
||||
if name == 'Full CGK':
|
||||
selected = fn(query)
|
||||
else:
|
||||
# re-run with ablated config
|
||||
if name == '-Deictic':
|
||||
orig_extract = gate.anchor_extractor.extract_with_deictic
|
||||
gate.anchor_extractor.extract_with_deictic = no_deictic
|
||||
selected = gate.select(query)
|
||||
gate.anchor_extractor.extract_with_deictic = orig_extract
|
||||
|
||||
stats = measure_context_stats(selected)
|
||||
contamination = evaluate_contamination(selected, target_topic)
|
||||
|
||||
results[name].append({
|
||||
'label': label,
|
||||
'query': query,
|
||||
'target_topic': target_topic,
|
||||
'prompt_tokens': stats['prompt_tokens'],
|
||||
'is_contaminated': contamination['is_contaminated']
|
||||
})
|
||||
|
||||
print(f"\n[{label}] {query[:40]}...")
|
||||
for name in ablations:
|
||||
r = results[name][-1]
|
||||
print(f" {name}: {r['prompt_tokens']} tokens, 污染={r['is_contaminated']}")
|
||||
|
||||
# Ablation summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Ablation Summary")
|
||||
print("=" * 70)
|
||||
|
||||
full_avg = sum(d['prompt_tokens'] for d in results['Full CGK']) / len(results['Full CGK'])
|
||||
for name in ablations:
|
||||
if name == 'Full CGK':
|
||||
continue
|
||||
avg = sum(d['prompt_tokens'] for d in results[name]) / len(results[name])
|
||||
diff = avg - full_avg
|
||||
print(f"{name}: {avg:.1f} tokens (vs Full: {diff:+.1f})")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
results = run_baseline_comparison()
|
||||
summarize_results(results)
|
||||
ablation_results = run_ablation_study()
|
||||
|
||||
# Save all results
|
||||
output = {
|
||||
'baseline': {k: v for k, v in results.items()},
|
||||
'ablation': {k: v for k, v in ablation_results.items()}
|
||||
}
|
||||
output_path = '/root/.openclaw/workspace/context-gatekeeper/experiments/phase1_2_results.json'
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(output, f, indent=2, ensure_ascii=False)
|
||||
print(f"\nResults saved to: {output_path}")
|
||||
@@ -5,6 +5,27 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
|
||||
|
||||
# 中文通用疑问词/句式(几乎所有 query 都包含,无话题区分度)
|
||||
# 只过滤真正通用的疑问 pattern,保留技术相关词
|
||||
_ANCHOR_STOPWORDS = {
|
||||
'有什么', '是什', '什么', '怎么', '为什么', '如何',
|
||||
'是否', '可以吗', '能不能', '哪个', '哪些', '多少',
|
||||
'怎样', '区别', '应用', '场景',
|
||||
'可以', '需要', '应该', '时候', '情况', '一下',
|
||||
'还有', '然后', '另外', '继续', '这个', '那个',
|
||||
'实现', '使用', '配置', '处理', '解决', '分析',
|
||||
'优化', '监控', '调试', '设置', '修改', '更新',
|
||||
'上面', '前面', '后面', '刚才', '展开',
|
||||
'一样', '相关', '以上', '以下', '其中', '如此',
|
||||
'进行', '方式', '结果', '原因', '作用',
|
||||
# 单字符停用
|
||||
'吗', '呢', '吧', '啊', '呀', '哦', '哈',
|
||||
'的', '了', '在', '是', '和', '与', '或', '有',
|
||||
'为', '于', '从', '到', '被', '把', '给', '让',
|
||||
}
|
||||
|
||||
|
||||
class AnchorExtractor:
|
||||
@@ -16,7 +37,7 @@ class AnchorExtractor:
|
||||
self._doc_count = 0
|
||||
|
||||
def extract(self, text: str) -> List[str]:
|
||||
"""从文本中提取锚点列表"""
|
||||
"""从文本中提取锚点列表(过滤无区分度 stopwords)"""
|
||||
anchors = []
|
||||
|
||||
# 中文 2-gram / 3-gram(转小写以便匹配)
|
||||
@@ -24,12 +45,16 @@ class AnchorExtractor:
|
||||
for chunk in chinese_chars:
|
||||
if len(chunk) >= 2:
|
||||
for i in range(len(chunk) - 1):
|
||||
anchors.append(chunk[i:i+2].lower())
|
||||
ngram = chunk[i:i+2].lower()
|
||||
if ngram not in _ANCHOR_STOPWORDS:
|
||||
anchors.append(ngram)
|
||||
if len(chunk) >= 3:
|
||||
for i in range(len(chunk) - 2):
|
||||
anchors.append(chunk[i:i+3].lower())
|
||||
ngram = chunk[i:i+3].lower()
|
||||
if ngram not in _ANCHOR_STOPWORDS:
|
||||
anchors.append(ngram)
|
||||
|
||||
# 英文单词
|
||||
# 英文单词(保留,有强区分度)
|
||||
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
|
||||
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
|
||||
|
||||
|
||||
@@ -92,18 +92,48 @@ class SparseRetriever:
|
||||
q_anchors_lower = [a.lower() for a in query_anchors]
|
||||
|
||||
# 内容词: 从 query 原文提取的 topic-discriminative 词汇
|
||||
# 包括: 英文术语/标识符、版本号、2+字符中文词
|
||||
# 中文通用短词(如"怎么")不具有话题区分度,排除
|
||||
# 排除通用疑问句式(所有话题都会出现的 pattern)
|
||||
CHINESE_STOPWORDS = {
|
||||
# 通用疑问结尾(所有query都有,无区分度)
|
||||
'有什么区别', '是什么', '怎么办', '怎么改', '怎么用',
|
||||
'工作', '区别', '应用', '场景', '方法', '问题',
|
||||
'有什么', '怎么', '哪些', '那个', '这个', '什么',
|
||||
'吗', '呢', '吧', '啊', '呀', '哦', '哈',
|
||||
# 通用技术词(跨话题都出现)
|
||||
'区别是', '有什么区', '是什', '怎么才', '如何', '为什么',
|
||||
# 2字通用词
|
||||
'可以', '需要', '应该', '哪个', '什么', '怎么', '为什',
|
||||
'时候', '情况', '一下', '问题', '相关', '以上', '以下',
|
||||
'另外', '继续', '还有', '然后', '前面', '后面', '中间',
|
||||
'时候', '过程', '机制', '原理', '方式', '方法', '策略',
|
||||
'实现', '使用', '配置', '处理', '解决', '分析', '优化',
|
||||
'监控', '调试', '设置', '修改', '更新', '升级', '迁移',
|
||||
}
|
||||
content_words = set()
|
||||
# 英文单词和代码标识符(所有长度 >= 2)
|
||||
for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_text):
|
||||
|
||||
# 先过滤掉疑问句尾 pattern,再提取
|
||||
# 移除所有通用疑问 pattern(如"有什么区别")
|
||||
query_clean = query_text
|
||||
generic_patterns = [
|
||||
r'有什么.*?[???]', r'怎么.*?[???]', r'是什么.*?[???]',
|
||||
r'为什么.*?[???]', r'如何.*?[???]', r'是否.*?[???]',
|
||||
r'能不能.*?[???]', r'可以吗.*?[???]', r'吗.*?[???]',
|
||||
r'么.*?[???]', r'哪些.*?[???]', r'哪个.*?[???]',
|
||||
r'多少.*?[???]', r'怎样.*?[???]',
|
||||
]
|
||||
for pat in generic_patterns:
|
||||
query_clean = re.sub(pat, '', query_clean)
|
||||
|
||||
# 英文单词和代码标识符(长度 >= 2,有区分度)
|
||||
for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_clean):
|
||||
if len(w) >= 2:
|
||||
content_words.add(w.lower())
|
||||
# 版本号
|
||||
for v in re.findall(r'v?\d+(\.\d+)*', query_text):
|
||||
for v in re.findall(r'v?\d+(\.\d+)*', query_clean):
|
||||
content_words.add(v.lower())
|
||||
# 2字及以上中文词(覆盖"PostgreSQL"等专有名词)
|
||||
for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_text):
|
||||
# 2字及以上中文词(需过滤 stopwords)
|
||||
for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_clean):
|
||||
if chunk not in CHINESE_STOPWORDS:
|
||||
content_words.add(chunk.lower())
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
|
||||
Reference in New Issue
Block a user