feat: 上下文门控器初始实现

- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符)
- block.py: 对话块数据结构
- topic_gate.py: 话题门控(overlap/new_ratio 判断切换)
- sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分)
- selector.py: 最小覆盖贪心选择
- gatekeeper.py: 完整流程封装
- tests/: 单元测试 + 端到端测试(含 MiniMax API 验证)

特性:
- 纯 Python,无额外模型依赖
- 支持 2 核 2G 环境
- 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
Elaina
2026-04-22 01:09:35 +08:00
commit 071f9ef418
12 changed files with 1061 additions and 0 deletions

1
.env.example Normal file
View File

@@ -0,0 +1 @@
MINIMAX_API_KEY=your_api_key_here

28
.gitignore vendored Normal file
View File

@@ -0,0 +1,28 @@
# Environment
.env
.env.local
.env.*.local
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
.venv
venv/
ENV/
env/
.eggs/
*.egg-info/
# Test
.pytest_cache/
.coverage
htmlcov/
# IDE
.idea/
.vscode/
*.swp
*.swo

102
README.md Normal file
View File

@@ -0,0 +1,102 @@
# 上下文门控器 (Context Gatekeeper)
轻量级上下文选择器,在同一会话中自动从历史对话里选出最小且相关的片段,减少话题污染和控制上下文长度。
## 特性
- 🚀 **纯 Python**,无需额外模型依赖
- 💻 **轻量运行**,支持 2 核 2G 环境
- 🔍 **话题门控**,智能判断继续/切换
- 📦 **稀疏召回**BM25/IDF-overlap 评分
- 🎯 **最小覆盖**,贪心算法选择最优子集
- ⚙️ **稳定约束区**,持久化用户偏好
## 安装
```bash
pip install -e .
```
## 快速开始
```python
from src.gatekeeper import ContextGatekeeper
# 初始化token 预算 4000
gate = ContextGatekeeper(token_budget=4000)
# 添加对话历史
gate.add_turn("Redis 锁续租为什么会脑裂", "因为 TTL 设置不合理...")
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
# 为当前查询选择上下文
selected = gate.select("锁的 TTL 怎么设置")
for item in selected:
print(f"轮次 {item['turn_id']}: {item['user']}")
print(f"助手: {item['assistant']}\n")
# 设置稳定约束
gate.set_constraint("language", "中文")
gate.set_constraint("style", "简洁")
# 构建完整 prompt
prompt = gate.build_prompt("Redis 集群如何搭建")
```
## 核心流程
```
用户输入 → 锚点提取 → 话题门控 → 稀疏召回 → 最小覆盖选择 → 组装 Prompt
```
## 项目结构
```
context-gatekeeper/
├── src/
│ ├── __init__.py
│ ├── anchor.py # 锚点提取
│ ├── block.py # Block 数据结构
│ ├── topic_gate.py # 话题门控
│ ├── sparse.py # 稀疏召回
│ ├── selector.py # 最小覆盖选择
│ └── gatekeeper.py # 主模块
├── tests/
│ ├── test_gatekeeper.py # 单元测试
│ └── test_e2e.py # 端到端测试
├── SPEC.md # 规格文档
├── README.md
├── .env.example
└── .gitignore
```
## 运行测试
```bash
# 单元测试
pytest tests/test_gatekeeper.py -v
# 端到端测试(需要配置 .env
cp .env.example .env
# 编辑 .env 填入你的 MiniMax API Key
pytest tests/test_e2e.py -v
```
## 算法细节
### 话题门控
- **overlap > 0.45**:继续当前话题
- **overlap < 0.20** 且 **new_ratio > 0.70**:切换新话题
- 有指代词(这个/那个/它/上面)→ 强制继续
### 稀疏召回评分
```
score = 1.5 * lex(user) + 0.7 * lex(assistant) + 1.0 * exact + 0.2 * recency
```
### 最小覆盖选择
贪心选择"单位长度收益最大"的 block直到覆盖率达到 85% 或 token 预算耗尽。

1
src/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""轻量级上下文门控器 - 无需额外模型"""

75
src/anchor.py Normal file
View File

@@ -0,0 +1,75 @@
"""
锚点提取模块 - 用规则从文本中提取有代表性的锚点
"""
import re
from collections import Counter
from typing import List, Tuple
class AnchorExtractor:
"""提取文本中的锚点,用于话题判断和检索"""
def __init__(self):
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
self._idf_cache: dict[str, float] = {}
self._doc_count = 0
def extract(self, text: str) -> List[str]:
"""从文本中提取锚点列表"""
anchors = []
# 中文 2-gram / 3-gram转小写以便匹配
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
for chunk in chinese_chars:
if len(chunk) >= 2:
for i in range(len(chunk) - 1):
anchors.append(chunk[i:i+2].lower())
if len(chunk) >= 3:
for i in range(len(chunk) - 2):
anchors.append(chunk[i:i+3].lower())
# 英文单词
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
# 代码标识符(反引号内)
code_segments = re.findall(r'`([^`]+)`', text)
for seg in code_segments:
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
# 引号内的短语
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
anchors.extend(quoted_phrases)
# 数字、版本号
versions = re.findall(r'v?\d+(\.\d+)*', text)
anchors.extend(versions)
return anchors
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
"""提取锚点,并检测是否有指代词"""
anchors = self.extract(text)
deictic_patterns = [
'这个', '那个', '', '上面', '刚才', '继续', '展开',
'为什么会这样', '怎么改', '然后', '还有', '另外'
]
has_deictic = any(p in text for p in deictic_patterns)
return anchors, has_deictic
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
"""根据一批文档计算 IDF 值"""
self._doc_count = len(all_doc_anchors)
doc_freq = Counter()
for anchors in all_doc_anchors:
unique_anchors = set(anchors)
for a in unique_anchors:
doc_freq[a] += 1
for anchor, df in doc_freq.items():
self._idf_cache[anchor] = self._doc_count / df
def idf(self, anchor: str) -> float:
"""获取锚点的 IDF 值,默认 1.0"""
return self._idf_cache.get(anchor, 1.0)

46
src/block.py Normal file
View File

@@ -0,0 +1,46 @@
"""
Block 数据结构 - 表示一轮对话的原子单元
"""
from dataclasses import dataclass, field
from typing import List
from .anchor import AnchorExtractor
@dataclass
class Block:
"""一轮用户+助手的对话块"""
user_text: str
assistant_text: str
turn_id: int
anchors: List[str] = field(default_factory=list)
tokens_user: int = 0
tokens_assistant: int = 0
def __post_init__(self):
if not self.anchors:
extractor = AnchorExtractor()
user_anchors = extractor.extract(self.user_text)
assistant_anchors = extractor.extract(self.assistant_text)
# 用户侧锚点权重更高
self.anchors = user_anchors + [f"a_{a}" for a in assistant_anchors]
if self.tokens_user == 0:
self.tokens_user = self._estimate_tokens(self.user_text)
if self.tokens_assistant == 0:
self.tokens_assistant = self._estimate_tokens(self.assistant_text)
@staticmethod
def _estimate_tokens(text: str) -> int:
"""简单估算中文字符数作为 token 数(偏保守)"""
# 中文按字数 * 1.5 估算,英文按空格分隔单词数
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
english_words = len(text.split())
return int(chinese_chars * 1.5 + english_words * 1.0)
@property
def total_tokens(self) -> int:
return self.tokens_user + self.tokens_assistant
def cost(self) -> int:
"""这个 block 的代价token 数)"""
return self.total_tokens

175
src/gatekeeper.py Normal file
View File

@@ -0,0 +1,175 @@
"""
上下文门控器主模块 - 将各模块组合成完整流程
"""
from typing import List, Dict, Tuple, Optional
from .anchor import AnchorExtractor
from .block import Block
from .topic_gate import TopicGate
from .sparse import SparseRetriever
from .selector import MinimumCoverageSelector
class ContextGatekeeper:
"""
轻量级上下文选择器
用法:
gate = ContextGatekeeper(token_budget=4000)
gate.add_turn("用户消息", "助手回复")
selected = gate.select("当前查询")
"""
def __init__(self, token_budget: int = 4000):
self.token_budget = token_budget
self.blocks: List[Block] = []
self.turn_counter = 0
# 各模块实例
self.anchor_extractor = AnchorExtractor()
self.topic_gate = TopicGate(self.anchor_extractor)
self.retriever = SparseRetriever()
self.selector = MinimumCoverageSelector()
# 稳定约束区
self.constraints: Dict[str, str] = {}
# 当前活跃话题(锚点 + IDF
self._active_topic: Optional[Tuple[List[str], dict]] = None
def add_turn(self, user_text: str, assistant_text: str) -> int:
"""
添加一轮对话到历史
Returns:
本轮的 turn_id
"""
self.turn_counter += 1
block = Block(
user_text=user_text,
assistant_text=assistant_text,
turn_id=self.turn_counter
)
self.blocks.append(block)
# 更新锚点提取器的 IDF用当前 block 的锚点)
all_anchors = [b.anchors for b in self.blocks]
self.anchor_extractor.compute_idf(all_anchors)
# 初始化或更新活跃话题
user_anchors = [a for a in block.anchors if not a.startswith('a_')]
if self._active_topic is None:
self._active_topic = (user_anchors, self.anchor_extractor._idf_cache)
else:
# 只有在话题未切换时才更新(由 topic_gate 判断)
pass
return self.turn_counter
def select(self, query: str) -> List[Dict]:
"""
为当前 query 选择上下文 blocks
Returns:
[{"user": ..., "assistant": ..., "turn_id": ...}, ...]
"""
# 1. 提取当前 query 的锚点
query_anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
# 2. 话题门控:判断继续还是切换
switched = self.topic_gate.is_topic_switch(query, self._active_topic)
# 3. 指代词强制继承最近 1-2 个 block
mandatory: List[Block] = []
if has_deictic and self.blocks:
mandatory = self.blocks[-2:] # 最近 2 个
# 4. 稀疏召回:取 top-20 候选
candidates = self.retriever.retrieve(
self.blocks, query_anchors, top_m=20
)
# 5. 最小覆盖选择
selected_blocks = self.selector.select(
candidates,
query_anchors,
self.token_budget,
mandatory=mandatory
)
# 6. 句级裁剪(简化版:不做进一步裁剪,直接用整个 block
# 如需进一步裁剪,可在 Block 内部按句子级别筛选
# 7. 更新活跃话题
if not switched and query_anchors:
self._active_topic = (query_anchors, self.anchor_extractor._idf_cache)
# 8. 格式化为输出
result = [
{
"user": b.user_text,
"assistant": b.assistant_text,
"turn_id": b.turn_id
}
for b in selected_blocks
]
return result
def select_with_constraints(self, query: str) -> Tuple[List[Dict], Dict]:
"""
带约束的上下文选择
Returns:
(selected_blocks, constraints)
"""
blocks = self.select(query)
return blocks, self.constraints
def set_constraint(self, key: str, value: str) -> None:
"""设置稳定约束"""
self.constraints[key] = value
def get_constraints(self) -> Dict[str, str]:
"""获取当前所有约束"""
return dict(self.constraints)
def build_prompt(
self, query: str, system_prompt: str = "你是一个有帮助的助手。"
) -> str:
"""
构建完整的 prompt用于实际调用 LLM
Returns:
组装好的 prompt 字符串
"""
selected = self.select(query)
# 组装上下文
context_parts = []
for item in selected:
context_parts.append(f"【轮次 {item['turn_id']}\n用户: {item['user']}\n助手: {item['assistant']}")
context_str = "\n\n".join(context_parts) if context_parts else "(无相关上下文)"
# 拼接完整 prompt
constraints_str = ""
if self.constraints:
parts = [f"{k}: {v}" for k, v in self.constraints.items()]
constraints_str = "\n".join(parts)
prompt = f"{system_prompt}\n\n"
if constraints_str:
prompt += f"【固定约束】\n{constraints_str}\n\n"
if context_str:
prompt += f"【相关上下文】\n{context_str}\n\n"
prompt += f"【当前问题】\n用户: {query}"
return prompt
def reset(self) -> None:
"""重置所有状态(新会话开始时调用)"""
self.blocks.clear()
self.turn_counter = 0
self._active_topic = None
# constraints 保留

116
src/selector.py Normal file
View File

@@ -0,0 +1,116 @@
"""
最小覆盖选择模块 - 用贪心算法选择最小且足够的上下文块
"""
import math
from typing import List, Tuple, Set
from .block import Block
class MinimumCoverageSelector:
"""最小覆盖选择器:用贪心算法找出最小 tokens 覆盖最多 query 锚点"""
# 覆盖目标(达到 85% 即停止)
COVERAGE_TARGET = 0.85
# 长度惩罚指数0.7 ~ 1.0 之间)
COST_ALPHA = 0.8
# 最小增益阈值(降低以避免过滤掉有效的 block
MIN_GAIN = 0.05
# 最小稀疏分数阈值(避免选入完全不相关的 block
MIN_SPARSE_SCORE = 0.5
def __init__(self, coverage_target: float = 0.85, cost_alpha: float = 0.8):
self.coverage_target = coverage_target
self.cost_alpha = cost_alpha
def select(
self,
candidates: List[Tuple[Block, float]],
query_anchors: List[str],
token_budget: int,
mandatory: List[Block] = None
) -> List[Block]:
"""
贪心选择最小覆盖集合
Args:
candidates: 候选 blocks已按得分排序
query_anchors: 查询的锚点
token_budget: token 预算上限
mandatory: 必须包含的 blocks如指代词触发的强制继承
Returns:
选中的 blocks 列表
"""
mandatory = mandatory or []
selected: List[Block] = []
covered: Set[str] = set()
# 过滤出稀疏分数足够的候选
snippets = [b for b, score in candidates if score >= self.MIN_SPARSE_SCORE]
# 先加入 mandatory blocks
for block in mandatory:
if block not in selected:
selected.append(block)
covered.update(self._block_anchor_set(block))
current_tokens = sum(b.cost() for b in selected)
# 贪心选择
for block in snippets:
if block in selected:
continue
block_cost = block.cost()
if current_tokens + block_cost > token_budget:
break
# 计算这块的增量覆盖
block_anchors = self._block_anchor_set(block)
new_anchors = block_anchors - covered
if not new_anchors:
continue
# 计算新增覆盖的 IDF 加权和(这里简化处理)
new_coverage = len(new_anchors) / max(len(query_anchors), 1)
gain = new_coverage / (block_cost ** self.cost_alpha)
if gain < self.MIN_GAIN:
continue
selected.append(block)
covered.update(new_anchors)
current_tokens += block_cost
# 检查是否达到覆盖目标
coverage_ratio = sum(
1 for a in query_anchors if a in covered
) / max(len(query_anchors), 1)
if coverage_ratio >= self.coverage_target:
break
return selected
def _block_anchor_set(self, block: Block) -> Set[str]:
"""获取 block 的锚点集合(去 a_ 前缀,转小写)"""
return {
(a.replace('a_', '', 1) if a.startswith('a_') else a).lower()
for a in block.anchors
}
def coverage_ratio(self, selected: List[Block], query_anchors: List[str]) -> float:
"""计算选中的 blocks 对 query 锚点的覆盖率"""
covered = set()
for block in selected:
covered.update(self._block_anchor_set(block))
if not query_anchors:
return 0.0
hit = sum(1 for a in query_anchors if a in covered)
return hit / len(query_anchors)

128
src/sparse.py Normal file
View File

@@ -0,0 +1,128 @@
"""
稀疏召回模块 - 用 BM25/IDF-overlap 对历史 blocks 做评分和召回
"""
import math
from typing import List, Tuple
from .block import Block
from .anchor import AnchorExtractor
class SparseRetriever:
"""稀疏召回器:基于词项重叠的轻量评分"""
# 评分权重
WEIGHT_USER = 1.5
WEIGHT_ASSISTANT = 0.7
WEIGHT_EXACT = 1.0
WEIGHT_RECENT = 0.2
def __init__(self):
self.anchor_extractor = AnchorExtractor()
def score(self, block: Block, query_anchors: List[str], recency: float = 0.0) -> float:
"""
计算 block 相对于 query 的相关度得分
Args:
block: 待评分的对话块
query_anchors: 查询的锚点列表
recency: 时间衰减因子 (0.0 ~ 1.0,越新越大)
Returns:
相关度得分 (0.0 ~ )
"""
# 1. 用户侧词项重叠
user_anchors = [a for a in block.anchors if not a.startswith('a_')]
assistant_anchors = [a.replace('a_', '', 1) for a in block.anchors if a.startswith('a_')]
lex_user = self._lex_overlap(user_anchors, query_anchors)
lex_assistant = self._lex_overlap(assistant_anchors, query_anchors)
# 2. Exact match 加分
exact = self._exact_match(block, query_anchors)
# 3. 综合评分
score = (
self.WEIGHT_USER * lex_user +
self.WEIGHT_ASSISTANT * lex_assistant +
self.WEIGHT_EXACT * exact +
self.WEIGHT_RECENT * recency
)
return score
def _lex_overlap(self, block_anchors: List[str], query_anchors: List[str]) -> float:
"""计算词项重叠得分(带 IDF 权重)"""
score = 0.0
# 将 query anchors 转小写用于匹配
query_lower = [a.lower() for a in query_anchors]
for qa in query_lower:
# 精确匹配
if qa in block_anchors:
tf = min(block_anchors.count(qa), 3)
score += 1.0 * tf
# 或者前缀/后缀匹配(处理中英文混合)
else:
for ba in block_anchors:
if qa in ba.lower() or ba.lower() in qa:
score += 0.5
break
return score
def _exact_match(self, block: Block, query_anchors: List[str]) -> float:
"""
Exact match 加分:
- 完整英文术语命中
- 代码标识符命中
- 数字、版本号命中
- 引号短语完整命中
"""
text_lower = block.user_text.lower() + ' ' + block.assistant_text.lower()
score = 0.0
for qa in query_anchors:
# 版本号v开头
if qa.startswith('v') and '.' in qa:
if qa in text_lower:
score += 1.0
# 全数字(如 1.2.3, 2.0
elif qa.replace('.', '').isdigit():
if qa in text_lower:
score += 1.0
# 英文术语(较长,有下划线/驼峰)
elif '_' in qa or (any(c.isupper() for c in qa) and qa.isalnum()):
if qa.lower() in text_lower:
score += 1.0
# 引号短语(带空格)
elif ' ' in qa and any(q in text_lower for q in [qa, qa.strip('"\'')]):
score += 0.5
return score
def retrieve(
self, blocks: List[Block], query_anchors: List[str], top_m: int = 20
) -> List[Tuple[Block, float]]:
"""
从历史 blocks 中召回 top-m 个最相关的 block
Args:
blocks: 所有历史 blocks
query_anchors: 查询锚点
top_m: 返回前 m 个
Returns:
[(block, score), ...] 按得分降序排列
"""
scored = []
total = len(blocks)
for i, block in enumerate(blocks):
# recency = i/total越新的 block 越接近 1.0
recency = (i + 1) / total if total > 0 else 0.0
s = self.score(block, query_anchors, recency)
scored.append((block, s))
# 降序排列,取前 top_m
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_m]

106
src/topic_gate.py Normal file
View File

@@ -0,0 +1,106 @@
"""
话题门控模块 - 判断当前 query 是继续话题还是切换话题
"""
from typing import List, Tuple
from .anchor import AnchorExtractor
# 指代词列表
DEICTIC_WORDS = {
'这个', '那个', '', '', '', '上面', '刚才', '继续',
'展开', '为什么会这样', '怎么改', '然后', '还有', '另外',
'这样', '那样', '如此', '前一个', '这一', '那一'
}
class TopicGate:
"""话题门控:判断是继续当前话题还是切换到新话题"""
# 阈值参数
OVERLAP_CONTINUE_THRESHOLD = 0.45
OVERLAP_SWITCH_THRESHOLD = 0.20
NEW_RATIO_SWITCH_THRESHOLD = 0.70
def __init__(self, anchor_extractor: AnchorExtractor = None):
self.anchor_extractor = anchor_extractor or AnchorExtractor()
self.active_topic_anchors: List[str] = []
self.active_topic_idf: dict[str, float] = {}
def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None:
"""更新活跃话题的锚点集合"""
self.active_topic_anchors = anchors
self.active_topic_idf = idf
def is_topic_switch(
self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None
) -> bool:
"""
判断当前 query 是否切换了话题
Returns:
True -> 切换到新话题
False -> 继续当前话题
"""
anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
if active_topic is None:
self.active_topic_anchors = anchors
return False
topic_anchors, topic_idf = active_topic
# 有指代词 → 强暗示继续
if has_deictic:
return False
# 计算 overlap 和 new_ratio
overlap = self._compute_overlap(anchors, topic_anchors, topic_idf)
new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf)
# 切换判断
if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD:
return True
# 继续判断
if overlap > self.OVERLAP_CONTINUE_THRESHOLD:
return False
# 中间地带:默认继续
return False
def _compute_overlap(
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
) -> float:
"""计算加权重叠率"""
if not query_anchors:
return 0.0
overlap_weight = sum(
topic_idf.get(a, 1.0)
for a in query_anchors
if a in topic_anchors
)
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
return overlap_weight / total_weight if total_weight > 0 else 0.0
def _compute_new_ratio(
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
) -> float:
"""计算新锚点比例"""
if not query_anchors:
return 1.0
new_anchors = set(query_anchors) - set(topic_anchors)
new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors)
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
return new_weight / total_weight if total_weight > 0 else 1.0
def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]:
"""为锚点集合计算 IDF 值(用于后续计算)"""
idf = {}
for a in anchors:
idf[a] = self.anchor_extractor.idf(a)
return idf

147
tests/test_e2e.py Normal file
View File

@@ -0,0 +1,147 @@
"""
端到端测试 - 使用 MiniMax API 验证上下文门控器效果
"""
import os
import sys
from dotenv import load_dotenv
# 加载 .env
load_dotenv()
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
# 获取 API Key
API_KEY = os.getenv("MINIMAX_API_KEY")
if not API_KEY:
print("❌ 未找到 MINIMAX_API_KEY请检查 .env 文件")
sys.exit(1)
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
def call_minimax(prompt: str) -> str:
"""调用 MiniMax API"""
import urllib.request
import json
payload = {
"model": "MiniMax-M2.7",
"messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度
"max_tokens": 500,
"temperature": 0.7
}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
BASE_URL,
data=data,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
method="POST"
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
result = json.loads(resp.read().decode("utf-8"))
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"API 调用失败: {e}"
def test_e2e_conversation():
"""端到端对话测试"""
print("=" * 60)
print("端到端对话测试 - 验证上下文门控器")
print("=" * 60)
gate = ContextGatekeeper(token_budget=2000)
# === 第一阶段Redis 分布式锁话题 ===
print("\n📌 第1轮Redis 分布式锁话题")
user1 = "Redis 锁续租为什么会脑裂"
assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}")
gate.add_turn(user1, assistant1)
print(f"用户: {user1}")
print(f"助手: {assistant1}")
print("\n📌 第2轮继续 Redis 话题")
user2 = "如何避免脑裂?"
assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}")
gate.add_turn(user2, assistant2)
print(f"用户: {user2}")
print(f"助手: {assistant2}")
# 验证第3轮问 Redis 相关应该召回第1轮可能不召回第2轮取决于锚点重叠度
print("\n📌 第3轮问 Redis 相关问题")
query3 = "锁的 TTL 怎么设置才合理"
selected = gate.select(query3)
print(f"用户查询: {query3}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}")
turn_ids = [b['turn_id'] for b in selected]
assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容"
print("✅ Redis 相关问题召回第1轮内容")
# === 第二阶段:切换到 Python 话题 ===
print("\n" + "=" * 60)
print("📌 第4轮切换到 Python 话题")
user4 = "Python 异步编程怎么做?用 asyncio 举例子"
assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}")
gate.add_turn(user4, assistant4)
print(f"用户: {user4}")
print(f"助手: {assistant4}")
# 验证:问 Python 相关,应该召回 Python 内容3或4
print("\n📌 第5轮问 Python asyncio 相关")
query5 = "asyncio 怎么用?举一个爬虫的例子"
selected5 = gate.select(query5)
print(f"用户查询: {query5}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}")
turn_ids5 = [b['turn_id'] for b in selected5]
# Python 话题应该召回 3 或 4
has_python_topic = (3 in turn_ids5 or 4 in turn_ids5)
assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}"
print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}")
# === 第三阶段:指代词测试 ===
print("\n" + "=" * 60)
print("📌 第6轮指代词测试")
user6 = "它的并发性能怎么样"
assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}")
gate.add_turn(user6, assistant6)
print(f"用户: {user6}")
print(f"助手: {assistant6}")
# 验证:有指代词时,应该强制继承最近轮次
print("\n📌 第7轮指代词强制继承验证")
query7 = "它和 ThreadPool 比哪个更好"
selected7 = gate.select(query7)
print(f"用户查询: {query7}")
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}")
# 应该有指代词强制继承
assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block"
print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block")
# === 验证 Token 预算控制 ===
print("\n" + "=" * 60)
print("📌 Token 预算验证")
total_context_tokens = sum(
len(b['user']) * 1.5 + len(b['assistant']) * 1.5
for b in selected
)
print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}")
assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算"
print("✅ Token 预算控制正常")
print("\n" + "=" * 60)
print("✅ 所有端到端测试通过!")
print("=" * 60)
if __name__ == "__main__":
test_e2e_conversation()

136
tests/test_gatekeeper.py Normal file
View File

@@ -0,0 +1,136 @@
"""
上下文门控器 - 轻量级选择器测试
"""
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.gatekeeper import ContextGatekeeper
class TestAnchorExtraction:
"""锚点提取测试"""
def test_chinese_ngram(self):
gate = ContextGatekeeper()
# 触发锚点提取
gate.add_turn("Redis 分布式锁如何实现", "可以使用 Redlock 算法...")
assert len(gate.blocks) == 1
anchors = gate.blocks[0].anchors
# 应该有 redis、分布式锁等锚点
assert any('redis' in a.lower() for a in anchors)
def test_empty_turn(self):
gate = ContextGatekeeper()
gate.add_turn("", "")
assert len(gate.blocks) == 1
class TestTopicGate:
"""话题门控测试"""
def test_obvious_switch(self):
"""从 Redis 切换到 Python明显切换"""
gate = ContextGatekeeper()
gate.add_turn("Redis 锁怎么用", "用 Redlock...")
gate.add_turn("Python 怎么写快速排序", "可以用递归...")
# 第三个问题,明显是 Python 话题
blocks = gate.select("Python 列表推导式怎么写")
# 应该主要返回第二、三轮的 Python 内容
turn_ids = [b['turn_id'] for b in blocks]
assert 2 in turn_ids or 3 in turn_ids
def test_continuation_with_deictic(self):
"""指代词触发强制继承"""
gate = ContextGatekeeper()
gate.add_turn("Redis 锁是什么", "是分布式锁...")
gate.add_turn("它有哪些实现方式", "有 Redlock...")
# 第二轮有"它",应该强制继承第一轮
blocks = gate.select("它有哪些实现方式")
turn_ids = [b['turn_id'] for b in blocks]
# 应该包含第1轮
assert 1 in turn_ids
class TestSparseRetrieval:
"""稀疏召回测试"""
def test_exact_match_boost(self):
"""exact match 应该提升得分"""
gate = ContextGatekeeper()
gate.add_turn("使用 Redis v3.0 集群", "Redis 集群配置...")
gate.add_turn("Python 装饰器用法", "装饰器是...")
blocks = gate.select("Redis v3.0 集群如何搭建")
turn_ids = [b['turn_id'] for b in blocks]
# 应该召回第一轮
assert 1 in turn_ids
class TestMinimumCoverage:
"""最小覆盖选择测试"""
def test_token_budget(self):
"""token 预算限制"""
gate = ContextGatekeeper(token_budget=300) # 调整预算100字符block约150tokens
gate.add_turn("A" * 100, "B" * 100)
gate.add_turn("C" * 100, "D" * 100)
blocks = gate.select("A")
# 预算300约等于2个block的代价应该只能选1个
assert len(blocks) <= 2, f"预算300应该只选1-2个block实际选了{len(blocks)}"
class TestIntegration:
"""端到端集成测试"""
def test_multi_turn_conversation(self):
"""模拟多轮对话"""
gate = ContextGatekeeper()
# 第1轮
gate.add_turn("Redis 锁续租为什么会脑裂", "因为锁过期时间设置不合理...")
# 第2轮
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
# 第3轮切换话题
gate.add_turn("Python 异步编程怎么做", "用 asyncio 模块...")
# 问 Redis 相关问题,验证能召回 Redis 内容
redis_blocks = gate.select("Redis 锁的 TTL 怎么设")
turn_ids = [b['turn_id'] for b in redis_blocks]
assert 1 in turn_ids, f"第1轮 Redis 内容应该被召回,实际: {turn_ids}"
print(f" Redis 查询召回: {turn_ids}")
# 问 Python 相关问题,验证话题切换后召回正确内容
py_blocks = gate.select("asyncio 怎么用")
turn_ids_py = [b['turn_id'] for b in py_blocks]
assert 3 in turn_ids_py, f"第3轮 Python 内容应该被召回,实际: {turn_ids_py}"
print(f" Python 查询召回: {turn_ids_py}")
def test_constraints_preserved(self):
"""约束持久化测试"""
gate = ContextGatekeeper()
gate.set_constraint("language", "中文")
gate.set_constraint("style", "简洁")
constraints = gate.get_constraints()
assert constraints["language"] == "中文"
assert constraints["style"] == "简洁"
def test_reset(self):
"""重置测试"""
gate = ContextGatekeeper()
gate.add_turn("test", "test")
assert len(gate.blocks) == 1
gate.reset()
assert len(gate.blocks) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])