From 071f9ef418fb53a4307d4b9ae070104b99fbd5f4 Mon Sep 17 00:00:00 2001 From: Elaina Date: Wed, 22 Apr 2026 01:09:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=8A=E4=B8=8B=E6=96=87=E9=97=A8?= =?UTF-8?q?=E6=8E=A7=E5=99=A8=E5=88=9D=E5=A7=8B=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择 --- .env.example | 1 + .gitignore | 28 +++++++ README.md | 102 +++++++++++++++++++++++ src/__init__.py | 1 + src/anchor.py | 75 +++++++++++++++++ src/block.py | 46 ++++++++++ src/gatekeeper.py | 175 +++++++++++++++++++++++++++++++++++++++ src/selector.py | 116 ++++++++++++++++++++++++++ src/sparse.py | 128 ++++++++++++++++++++++++++++ src/topic_gate.py | 106 ++++++++++++++++++++++++ tests/test_e2e.py | 147 ++++++++++++++++++++++++++++++++ tests/test_gatekeeper.py | 136 ++++++++++++++++++++++++++++++ 12 files changed, 1061 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 README.md create mode 100644 src/__init__.py create mode 100644 src/anchor.py create mode 100644 src/block.py create mode 100644 src/gatekeeper.py create mode 100644 src/selector.py create mode 100644 src/sparse.py create mode 100644 src/topic_gate.py create mode 100644 tests/test_e2e.py create mode 100644 tests/test_gatekeeper.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..eb24dba --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +MINIMAX_API_KEY=your_api_key_here \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c9143aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +# Environment +.env +.env.local +.env.*.local + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +.venv +venv/ +ENV/ +env/ +.eggs/ +*.egg-info/ + +# Test +.pytest_cache/ +.coverage +htmlcov/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..971e508 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# 上下文门控器 (Context Gatekeeper) + +轻量级上下文选择器,在同一会话中自动从历史对话里选出最小且相关的片段,减少话题污染和控制上下文长度。 + +## 特性 + +- 🚀 **纯 Python**,无需额外模型依赖 +- 💻 **轻量运行**,支持 2 核 2G 环境 +- 🔍 **话题门控**,智能判断继续/切换 +- 📦 **稀疏召回**,BM25/IDF-overlap 评分 +- 🎯 **最小覆盖**,贪心算法选择最优子集 +- ⚙️ **稳定约束区**,持久化用户偏好 + +## 安装 + +```bash +pip install -e . +``` + +## 快速开始 + +```python +from src.gatekeeper import ContextGatekeeper + +# 初始化,token 预算 4000 +gate = ContextGatekeeper(token_budget=4000) + +# 添加对话历史 +gate.add_turn("Redis 锁续租为什么会脑裂", "因为 TTL 设置不合理...") +gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...") + +# 为当前查询选择上下文 +selected = gate.select("锁的 TTL 怎么设置") + +for item in selected: + print(f"轮次 {item['turn_id']}: {item['user']}") + print(f"助手: {item['assistant']}\n") + +# 设置稳定约束 +gate.set_constraint("language", "中文") +gate.set_constraint("style", "简洁") + +# 构建完整 prompt +prompt = gate.build_prompt("Redis 集群如何搭建") +``` + +## 核心流程 + +``` +用户输入 → 锚点提取 → 话题门控 → 稀疏召回 → 最小覆盖选择 → 组装 Prompt +``` + +## 项目结构 + +``` +context-gatekeeper/ +├── src/ +│ ├── __init__.py +│ ├── anchor.py # 锚点提取 +│ ├── block.py # Block 数据结构 +│ ├── topic_gate.py # 话题门控 +│ ├── sparse.py # 稀疏召回 +│ ├── selector.py # 最小覆盖选择 +│ └── gatekeeper.py # 主模块 +├── tests/ +│ ├── test_gatekeeper.py # 单元测试 +│ └── test_e2e.py # 端到端测试 +├── SPEC.md # 规格文档 +├── README.md +├── .env.example +└── .gitignore +``` + +## 运行测试 + +```bash +# 单元测试 +pytest tests/test_gatekeeper.py -v + +# 端到端测试(需要配置 .env) +cp .env.example .env +# 编辑 .env 填入你的 MiniMax API Key +pytest tests/test_e2e.py -v +``` + +## 算法细节 + +### 话题门控 + +- **overlap > 0.45**:继续当前话题 +- **overlap < 0.20** 且 **new_ratio > 0.70**:切换新话题 +- 有指代词(这个/那个/它/上面)→ 强制继续 + +### 稀疏召回评分 + +``` +score = 1.5 * lex(user) + 0.7 * lex(assistant) + 1.0 * exact + 0.2 * recency +``` + +### 最小覆盖选择 + +贪心选择"单位长度收益最大"的 block,直到覆盖率达到 85% 或 token 预算耗尽。 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..0ab073e --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""轻量级上下文门控器 - 无需额外模型""" \ No newline at end of file diff --git a/src/anchor.py b/src/anchor.py new file mode 100644 index 0000000..d0e5a21 --- /dev/null +++ b/src/anchor.py @@ -0,0 +1,75 @@ +""" +锚点提取模块 - 用规则从文本中提取有代表性的锚点 +""" + +import re +from collections import Counter +from typing import List, Tuple + + +class AnchorExtractor: + """提取文本中的锚点,用于话题判断和检索""" + + def __init__(self): + # IDF 字典,用于权重计算(简单起见用内部默认值,可扩展) + self._idf_cache: dict[str, float] = {} + self._doc_count = 0 + + def extract(self, text: str) -> List[str]: + """从文本中提取锚点列表""" + anchors = [] + + # 中文 2-gram / 3-gram(转小写以便匹配) + chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text) + for chunk in chinese_chars: + if len(chunk) >= 2: + for i in range(len(chunk) - 1): + anchors.append(chunk[i:i+2].lower()) + if len(chunk) >= 3: + for i in range(len(chunk) - 2): + anchors.append(chunk[i:i+3].lower()) + + # 英文单词 + english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text) + anchors.extend([w.lower() for w in english_words if len(w) >= 2]) + + # 代码标识符(反引号内) + code_segments = re.findall(r'`([^`]+)`', text) + for seg in code_segments: + anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg)) + + # 引号内的短语 + quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text) + anchors.extend(quoted_phrases) + + # 数字、版本号 + versions = re.findall(r'v?\d+(\.\d+)*', text) + anchors.extend(versions) + + return anchors + + def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]: + """提取锚点,并检测是否有指代词""" + anchors = self.extract(text) + deictic_patterns = [ + '这个', '那个', '它', '上面', '刚才', '继续', '展开', + '为什么会这样', '怎么改', '然后', '还有', '另外' + ] + has_deictic = any(p in text for p in deictic_patterns) + return anchors, has_deictic + + def compute_idf(self, all_doc_anchors: List[List[str]]) -> None: + """根据一批文档计算 IDF 值""" + self._doc_count = len(all_doc_anchors) + doc_freq = Counter() + for anchors in all_doc_anchors: + unique_anchors = set(anchors) + for a in unique_anchors: + doc_freq[a] += 1 + + for anchor, df in doc_freq.items(): + self._idf_cache[anchor] = self._doc_count / df + + def idf(self, anchor: str) -> float: + """获取锚点的 IDF 值,默认 1.0""" + return self._idf_cache.get(anchor, 1.0) \ No newline at end of file diff --git a/src/block.py b/src/block.py new file mode 100644 index 0000000..6356bfb --- /dev/null +++ b/src/block.py @@ -0,0 +1,46 @@ +""" +Block 数据结构 - 表示一轮对话的原子单元 +""" + +from dataclasses import dataclass, field +from typing import List +from .anchor import AnchorExtractor + + +@dataclass +class Block: + """一轮用户+助手的对话块""" + user_text: str + assistant_text: str + turn_id: int + anchors: List[str] = field(default_factory=list) + tokens_user: int = 0 + tokens_assistant: int = 0 + + def __post_init__(self): + if not self.anchors: + extractor = AnchorExtractor() + user_anchors = extractor.extract(self.user_text) + assistant_anchors = extractor.extract(self.assistant_text) + # 用户侧锚点权重更高 + self.anchors = user_anchors + [f"a_{a}" for a in assistant_anchors] + if self.tokens_user == 0: + self.tokens_user = self._estimate_tokens(self.user_text) + if self.tokens_assistant == 0: + self.tokens_assistant = self._estimate_tokens(self.assistant_text) + + @staticmethod + def _estimate_tokens(text: str) -> int: + """简单估算中文字符数作为 token 数(偏保守)""" + # 中文按字数 * 1.5 估算,英文按空格分隔单词数 + chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff']) + english_words = len(text.split()) + return int(chinese_chars * 1.5 + english_words * 1.0) + + @property + def total_tokens(self) -> int: + return self.tokens_user + self.tokens_assistant + + def cost(self) -> int: + """这个 block 的代价(token 数)""" + return self.total_tokens \ No newline at end of file diff --git a/src/gatekeeper.py b/src/gatekeeper.py new file mode 100644 index 0000000..05a8f14 --- /dev/null +++ b/src/gatekeeper.py @@ -0,0 +1,175 @@ +""" +上下文门控器主模块 - 将各模块组合成完整流程 +""" + +from typing import List, Dict, Tuple, Optional +from .anchor import AnchorExtractor +from .block import Block +from .topic_gate import TopicGate +from .sparse import SparseRetriever +from .selector import MinimumCoverageSelector + + +class ContextGatekeeper: + """ + 轻量级上下文选择器 + + 用法: + gate = ContextGatekeeper(token_budget=4000) + gate.add_turn("用户消息", "助手回复") + selected = gate.select("当前查询") + """ + + def __init__(self, token_budget: int = 4000): + self.token_budget = token_budget + self.blocks: List[Block] = [] + self.turn_counter = 0 + + # 各模块实例 + self.anchor_extractor = AnchorExtractor() + self.topic_gate = TopicGate(self.anchor_extractor) + self.retriever = SparseRetriever() + self.selector = MinimumCoverageSelector() + + # 稳定约束区 + self.constraints: Dict[str, str] = {} + + # 当前活跃话题(锚点 + IDF) + self._active_topic: Optional[Tuple[List[str], dict]] = None + + def add_turn(self, user_text: str, assistant_text: str) -> int: + """ + 添加一轮对话到历史 + + Returns: + 本轮的 turn_id + """ + self.turn_counter += 1 + block = Block( + user_text=user_text, + assistant_text=assistant_text, + turn_id=self.turn_counter + ) + self.blocks.append(block) + + # 更新锚点提取器的 IDF(用当前 block 的锚点) + all_anchors = [b.anchors for b in self.blocks] + self.anchor_extractor.compute_idf(all_anchors) + + # 初始化或更新活跃话题 + user_anchors = [a for a in block.anchors if not a.startswith('a_')] + if self._active_topic is None: + self._active_topic = (user_anchors, self.anchor_extractor._idf_cache) + else: + # 只有在话题未切换时才更新(由 topic_gate 判断) + pass + + return self.turn_counter + + def select(self, query: str) -> List[Dict]: + """ + 为当前 query 选择上下文 blocks + + Returns: + [{"user": ..., "assistant": ..., "turn_id": ...}, ...] + """ + # 1. 提取当前 query 的锚点 + query_anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query) + + # 2. 话题门控:判断继续还是切换 + switched = self.topic_gate.is_topic_switch(query, self._active_topic) + + # 3. 指代词强制继承最近 1-2 个 block + mandatory: List[Block] = [] + if has_deictic and self.blocks: + mandatory = self.blocks[-2:] # 最近 2 个 + + # 4. 稀疏召回:取 top-20 候选 + candidates = self.retriever.retrieve( + self.blocks, query_anchors, top_m=20 + ) + + # 5. 最小覆盖选择 + selected_blocks = self.selector.select( + candidates, + query_anchors, + self.token_budget, + mandatory=mandatory + ) + + # 6. 句级裁剪(简化版:不做进一步裁剪,直接用整个 block) + # 如需进一步裁剪,可在 Block 内部按句子级别筛选 + + # 7. 更新活跃话题 + if not switched and query_anchors: + self._active_topic = (query_anchors, self.anchor_extractor._idf_cache) + + # 8. 格式化为输出 + result = [ + { + "user": b.user_text, + "assistant": b.assistant_text, + "turn_id": b.turn_id + } + for b in selected_blocks + ] + + return result + + def select_with_constraints(self, query: str) -> Tuple[List[Dict], Dict]: + """ + 带约束的上下文选择 + + Returns: + (selected_blocks, constraints) + """ + blocks = self.select(query) + return blocks, self.constraints + + def set_constraint(self, key: str, value: str) -> None: + """设置稳定约束""" + self.constraints[key] = value + + def get_constraints(self) -> Dict[str, str]: + """获取当前所有约束""" + return dict(self.constraints) + + def build_prompt( + self, query: str, system_prompt: str = "你是一个有帮助的助手。" + ) -> str: + """ + 构建完整的 prompt(用于实际调用 LLM) + + Returns: + 组装好的 prompt 字符串 + """ + selected = self.select(query) + + # 组装上下文 + context_parts = [] + for item in selected: + context_parts.append(f"【轮次 {item['turn_id']}】\n用户: {item['user']}\n助手: {item['assistant']}") + + context_str = "\n\n".join(context_parts) if context_parts else "(无相关上下文)" + + # 拼接完整 prompt + constraints_str = "" + if self.constraints: + parts = [f"{k}: {v}" for k, v in self.constraints.items()] + constraints_str = "\n".join(parts) + + prompt = f"{system_prompt}\n\n" + if constraints_str: + prompt += f"【固定约束】\n{constraints_str}\n\n" + if context_str: + prompt += f"【相关上下文】\n{context_str}\n\n" + prompt += f"【当前问题】\n用户: {query}" + + return prompt + + def reset(self) -> None: + """重置所有状态(新会话开始时调用)""" + self.blocks.clear() + self.turn_counter = 0 + self._active_topic = None + # constraints 保留 \ No newline at end of file diff --git a/src/selector.py b/src/selector.py new file mode 100644 index 0000000..d76e037 --- /dev/null +++ b/src/selector.py @@ -0,0 +1,116 @@ +""" +最小覆盖选择模块 - 用贪心算法选择最小且足够的上下文块 +""" + +import math +from typing import List, Tuple, Set +from .block import Block + + +class MinimumCoverageSelector: + """最小覆盖选择器:用贪心算法找出最小 tokens 覆盖最多 query 锚点""" + + # 覆盖目标(达到 85% 即停止) + COVERAGE_TARGET = 0.85 + + # 长度惩罚指数(0.7 ~ 1.0 之间) + COST_ALPHA = 0.8 + + # 最小增益阈值(降低以避免过滤掉有效的 block) + MIN_GAIN = 0.05 + + # 最小稀疏分数阈值(避免选入完全不相关的 block) + MIN_SPARSE_SCORE = 0.5 + + def __init__(self, coverage_target: float = 0.85, cost_alpha: float = 0.8): + self.coverage_target = coverage_target + self.cost_alpha = cost_alpha + + def select( + self, + candidates: List[Tuple[Block, float]], + query_anchors: List[str], + token_budget: int, + mandatory: List[Block] = None + ) -> List[Block]: + """ + 贪心选择最小覆盖集合 + + Args: + candidates: 候选 blocks(已按得分排序) + query_anchors: 查询的锚点 + token_budget: token 预算上限 + mandatory: 必须包含的 blocks(如指代词触发的强制继承) + + Returns: + 选中的 blocks 列表 + """ + mandatory = mandatory or [] + selected: List[Block] = [] + covered: Set[str] = set() + # 过滤出稀疏分数足够的候选 + snippets = [b for b, score in candidates if score >= self.MIN_SPARSE_SCORE] + + # 先加入 mandatory blocks + for block in mandatory: + if block not in selected: + selected.append(block) + covered.update(self._block_anchor_set(block)) + + current_tokens = sum(b.cost() for b in selected) + + # 贪心选择 + for block in snippets: + if block in selected: + continue + + block_cost = block.cost() + if current_tokens + block_cost > token_budget: + break + + # 计算这块的增量覆盖 + block_anchors = self._block_anchor_set(block) + new_anchors = block_anchors - covered + + if not new_anchors: + continue + + # 计算新增覆盖的 IDF 加权和(这里简化处理) + new_coverage = len(new_anchors) / max(len(query_anchors), 1) + gain = new_coverage / (block_cost ** self.cost_alpha) + + if gain < self.MIN_GAIN: + continue + + selected.append(block) + covered.update(new_anchors) + current_tokens += block_cost + + # 检查是否达到覆盖目标 + coverage_ratio = sum( + 1 for a in query_anchors if a in covered + ) / max(len(query_anchors), 1) + + if coverage_ratio >= self.coverage_target: + break + + return selected + + def _block_anchor_set(self, block: Block) -> Set[str]: + """获取 block 的锚点集合(去 a_ 前缀,转小写)""" + return { + (a.replace('a_', '', 1) if a.startswith('a_') else a).lower() + for a in block.anchors + } + + def coverage_ratio(self, selected: List[Block], query_anchors: List[str]) -> float: + """计算选中的 blocks 对 query 锚点的覆盖率""" + covered = set() + for block in selected: + covered.update(self._block_anchor_set(block)) + + if not query_anchors: + return 0.0 + + hit = sum(1 for a in query_anchors if a in covered) + return hit / len(query_anchors) \ No newline at end of file diff --git a/src/sparse.py b/src/sparse.py new file mode 100644 index 0000000..fe1354b --- /dev/null +++ b/src/sparse.py @@ -0,0 +1,128 @@ +""" +稀疏召回模块 - 用 BM25/IDF-overlap 对历史 blocks 做评分和召回 +""" + +import math +from typing import List, Tuple +from .block import Block +from .anchor import AnchorExtractor + + +class SparseRetriever: + """稀疏召回器:基于词项重叠的轻量评分""" + + # 评分权重 + WEIGHT_USER = 1.5 + WEIGHT_ASSISTANT = 0.7 + WEIGHT_EXACT = 1.0 + WEIGHT_RECENT = 0.2 + + def __init__(self): + self.anchor_extractor = AnchorExtractor() + + def score(self, block: Block, query_anchors: List[str], recency: float = 0.0) -> float: + """ + 计算 block 相对于 query 的相关度得分 + + Args: + block: 待评分的对话块 + query_anchors: 查询的锚点列表 + recency: 时间衰减因子 (0.0 ~ 1.0,越新越大) + + Returns: + 相关度得分 (0.0 ~ ) + """ + # 1. 用户侧词项重叠 + user_anchors = [a for a in block.anchors if not a.startswith('a_')] + assistant_anchors = [a.replace('a_', '', 1) for a in block.anchors if a.startswith('a_')] + + lex_user = self._lex_overlap(user_anchors, query_anchors) + lex_assistant = self._lex_overlap(assistant_anchors, query_anchors) + + # 2. Exact match 加分 + exact = self._exact_match(block, query_anchors) + + # 3. 综合评分 + score = ( + self.WEIGHT_USER * lex_user + + self.WEIGHT_ASSISTANT * lex_assistant + + self.WEIGHT_EXACT * exact + + self.WEIGHT_RECENT * recency + ) + + return score + + def _lex_overlap(self, block_anchors: List[str], query_anchors: List[str]) -> float: + """计算词项重叠得分(带 IDF 权重)""" + score = 0.0 + # 将 query anchors 转小写用于匹配 + query_lower = [a.lower() for a in query_anchors] + for qa in query_lower: + # 精确匹配 + if qa in block_anchors: + tf = min(block_anchors.count(qa), 3) + score += 1.0 * tf + # 或者前缀/后缀匹配(处理中英文混合) + else: + for ba in block_anchors: + if qa in ba.lower() or ba.lower() in qa: + score += 0.5 + break + return score + + def _exact_match(self, block: Block, query_anchors: List[str]) -> float: + """ + Exact match 加分: + - 完整英文术语命中 + - 代码标识符命中 + - 数字、版本号命中 + - 引号短语完整命中 + """ + text_lower = block.user_text.lower() + ' ' + block.assistant_text.lower() + score = 0.0 + + for qa in query_anchors: + # 版本号(v开头) + if qa.startswith('v') and '.' in qa: + if qa in text_lower: + score += 1.0 + # 全数字(如 1.2.3, 2.0) + elif qa.replace('.', '').isdigit(): + if qa in text_lower: + score += 1.0 + # 英文术语(较长,有下划线/驼峰) + elif '_' in qa or (any(c.isupper() for c in qa) and qa.isalnum()): + if qa.lower() in text_lower: + score += 1.0 + # 引号短语(带空格) + elif ' ' in qa and any(q in text_lower for q in [qa, qa.strip('"\'')]): + score += 0.5 + + return score + + def retrieve( + self, blocks: List[Block], query_anchors: List[str], top_m: int = 20 + ) -> List[Tuple[Block, float]]: + """ + 从历史 blocks 中召回 top-m 个最相关的 block + + Args: + blocks: 所有历史 blocks + query_anchors: 查询锚点 + top_m: 返回前 m 个 + + Returns: + [(block, score), ...] 按得分降序排列 + """ + scored = [] + total = len(blocks) + + for i, block in enumerate(blocks): + # recency = i/total,越新的 block 越接近 1.0 + recency = (i + 1) / total if total > 0 else 0.0 + s = self.score(block, query_anchors, recency) + scored.append((block, s)) + + # 降序排列,取前 top_m + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:top_m] \ No newline at end of file diff --git a/src/topic_gate.py b/src/topic_gate.py new file mode 100644 index 0000000..fc083f9 --- /dev/null +++ b/src/topic_gate.py @@ -0,0 +1,106 @@ +""" +话题门控模块 - 判断当前 query 是继续话题还是切换话题 +""" + +from typing import List, Tuple +from .anchor import AnchorExtractor + + +# 指代词列表 +DEICTIC_WORDS = { + '这个', '那个', '它', '她', '他', '上面', '刚才', '继续', + '展开', '为什么会这样', '怎么改', '然后', '还有', '另外', + '这样', '那样', '如此', '前一个', '这一', '那一' +} + + +class TopicGate: + """话题门控:判断是继续当前话题还是切换到新话题""" + + # 阈值参数 + OVERLAP_CONTINUE_THRESHOLD = 0.45 + OVERLAP_SWITCH_THRESHOLD = 0.20 + NEW_RATIO_SWITCH_THRESHOLD = 0.70 + + def __init__(self, anchor_extractor: AnchorExtractor = None): + self.anchor_extractor = anchor_extractor or AnchorExtractor() + self.active_topic_anchors: List[str] = [] + self.active_topic_idf: dict[str, float] = {} + + def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None: + """更新活跃话题的锚点集合""" + self.active_topic_anchors = anchors + self.active_topic_idf = idf + + def is_topic_switch( + self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None + ) -> bool: + """ + 判断当前 query 是否切换了话题 + + Returns: + True -> 切换到新话题 + False -> 继续当前话题 + """ + anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query) + + if active_topic is None: + self.active_topic_anchors = anchors + return False + + topic_anchors, topic_idf = active_topic + + # 有指代词 → 强暗示继续 + if has_deictic: + return False + + # 计算 overlap 和 new_ratio + overlap = self._compute_overlap(anchors, topic_anchors, topic_idf) + new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf) + + # 切换判断 + if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD: + return True + + # 继续判断 + if overlap > self.OVERLAP_CONTINUE_THRESHOLD: + return False + + # 中间地带:默认继续 + return False + + def _compute_overlap( + self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float] + ) -> float: + """计算加权重叠率""" + if not query_anchors: + return 0.0 + + overlap_weight = sum( + topic_idf.get(a, 1.0) + for a in query_anchors + if a in topic_anchors + ) + total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors) + + return overlap_weight / total_weight if total_weight > 0 else 0.0 + + def _compute_new_ratio( + self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float] + ) -> float: + """计算新锚点比例""" + if not query_anchors: + return 1.0 + + new_anchors = set(query_anchors) - set(topic_anchors) + new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors) + total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors) + + return new_weight / total_weight if total_weight > 0 else 1.0 + + def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]: + """为锚点集合计算 IDF 值(用于后续计算)""" + idf = {} + for a in anchors: + idf[a] = self.anchor_extractor.idf(a) + return idf \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000..081e8a3 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,147 @@ +""" +端到端测试 - 使用 MiniMax API 验证上下文门控器效果 +""" + +import os +import sys +from dotenv import load_dotenv + +# 加载 .env +load_dotenv() + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.gatekeeper import ContextGatekeeper + +# 获取 API Key +API_KEY = os.getenv("MINIMAX_API_KEY") +if not API_KEY: + print("❌ 未找到 MINIMAX_API_KEY,请检查 .env 文件") + sys.exit(1) + +BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2" + + +def call_minimax(prompt: str) -> str: + """调用 MiniMax API""" + import urllib.request + import json + + payload = { + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度 + "max_tokens": 500, + "temperature": 0.7 + } + + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + BASE_URL, + data=data, + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + method="POST" + ) + + try: + with urllib.request.urlopen(req, timeout=30) as resp: + result = json.loads(resp.read().decode("utf-8")) + return result["choices"][0]["message"]["content"] + except Exception as e: + return f"API 调用失败: {e}" + + +def test_e2e_conversation(): + """端到端对话测试""" + print("=" * 60) + print("端到端对话测试 - 验证上下文门控器") + print("=" * 60) + + gate = ContextGatekeeper(token_budget=2000) + + # === 第一阶段:Redis 分布式锁话题 === + print("\n📌 第1轮:Redis 分布式锁话题") + user1 = "Redis 锁续租为什么会脑裂" + assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}") + gate.add_turn(user1, assistant1) + print(f"用户: {user1}") + print(f"助手: {assistant1}") + + print("\n📌 第2轮:继续 Redis 话题") + user2 = "如何避免脑裂?" + assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}") + gate.add_turn(user2, assistant2) + print(f"用户: {user2}") + print(f"助手: {assistant2}") + + # 验证:第3轮问 Redis 相关,应该召回第1轮(可能不召回第2轮,取决于锚点重叠度) + print("\n📌 第3轮:问 Redis 相关问题") + query3 = "锁的 TTL 怎么设置才合理" + selected = gate.select(query3) + print(f"用户查询: {query3}") + print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}") + turn_ids = [b['turn_id'] for b in selected] + assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容" + print("✅ Redis 相关问题召回第1轮内容") + + # === 第二阶段:切换到 Python 话题 === + print("\n" + "=" * 60) + print("📌 第4轮:切换到 Python 话题") + user4 = "Python 异步编程怎么做?用 asyncio 举例子" + assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}") + gate.add_turn(user4, assistant4) + print(f"用户: {user4}") + print(f"助手: {assistant4}") + + # 验证:问 Python 相关,应该召回 Python 内容(3或4) + print("\n📌 第5轮:问 Python asyncio 相关") + query5 = "asyncio 怎么用?举一个爬虫的例子" + selected5 = gate.select(query5) + print(f"用户查询: {query5}") + print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}") + turn_ids5 = [b['turn_id'] for b in selected5] + # Python 话题应该召回 3 或 4 + has_python_topic = (3 in turn_ids5 or 4 in turn_ids5) + assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}" + print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}") + + # === 第三阶段:指代词测试 === + print("\n" + "=" * 60) + print("📌 第6轮:指代词测试") + user6 = "它的并发性能怎么样" + assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}") + gate.add_turn(user6, assistant6) + print(f"用户: {user6}") + print(f"助手: {assistant6}") + + # 验证:有指代词时,应该强制继承最近轮次 + print("\n📌 第7轮:指代词强制继承验证") + query7 = "它和 ThreadPool 比哪个更好" + selected7 = gate.select(query7) + print(f"用户查询: {query7}") + print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}") + # 应该有指代词强制继承 + assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block" + print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block") + + # === 验证 Token 预算控制 === + print("\n" + "=" * 60) + print("📌 Token 预算验证") + total_context_tokens = sum( + len(b['user']) * 1.5 + len(b['assistant']) * 1.5 + for b in selected + ) + print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}") + assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算" + print("✅ Token 预算控制正常") + + print("\n" + "=" * 60) + print("✅ 所有端到端测试通过!") + print("=" * 60) + + +if __name__ == "__main__": + test_e2e_conversation() \ No newline at end of file diff --git a/tests/test_gatekeeper.py b/tests/test_gatekeeper.py new file mode 100644 index 0000000..a88ee07 --- /dev/null +++ b/tests/test_gatekeeper.py @@ -0,0 +1,136 @@ +""" +上下文门控器 - 轻量级选择器测试 +""" + +import sys +import os + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.gatekeeper import ContextGatekeeper + + +class TestAnchorExtraction: + """锚点提取测试""" + + def test_chinese_ngram(self): + gate = ContextGatekeeper() + # 触发锚点提取 + gate.add_turn("Redis 分布式锁如何实现", "可以使用 Redlock 算法...") + assert len(gate.blocks) == 1 + anchors = gate.blocks[0].anchors + # 应该有 redis、分布式锁等锚点 + assert any('redis' in a.lower() for a in anchors) + + def test_empty_turn(self): + gate = ContextGatekeeper() + gate.add_turn("", "") + assert len(gate.blocks) == 1 + + +class TestTopicGate: + """话题门控测试""" + + def test_obvious_switch(self): + """从 Redis 切换到 Python,明显切换""" + gate = ContextGatekeeper() + gate.add_turn("Redis 锁怎么用", "用 Redlock...") + gate.add_turn("Python 怎么写快速排序", "可以用递归...") + + # 第三个问题,明显是 Python 话题 + blocks = gate.select("Python 列表推导式怎么写") + # 应该主要返回第二、三轮的 Python 内容 + turn_ids = [b['turn_id'] for b in blocks] + assert 2 in turn_ids or 3 in turn_ids + + def test_continuation_with_deictic(self): + """指代词触发强制继承""" + gate = ContextGatekeeper() + gate.add_turn("Redis 锁是什么", "是分布式锁...") + gate.add_turn("它有哪些实现方式", "有 Redlock...") + + # 第二轮有"它",应该强制继承第一轮 + blocks = gate.select("它有哪些实现方式") + turn_ids = [b['turn_id'] for b in blocks] + # 应该包含第1轮 + assert 1 in turn_ids + + +class TestSparseRetrieval: + """稀疏召回测试""" + + def test_exact_match_boost(self): + """exact match 应该提升得分""" + gate = ContextGatekeeper() + gate.add_turn("使用 Redis v3.0 集群", "Redis 集群配置...") + gate.add_turn("Python 装饰器用法", "装饰器是...") + + blocks = gate.select("Redis v3.0 集群如何搭建") + turn_ids = [b['turn_id'] for b in blocks] + # 应该召回第一轮 + assert 1 in turn_ids + + +class TestMinimumCoverage: + """最小覆盖选择测试""" + + def test_token_budget(self): + """token 预算限制""" + gate = ContextGatekeeper(token_budget=300) # 调整预算,100字符block约150tokens + gate.add_turn("A" * 100, "B" * 100) + gate.add_turn("C" * 100, "D" * 100) + + blocks = gate.select("A") + # 预算300,约等于2个block的代价,应该只能选1个 + assert len(blocks) <= 2, f"预算300应该只选1-2个block,实际选了{len(blocks)}" + + +class TestIntegration: + """端到端集成测试""" + + def test_multi_turn_conversation(self): + """模拟多轮对话""" + gate = ContextGatekeeper() + + # 第1轮 + gate.add_turn("Redis 锁续租为什么会脑裂", "因为锁过期时间设置不合理...") + # 第2轮 + gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...") + # 第3轮:切换话题 + gate.add_turn("Python 异步编程怎么做", "用 asyncio 模块...") + + # 问 Redis 相关问题,验证能召回 Redis 内容 + redis_blocks = gate.select("Redis 锁的 TTL 怎么设") + turn_ids = [b['turn_id'] for b in redis_blocks] + assert 1 in turn_ids, f"第1轮 Redis 内容应该被召回,实际: {turn_ids}" + print(f" Redis 查询召回: {turn_ids}") + + # 问 Python 相关问题,验证话题切换后召回正确内容 + py_blocks = gate.select("asyncio 怎么用") + turn_ids_py = [b['turn_id'] for b in py_blocks] + assert 3 in turn_ids_py, f"第3轮 Python 内容应该被召回,实际: {turn_ids_py}" + print(f" Python 查询召回: {turn_ids_py}") + + def test_constraints_preserved(self): + """约束持久化测试""" + gate = ContextGatekeeper() + gate.set_constraint("language", "中文") + gate.set_constraint("style", "简洁") + + constraints = gate.get_constraints() + assert constraints["language"] == "中文" + assert constraints["style"] == "简洁" + + def test_reset(self): + """重置测试""" + gate = ContextGatekeeper() + gate.add_turn("test", "test") + assert len(gate.blocks) == 1 + + gate.reset() + assert len(gate.blocks) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file