feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
106
src/topic_gate.py
Normal file
106
src/topic_gate.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
话题门控模块 - 判断当前 query 是继续话题还是切换话题
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
from .anchor import AnchorExtractor
|
||||
|
||||
|
||||
# 指代词列表
|
||||
DEICTIC_WORDS = {
|
||||
'这个', '那个', '它', '她', '他', '上面', '刚才', '继续',
|
||||
'展开', '为什么会这样', '怎么改', '然后', '还有', '另外',
|
||||
'这样', '那样', '如此', '前一个', '这一', '那一'
|
||||
}
|
||||
|
||||
|
||||
class TopicGate:
|
||||
"""话题门控:判断是继续当前话题还是切换到新话题"""
|
||||
|
||||
# 阈值参数
|
||||
OVERLAP_CONTINUE_THRESHOLD = 0.45
|
||||
OVERLAP_SWITCH_THRESHOLD = 0.20
|
||||
NEW_RATIO_SWITCH_THRESHOLD = 0.70
|
||||
|
||||
def __init__(self, anchor_extractor: AnchorExtractor = None):
|
||||
self.anchor_extractor = anchor_extractor or AnchorExtractor()
|
||||
self.active_topic_anchors: List[str] = []
|
||||
self.active_topic_idf: dict[str, float] = {}
|
||||
|
||||
def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None:
|
||||
"""更新活跃话题的锚点集合"""
|
||||
self.active_topic_anchors = anchors
|
||||
self.active_topic_idf = idf
|
||||
|
||||
def is_topic_switch(
|
||||
self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前 query 是否切换了话题
|
||||
|
||||
Returns:
|
||||
True -> 切换到新话题
|
||||
False -> 继续当前话题
|
||||
"""
|
||||
anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
|
||||
|
||||
if active_topic is None:
|
||||
self.active_topic_anchors = anchors
|
||||
return False
|
||||
|
||||
topic_anchors, topic_idf = active_topic
|
||||
|
||||
# 有指代词 → 强暗示继续
|
||||
if has_deictic:
|
||||
return False
|
||||
|
||||
# 计算 overlap 和 new_ratio
|
||||
overlap = self._compute_overlap(anchors, topic_anchors, topic_idf)
|
||||
new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf)
|
||||
|
||||
# 切换判断
|
||||
if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD:
|
||||
return True
|
||||
|
||||
# 继续判断
|
||||
if overlap > self.OVERLAP_CONTINUE_THRESHOLD:
|
||||
return False
|
||||
|
||||
# 中间地带:默认继续
|
||||
return False
|
||||
|
||||
def _compute_overlap(
|
||||
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
||||
) -> float:
|
||||
"""计算加权重叠率"""
|
||||
if not query_anchors:
|
||||
return 0.0
|
||||
|
||||
overlap_weight = sum(
|
||||
topic_idf.get(a, 1.0)
|
||||
for a in query_anchors
|
||||
if a in topic_anchors
|
||||
)
|
||||
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
||||
|
||||
return overlap_weight / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
def _compute_new_ratio(
|
||||
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
||||
) -> float:
|
||||
"""计算新锚点比例"""
|
||||
if not query_anchors:
|
||||
return 1.0
|
||||
|
||||
new_anchors = set(query_anchors) - set(topic_anchors)
|
||||
new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors)
|
||||
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
||||
|
||||
return new_weight / total_weight if total_weight > 0 else 1.0
|
||||
|
||||
def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]:
|
||||
"""为锚点集合计算 IDF 值(用于后续计算)"""
|
||||
idf = {}
|
||||
for a in anchors:
|
||||
idf[a] = self.anchor_extractor.idf(a)
|
||||
return idf
|
||||
Reference in New Issue
Block a user