- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
"""
|
|
话题门控模块 - 判断当前 query 是继续话题还是切换话题
|
|
"""
|
|
|
|
from typing import List, Tuple
|
|
from .anchor import AnchorExtractor
|
|
|
|
|
|
# 指代词列表
|
|
DEICTIC_WORDS = {
|
|
'这个', '那个', '它', '她', '他', '上面', '刚才', '继续',
|
|
'展开', '为什么会这样', '怎么改', '然后', '还有', '另外',
|
|
'这样', '那样', '如此', '前一个', '这一', '那一'
|
|
}
|
|
|
|
|
|
class TopicGate:
|
|
"""话题门控:判断是继续当前话题还是切换到新话题"""
|
|
|
|
# 阈值参数
|
|
OVERLAP_CONTINUE_THRESHOLD = 0.45
|
|
OVERLAP_SWITCH_THRESHOLD = 0.20
|
|
NEW_RATIO_SWITCH_THRESHOLD = 0.70
|
|
|
|
def __init__(self, anchor_extractor: AnchorExtractor = None):
|
|
self.anchor_extractor = anchor_extractor or AnchorExtractor()
|
|
self.active_topic_anchors: List[str] = []
|
|
self.active_topic_idf: dict[str, float] = {}
|
|
|
|
def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None:
|
|
"""更新活跃话题的锚点集合"""
|
|
self.active_topic_anchors = anchors
|
|
self.active_topic_idf = idf
|
|
|
|
def is_topic_switch(
|
|
self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None
|
|
) -> bool:
|
|
"""
|
|
判断当前 query 是否切换了话题
|
|
|
|
Returns:
|
|
True -> 切换到新话题
|
|
False -> 继续当前话题
|
|
"""
|
|
anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
|
|
|
|
if active_topic is None:
|
|
self.active_topic_anchors = anchors
|
|
return False
|
|
|
|
topic_anchors, topic_idf = active_topic
|
|
|
|
# 有指代词 → 强暗示继续
|
|
if has_deictic:
|
|
return False
|
|
|
|
# 计算 overlap 和 new_ratio
|
|
overlap = self._compute_overlap(anchors, topic_anchors, topic_idf)
|
|
new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf)
|
|
|
|
# 切换判断
|
|
if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD:
|
|
return True
|
|
|
|
# 继续判断
|
|
if overlap > self.OVERLAP_CONTINUE_THRESHOLD:
|
|
return False
|
|
|
|
# 中间地带:默认继续
|
|
return False
|
|
|
|
def _compute_overlap(
|
|
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
|
) -> float:
|
|
"""计算加权重叠率"""
|
|
if not query_anchors:
|
|
return 0.0
|
|
|
|
overlap_weight = sum(
|
|
topic_idf.get(a, 1.0)
|
|
for a in query_anchors
|
|
if a in topic_anchors
|
|
)
|
|
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
|
|
|
return overlap_weight / total_weight if total_weight > 0 else 0.0
|
|
|
|
def _compute_new_ratio(
|
|
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
|
) -> float:
|
|
"""计算新锚点比例"""
|
|
if not query_anchors:
|
|
return 1.0
|
|
|
|
new_anchors = set(query_anchors) - set(topic_anchors)
|
|
new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors)
|
|
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
|
|
|
return new_weight / total_weight if total_weight > 0 else 1.0
|
|
|
|
def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]:
|
|
"""为锚点集合计算 IDF 值(用于后续计算)"""
|
|
idf = {}
|
|
for a in anchors:
|
|
idf[a] = self.anchor_extractor.idf(a)
|
|
return idf |