feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
75
src/anchor.py
Normal file
75
src/anchor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
锚点提取模块 - 用规则从文本中提取有代表性的锚点
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class AnchorExtractor:
|
||||
"""提取文本中的锚点,用于话题判断和检索"""
|
||||
|
||||
def __init__(self):
|
||||
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
|
||||
self._idf_cache: dict[str, float] = {}
|
||||
self._doc_count = 0
|
||||
|
||||
def extract(self, text: str) -> List[str]:
|
||||
"""从文本中提取锚点列表"""
|
||||
anchors = []
|
||||
|
||||
# 中文 2-gram / 3-gram(转小写以便匹配)
|
||||
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
|
||||
for chunk in chinese_chars:
|
||||
if len(chunk) >= 2:
|
||||
for i in range(len(chunk) - 1):
|
||||
anchors.append(chunk[i:i+2].lower())
|
||||
if len(chunk) >= 3:
|
||||
for i in range(len(chunk) - 2):
|
||||
anchors.append(chunk[i:i+3].lower())
|
||||
|
||||
# 英文单词
|
||||
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
|
||||
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
|
||||
|
||||
# 代码标识符(反引号内)
|
||||
code_segments = re.findall(r'`([^`]+)`', text)
|
||||
for seg in code_segments:
|
||||
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
|
||||
|
||||
# 引号内的短语
|
||||
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
|
||||
anchors.extend(quoted_phrases)
|
||||
|
||||
# 数字、版本号
|
||||
versions = re.findall(r'v?\d+(\.\d+)*', text)
|
||||
anchors.extend(versions)
|
||||
|
||||
return anchors
|
||||
|
||||
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
|
||||
"""提取锚点,并检测是否有指代词"""
|
||||
anchors = self.extract(text)
|
||||
deictic_patterns = [
|
||||
'这个', '那个', '它', '上面', '刚才', '继续', '展开',
|
||||
'为什么会这样', '怎么改', '然后', '还有', '另外'
|
||||
]
|
||||
has_deictic = any(p in text for p in deictic_patterns)
|
||||
return anchors, has_deictic
|
||||
|
||||
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
|
||||
"""根据一批文档计算 IDF 值"""
|
||||
self._doc_count = len(all_doc_anchors)
|
||||
doc_freq = Counter()
|
||||
for anchors in all_doc_anchors:
|
||||
unique_anchors = set(anchors)
|
||||
for a in unique_anchors:
|
||||
doc_freq[a] += 1
|
||||
|
||||
for anchor, df in doc_freq.items():
|
||||
self._idf_cache[anchor] = self._doc_count / df
|
||||
|
||||
def idf(self, anchor: str) -> float:
|
||||
"""获取锚点的 IDF 值,默认 1.0"""
|
||||
return self._idf_cache.get(anchor, 1.0)
|
||||
Reference in New Issue
Block a user