- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
"""
|
||
锚点提取模块 - 用规则从文本中提取有代表性的锚点
|
||
"""
|
||
|
||
import re
|
||
from collections import Counter
|
||
from typing import List, Tuple
|
||
|
||
|
||
class AnchorExtractor:
|
||
"""提取文本中的锚点,用于话题判断和检索"""
|
||
|
||
def __init__(self):
|
||
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
|
||
self._idf_cache: dict[str, float] = {}
|
||
self._doc_count = 0
|
||
|
||
def extract(self, text: str) -> List[str]:
|
||
"""从文本中提取锚点列表"""
|
||
anchors = []
|
||
|
||
# 中文 2-gram / 3-gram(转小写以便匹配)
|
||
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
|
||
for chunk in chinese_chars:
|
||
if len(chunk) >= 2:
|
||
for i in range(len(chunk) - 1):
|
||
anchors.append(chunk[i:i+2].lower())
|
||
if len(chunk) >= 3:
|
||
for i in range(len(chunk) - 2):
|
||
anchors.append(chunk[i:i+3].lower())
|
||
|
||
# 英文单词
|
||
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
|
||
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
|
||
|
||
# 代码标识符(反引号内)
|
||
code_segments = re.findall(r'`([^`]+)`', text)
|
||
for seg in code_segments:
|
||
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
|
||
|
||
# 引号内的短语
|
||
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
|
||
anchors.extend(quoted_phrases)
|
||
|
||
# 数字、版本号
|
||
versions = re.findall(r'v?\d+(\.\d+)*', text)
|
||
anchors.extend(versions)
|
||
|
||
return anchors
|
||
|
||
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
|
||
"""提取锚点,并检测是否有指代词"""
|
||
anchors = self.extract(text)
|
||
deictic_patterns = [
|
||
'这个', '那个', '它', '上面', '刚才', '继续', '展开',
|
||
'为什么会这样', '怎么改', '然后', '还有', '另外'
|
||
]
|
||
has_deictic = any(p in text for p in deictic_patterns)
|
||
return anchors, has_deictic
|
||
|
||
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
|
||
"""根据一批文档计算 IDF 值"""
|
||
self._doc_count = len(all_doc_anchors)
|
||
doc_freq = Counter()
|
||
for anchors in all_doc_anchors:
|
||
unique_anchors = set(anchors)
|
||
for a in unique_anchors:
|
||
doc_freq[a] += 1
|
||
|
||
for anchor, df in doc_freq.items():
|
||
self._idf_cache[anchor] = self._doc_count / df
|
||
|
||
def idf(self, anchor: str) -> float:
|
||
"""获取锚点的 IDF 值,默认 1.0"""
|
||
return self._idf_cache.get(anchor, 1.0) |