#!/usr/bin/env python3 """Phase 4: End-to-End Context Quality Evaluation""" import sys, os, json sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from src.gatekeeper import ContextGatekeeper def estimate_tokens(text): if not text: return 0 chinese = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') english = len([w for w in text.split() if w.isascii()]) return int(chinese * 0.4 + english * 1.3 + len(text) * 0.05) def measure_prompt_tokens(selected, query): ctx = "".join(f"用户: {i['user']}\n助手: {i['assistant']}\n\n" for i in selected) return estimate_tokens(ctx) + int(estimate_tokens(ctx) * 0.08) + estimate_tokens(query) def evaluate_context_quality(gate, query, target_topic, context): total_blocks = len(context) topic_blocks = 0 other_topics_found = [] for item in context: text = item['user'] + ' ' + item['assistant'] found = [t for t in ['Redis', 'asyncio', 'PostgreSQL', 'Git'] if t.lower() in text.lower() and t.lower() != target_topic.lower()] if found: other_topics_found.extend(found) else: topic_blocks += 1 block_purity = topic_blocks / total_blocks if total_blocks > 0 else 0 q_anchors, _ = gate.anchor_extractor.extract_with_deictic(query) context_text = ' '.join(item['user'] + ' ' + item['assistant'] for item in context) covered = sum(1 for a in q_anchors if a.lower() in context_text.lower()) anchor_recall = covered / len(q_anchors) if q_anchors else 0 key_terms = { 'PostgreSQL': ['explain', 'analyze', 'mvcc', 'vacuum', '索引', '执行计划'], 'Git': ['rebase', 'merge', 'reset', 'commit', '分支'], 'Redis': ['redis', '分布式锁', '惰性删除', '定期删除', '过期'], 'asyncio': ['asyncio', 'task', 'cancel', '事件循环', '协程'], } relevant_terms = key_terms.get(target_topic, []) context_lower = context_text.lower() terms_found = [t for t in relevant_terms if t.lower() in context_lower] term_coverage = len(terms_found) / len(relevant_terms) if relevant_terms else 0 return { 'block_purity': block_purity, 'anchor_recall': anchor_recall, 'term_coverage': term_coverage, 'other_topics': list(set(other_topics_found)), 'context_tokens': estimate_tokens(context_text), } def evaluate_retrieval_quality(gate, query, target_topic): sel_cgk = gate.select(query) last5_context = gate.blocks[-5:] last5_items = [{'user': b.user_text, 'assistant': b.assistant_text} for b in last5_context] cgk_quality = evaluate_context_quality(gate, query, target_topic, sel_cgk) last5_quality = evaluate_context_quality(gate, query, target_topic, last5_items) cgk_prompt_tok = measure_prompt_tokens(sel_cgk, query) last5_prompt_tok = measure_prompt_tokens(last5_items, query) return { 'cgk': {**cgk_quality, 'prompt_tokens': cgk_prompt_tok}, 'last5': {**last5_quality, 'prompt_tokens': last5_prompt_tok}, } redis_qa = [ ("Redis 分布式锁和 RedLock 算法有什么区别?", "RedLock是..."), ("Redis 集群环境下怎么做分布式锁?", "集群下..."), ("Redis 惰性删除和定期删除有什么区别?", "惰性删除..."), ("Redis 的过期 key 对 RDB 快照有什么影响?", "过期key..."), ("Redis 主从复制断线后如何增量同步?", "PSYNC..."), ] asyncio_qa = [ ("asyncio.Task 的 cancel 方法怎么工作的?", "cancel..."), ("asyncio.gather 和 asyncio.wait 的返回结果有什么区别?", "gather..."), ("asyncio 的事件循环怎么启动和停止?", "事件循环..."), ("asyncio.sleep 和 time.sleep 的区别是什么?", "sleep..."), ("asyncio 的 Future 对象怎么获取结果?", "Future..."), ] pg_qa = [ ("PostgreSQL 的 MVCC 机制是怎么保证读不阻塞写的?", "MVCC..."), ("PostgreSQL 的 VACUUM 为什么要定期运行?", "VACUUM..."), ("EXPLAIN ANALYZE 怎么看执行计划?", "EXPLAIN..."), ("PostgreSQL B-tree 索引和 Hash 索引的区别是什么?", "B-tree..."), ("PostgreSQL 的 TOAST 机制是什么?", "TOAST..."), ] git_qa = [ ("Git 的 rebase 和 merge 的区别是什么?", "rebase..."), ("Git reset 的 --soft、--mixed、--hard 有什么区别?", "reset..."), ("Git stash 暂存区和工作目录的区别是什么?", "stash..."), ("Git 的 bisect 怎么用来快速定位 bug?", "bisect..."), ("Git 的 reflog 怎么用来恢复误删的提交?", "reflog..."), ] TEST_SEQ = [ ("问PG", "EXPLAIN ANALYZE 怎么看执行计划?", "PostgreSQL"), ("问Git", "Git 的 rebase 和 merge 有什么区别?", "Git"), ("问Redis", "Redis 惰性删除和定期删除有什么区别?", "Redis"), ("问asyncio", "asyncio.Task 的 cancel 方法怎么工作的?", "asyncio"), ("再问Git", "Git 的 reset 和 revert 的应用场景有什么区别?", "Git"), ] def build_gate(): g = ContextGatekeeper(token_budget=4000) for i in range(5): g.add_turn(redis_qa[i][0], redis_qa[i][1]) g.add_turn(asyncio_qa[i][0], asyncio_qa[i][1]) g.add_turn(pg_qa[i][0], pg_qa[i][1]) g.add_turn(git_qa[i][0], git_qa[i][1]) return g def main(): print("="*70) print("Phase 4: End-to-End Context Quality Evaluation") print("="*70) print("评估维度:") print(" - block_purity: 目标话题块占比(1.0=纯目标话题)") print(" - anchor_recall: query锚点在上下文中的覆盖率") print(" - term_coverage: 目标话题关键术语在上下文中的覆盖率") print(" - prompt_tokens: 完整prompt token数(含格式化开销)") print() cgk_tokens_list, last5_tokens_list = [], [] cgk_purities, last5_purities = [], [] cgk_anchors, last5_anchors = [], [] cgk_terms, last5_terms = [], [] cgk_cont, last5_cont = [], [] for label, query, target in TEST_SEQ: gate = build_gate() r = evaluate_retrieval_quality(gate, query, target) cgk = r['cgk'] last5 = r['last5'] cgk_tokens_list.append(cgk['prompt_tokens']) last5_tokens_list.append(last5['prompt_tokens']) cgk_purities.append(cgk['block_purity']) last5_purities.append(last5['block_purity']) cgk_anchors.append(cgk['anchor_recall']) last5_anchors.append(last5['anchor_recall']) cgk_terms.append(cgk['term_coverage']) last5_terms.append(last5['term_coverage']) cgk_cont.append(1 if cgk['other_topics'] else 0) last5_cont.append(1 if last5['other_topics'] else 0) print(f"\n[{label}] {query}") print(f" CGK: tok={cgk['prompt_tokens']:.0f}, purity={cgk['block_purity']:.2f}, " f"anchor={cgk['anchor_recall']:.2f}, term={cgk['term_coverage']:.2f}, " f"other_topics={cgk['other_topics']}") print(f" Last-5: tok={last5['prompt_tokens']:.0f}, purity={last5['block_purity']:.2f}, " f"anchor={last5['anchor_recall']:.2f}, term={last5['term_coverage']:.2f}, " f"other_topics={last5['other_topics']}") saving = (1 - cgk['prompt_tokens']/last5['prompt_tokens'])*100 print(f" → CGK节省{saving:.0f}%token, 纯度+{(cgk['block_purity']-last5['block_purity'])*100:.0f}%, " f"混入话题{len(cgk['other_topics'])}个 vs Last-5 {len(last5['other_topics'])}个") n = len(TEST_SEQ) print("\n" + "="*70) print("Summary: CGK vs Last-5 (avg over {} queries)".format(n)) print("="*70) print(f"\n{'Metric':<25} {'CGK':>12} {'Last-5':>12} {'Winner':<10}") print("-"*62) cgk_avg_tok = sum(cgk_tokens_list)/n last5_avg_tok = sum(last5_tokens_list)/n print(f"{'Avg prompt tokens':<25} {cgk_avg_tok:>12.1f} {last5_avg_tok:>12.1f} {'CGK' if cgk_avg_tok < last5_avg_tok else 'Last-5':<10}") cgk_avg_pur = sum(cgk_purities)/n last5_avg_pur = sum(last5_purities)/n print(f"{'Avg block purity':<25} {cgk_avg_pur:>12.3f} {last5_avg_pur:>12.3f} {'CGK' if cgk_avg_pur > last5_avg_pur else 'Last-5':<10}") cgk_avg_anc = sum(cgk_anchors)/n last5_avg_anc = sum(last5_anchors)/n print(f"{'Avg anchor recall':<25} {cgk_avg_anc:>12.3f} {last5_avg_anc:>12.3f} {'CGK' if cgk_avg_anc > last5_avg_anc else 'Last-5':<10}") cgk_avg_term = sum(cgk_terms)/n last5_avg_term = sum(last5_terms)/n print(f"{'Avg term coverage':<25} {cgk_avg_term:>12.3f} {last5_avg_term:>12.3f} {'CGK' if cgk_avg_term > last5_avg_term else 'Last-5':<10}") print(f"{'Contamination episodes':<25} {sum(cgk_cont):>12} {sum(last5_cont):>12} {'CGK' if sum(cgk_cont) < sum(last5_cont) else 'Last-5':<10}") print("\n" + "="*70) print("Honest Limitations of This Evaluation") print("="*70) print(""" 1. 没有真实 LLM 调用:评估的是"上下文块的质量",不是"模型答案的质量"。 上下文好 ≠ 答案好,真正的答案质量需要实际调用 LLM。 2. 测试集仍是合成数据:真实对话中用户可能只打"那这个呢"或"为什么", 短 query 的锚点覆盖率会显著低于本测试中的完整问题。 3. 污染只统计了"块级别"混入:即使上下文纯度 100%,LLM 的注意力机制 仍可能跨块建立错误关联,这种"软污染"无法通过块级分析检测。 """) out_path = os.path.join(os.path.dirname(__file__), 'phase4_quality_results.json') with open(out_path, 'w') as f: json.dump({ 'cgk_avg_tokens': cgk_avg_tok, 'last5_avg_tokens': last5_avg_tok, 'cgk_avg_purity': cgk_avg_pur, 'last5_avg_purity': last5_avg_pur, 'cgk_avg_anchor_recall': cgk_avg_anc, 'last5_avg_anchor_recall': last5_avg_anc, 'cgk_avg_term_coverage': cgk_avg_term, 'last5_avg_term_coverage': last5_avg_term, 'cgk_contamination_episodes': sum(cgk_cont), 'last5_contamination_episodes': sum(last5_cont), }, f, indent=2) print(f"\nSaved to: {out_path}") if __name__ == '__main__': main()