""" 话题门控模块 - 判断当前 query 是继续话题还是切换话题 """ from typing import List, Tuple from .anchor import AnchorExtractor # 指代词列表 DEICTIC_WORDS = { '这个', '那个', '它', '她', '他', '上面', '刚才', '继续', '展开', '为什么会这样', '怎么改', '然后', '还有', '另外', '这样', '那样', '如此', '前一个', '这一', '那一' } class TopicGate: """话题门控:判断是继续当前话题还是切换到新话题""" # 阈值参数 OVERLAP_CONTINUE_THRESHOLD = 0.45 OVERLAP_SWITCH_THRESHOLD = 0.20 NEW_RATIO_SWITCH_THRESHOLD = 0.70 def __init__(self, anchor_extractor: AnchorExtractor = None): self.anchor_extractor = anchor_extractor or AnchorExtractor() self.active_topic_anchors: List[str] = [] self.active_topic_idf: dict[str, float] = {} def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None: """更新活跃话题的锚点集合""" self.active_topic_anchors = anchors self.active_topic_idf = idf def is_topic_switch( self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None ) -> bool: """ 判断当前 query 是否切换了话题 Returns: True -> 切换到新话题 False -> 继续当前话题 """ anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query) if active_topic is None: self.active_topic_anchors = anchors return False topic_anchors, topic_idf = active_topic # 有指代词 → 强暗示继续 if has_deictic: return False # 计算 overlap 和 new_ratio overlap = self._compute_overlap(anchors, topic_anchors, topic_idf) new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf) # 切换判断 if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD: return True # 继续判断 if overlap > self.OVERLAP_CONTINUE_THRESHOLD: return False # 中间地带:默认继续 return False def _compute_overlap( self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float] ) -> float: """计算加权重叠率""" if not query_anchors: return 0.0 overlap_weight = sum( topic_idf.get(a, 1.0) for a in query_anchors if a in topic_anchors ) total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors) return overlap_weight / total_weight if total_weight > 0 else 0.0 def _compute_new_ratio( self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float] ) -> float: """计算新锚点比例""" if not query_anchors: return 1.0 new_anchors = set(query_anchors) - set(topic_anchors) new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors) total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors) return new_weight / total_weight if total_weight > 0 else 1.0 def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]: """为锚点集合计算 IDF 值(用于后续计算)""" idf = {} for a in anchors: idf[a] = self.anchor_extractor.idf(a) return idf