feat: 上下文门控器初始实现

- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
Elaina
2026-04-22 01:09:35 +08:00
commit 071f9ef418
12 changed files with 1061 additions and 0 deletions

75
src/anchor.py Normal file
View File

@@ -0,0 +1,75 @@
"""
锚点提取模块 - 用规则从文本中提取有代表性的锚点
"""
import re
from collections import Counter
from typing import List, Tuple
class AnchorExtractor:
"""提取文本中的锚点,用于话题判断和检索"""
def __init__(self):
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
self._idf_cache: dict[str, float] = {}
self._doc_count = 0
def extract(self, text: str) -> List[str]:
"""从文本中提取锚点列表"""
anchors = []
# 中文 2-gram / 3-gram转小写以便匹配
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
for chunk in chinese_chars:
if len(chunk) >= 2:
for i in range(len(chunk) - 1):
anchors.append(chunk[i:i+2].lower())
if len(chunk) >= 3:
for i in range(len(chunk) - 2):
anchors.append(chunk[i:i+3].lower())
# 英文单词
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
# 代码标识符(反引号内)
code_segments = re.findall(r'`([^`]+)`', text)
for seg in code_segments:
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
# 引号内的短语
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
anchors.extend(quoted_phrases)
# 数字、版本号
versions = re.findall(r'v?\d+(\.\d+)*', text)
anchors.extend(versions)
return anchors
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
"""提取锚点,并检测是否有指代词"""
anchors = self.extract(text)
deictic_patterns = [
'这个', '那个', '', '上面', '刚才', '继续', '展开',
'为什么会这样', '怎么改', '然后', '还有', '另外'
]
has_deictic = any(p in text for p in deictic_patterns)
return anchors, has_deictic
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
"""根据一批文档计算 IDF 值"""
self._doc_count = len(all_doc_anchors)
doc_freq = Counter()
for anchors in all_doc_anchors:
unique_anchors = set(anchors)
for a in unique_anchors:
doc_freq[a] += 1
for anchor, df in doc_freq.items():
self._idf_cache[anchor] = self._doc_count / df
def idf(self, anchor: str) -> float:
"""获取锚点的 IDF 值,默认 1.0"""
return self._idf_cache.get(anchor, 1.0)