Files
context-gatekeeper/src/anchor.py
Elaina 071f9ef418 feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
2026-04-22 01:09:35 +08:00

75 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
锚点提取模块 - 用规则从文本中提取有代表性的锚点
"""
import re
from collections import Counter
from typing import List, Tuple
class AnchorExtractor:
"""提取文本中的锚点,用于话题判断和检索"""
def __init__(self):
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
self._idf_cache: dict[str, float] = {}
self._doc_count = 0
def extract(self, text: str) -> List[str]:
"""从文本中提取锚点列表"""
anchors = []
# 中文 2-gram / 3-gram转小写以便匹配
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
for chunk in chinese_chars:
if len(chunk) >= 2:
for i in range(len(chunk) - 1):
anchors.append(chunk[i:i+2].lower())
if len(chunk) >= 3:
for i in range(len(chunk) - 2):
anchors.append(chunk[i:i+3].lower())
# 英文单词
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
# 代码标识符(反引号内)
code_segments = re.findall(r'`([^`]+)`', text)
for seg in code_segments:
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
# 引号内的短语
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
anchors.extend(quoted_phrases)
# 数字、版本号
versions = re.findall(r'v?\d+(\.\d+)*', text)
anchors.extend(versions)
return anchors
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
"""提取锚点,并检测是否有指代词"""
anchors = self.extract(text)
deictic_patterns = [
'这个', '那个', '', '上面', '刚才', '继续', '展开',
'为什么会这样', '怎么改', '然后', '还有', '另外'
]
has_deictic = any(p in text for p in deictic_patterns)
return anchors, has_deictic
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
"""根据一批文档计算 IDF 值"""
self._doc_count = len(all_doc_anchors)
doc_freq = Counter()
for anchors in all_doc_anchors:
unique_anchors = set(anchors)
for a in unique_anchors:
doc_freq[a] += 1
for anchor, df in doc_freq.items():
self._idf_cache[anchor] = self._doc_count / df
def idf(self, anchor: str) -> float:
"""获取锚点的 IDF 值,默认 1.0"""
return self._idf_cache.get(anchor, 1.0)