feat: 上下文门控器初始实现
- anchor.py: 锚点提取(中文 2/3-gram、英文单词、代码标识符) - block.py: 对话块数据结构 - topic_gate.py: 话题门控(overlap/new_ratio 判断切换) - sparse.py: 稀疏召回(BM25/IDF-overlap + exact match 加分) - selector.py: 最小覆盖贪心选择 - gatekeeper.py: 完整流程封装 - tests/: 单元测试 + 端到端测试(含 MiniMax API 验证) 特性: - 纯 Python,无额外模型依赖 - 支持 2 核 2G 环境 - 话题门控 + 稀疏召回 + 最小覆盖选择
This commit is contained in:
1
.env.example
Normal file
1
.env.example
Normal file
@@ -0,0 +1 @@
|
|||||||
|
MINIMAX_API_KEY=your_api_key_here
|
||||||
28
.gitignore
vendored
Normal file
28
.gitignore
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
.eggs/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Test
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
102
README.md
Normal file
102
README.md
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# 上下文门控器 (Context Gatekeeper)
|
||||||
|
|
||||||
|
轻量级上下文选择器,在同一会话中自动从历史对话里选出最小且相关的片段,减少话题污染和控制上下文长度。
|
||||||
|
|
||||||
|
## 特性
|
||||||
|
|
||||||
|
- 🚀 **纯 Python**,无需额外模型依赖
|
||||||
|
- 💻 **轻量运行**,支持 2 核 2G 环境
|
||||||
|
- 🔍 **话题门控**,智能判断继续/切换
|
||||||
|
- 📦 **稀疏召回**,BM25/IDF-overlap 评分
|
||||||
|
- 🎯 **最小覆盖**,贪心算法选择最优子集
|
||||||
|
- ⚙️ **稳定约束区**,持久化用户偏好
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.gatekeeper import ContextGatekeeper
|
||||||
|
|
||||||
|
# 初始化,token 预算 4000
|
||||||
|
gate = ContextGatekeeper(token_budget=4000)
|
||||||
|
|
||||||
|
# 添加对话历史
|
||||||
|
gate.add_turn("Redis 锁续租为什么会脑裂", "因为 TTL 设置不合理...")
|
||||||
|
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
|
||||||
|
|
||||||
|
# 为当前查询选择上下文
|
||||||
|
selected = gate.select("锁的 TTL 怎么设置")
|
||||||
|
|
||||||
|
for item in selected:
|
||||||
|
print(f"轮次 {item['turn_id']}: {item['user']}")
|
||||||
|
print(f"助手: {item['assistant']}\n")
|
||||||
|
|
||||||
|
# 设置稳定约束
|
||||||
|
gate.set_constraint("language", "中文")
|
||||||
|
gate.set_constraint("style", "简洁")
|
||||||
|
|
||||||
|
# 构建完整 prompt
|
||||||
|
prompt = gate.build_prompt("Redis 集群如何搭建")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心流程
|
||||||
|
|
||||||
|
```
|
||||||
|
用户输入 → 锚点提取 → 话题门控 → 稀疏召回 → 最小覆盖选择 → 组装 Prompt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
context-gatekeeper/
|
||||||
|
├── src/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── anchor.py # 锚点提取
|
||||||
|
│ ├── block.py # Block 数据结构
|
||||||
|
│ ├── topic_gate.py # 话题门控
|
||||||
|
│ ├── sparse.py # 稀疏召回
|
||||||
|
│ ├── selector.py # 最小覆盖选择
|
||||||
|
│ └── gatekeeper.py # 主模块
|
||||||
|
├── tests/
|
||||||
|
│ ├── test_gatekeeper.py # 单元测试
|
||||||
|
│ └── test_e2e.py # 端到端测试
|
||||||
|
├── SPEC.md # 规格文档
|
||||||
|
├── README.md
|
||||||
|
├── .env.example
|
||||||
|
└── .gitignore
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 单元测试
|
||||||
|
pytest tests/test_gatekeeper.py -v
|
||||||
|
|
||||||
|
# 端到端测试(需要配置 .env)
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 填入你的 MiniMax API Key
|
||||||
|
pytest tests/test_e2e.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
## 算法细节
|
||||||
|
|
||||||
|
### 话题门控
|
||||||
|
|
||||||
|
- **overlap > 0.45**:继续当前话题
|
||||||
|
- **overlap < 0.20** 且 **new_ratio > 0.70**:切换新话题
|
||||||
|
- 有指代词(这个/那个/它/上面)→ 强制继续
|
||||||
|
|
||||||
|
### 稀疏召回评分
|
||||||
|
|
||||||
|
```
|
||||||
|
score = 1.5 * lex(user) + 0.7 * lex(assistant) + 1.0 * exact + 0.2 * recency
|
||||||
|
```
|
||||||
|
|
||||||
|
### 最小覆盖选择
|
||||||
|
|
||||||
|
贪心选择"单位长度收益最大"的 block,直到覆盖率达到 85% 或 token 预算耗尽。
|
||||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""轻量级上下文门控器 - 无需额外模型"""
|
||||||
75
src/anchor.py
Normal file
75
src/anchor.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
锚点提取模块 - 用规则从文本中提取有代表性的锚点
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class AnchorExtractor:
|
||||||
|
"""提取文本中的锚点,用于话题判断和检索"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# IDF 字典,用于权重计算(简单起见用内部默认值,可扩展)
|
||||||
|
self._idf_cache: dict[str, float] = {}
|
||||||
|
self._doc_count = 0
|
||||||
|
|
||||||
|
def extract(self, text: str) -> List[str]:
|
||||||
|
"""从文本中提取锚点列表"""
|
||||||
|
anchors = []
|
||||||
|
|
||||||
|
# 中文 2-gram / 3-gram(转小写以便匹配)
|
||||||
|
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', text)
|
||||||
|
for chunk in chinese_chars:
|
||||||
|
if len(chunk) >= 2:
|
||||||
|
for i in range(len(chunk) - 1):
|
||||||
|
anchors.append(chunk[i:i+2].lower())
|
||||||
|
if len(chunk) >= 3:
|
||||||
|
for i in range(len(chunk) - 2):
|
||||||
|
anchors.append(chunk[i:i+3].lower())
|
||||||
|
|
||||||
|
# 英文单词
|
||||||
|
english_words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', text)
|
||||||
|
anchors.extend([w.lower() for w in english_words if len(w) >= 2])
|
||||||
|
|
||||||
|
# 代码标识符(反引号内)
|
||||||
|
code_segments = re.findall(r'`([^`]+)`', text)
|
||||||
|
for seg in code_segments:
|
||||||
|
anchors.extend(re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', seg))
|
||||||
|
|
||||||
|
# 引号内的短语
|
||||||
|
quoted_phrases = re.findall(r'["\u201c\u201d]([^"\u201c\u201d]+)["\u201c\u201d]', text)
|
||||||
|
anchors.extend(quoted_phrases)
|
||||||
|
|
||||||
|
# 数字、版本号
|
||||||
|
versions = re.findall(r'v?\d+(\.\d+)*', text)
|
||||||
|
anchors.extend(versions)
|
||||||
|
|
||||||
|
return anchors
|
||||||
|
|
||||||
|
def extract_with_deictic(self, text: str) -> Tuple[List[str], bool]:
|
||||||
|
"""提取锚点,并检测是否有指代词"""
|
||||||
|
anchors = self.extract(text)
|
||||||
|
deictic_patterns = [
|
||||||
|
'这个', '那个', '它', '上面', '刚才', '继续', '展开',
|
||||||
|
'为什么会这样', '怎么改', '然后', '还有', '另外'
|
||||||
|
]
|
||||||
|
has_deictic = any(p in text for p in deictic_patterns)
|
||||||
|
return anchors, has_deictic
|
||||||
|
|
||||||
|
def compute_idf(self, all_doc_anchors: List[List[str]]) -> None:
|
||||||
|
"""根据一批文档计算 IDF 值"""
|
||||||
|
self._doc_count = len(all_doc_anchors)
|
||||||
|
doc_freq = Counter()
|
||||||
|
for anchors in all_doc_anchors:
|
||||||
|
unique_anchors = set(anchors)
|
||||||
|
for a in unique_anchors:
|
||||||
|
doc_freq[a] += 1
|
||||||
|
|
||||||
|
for anchor, df in doc_freq.items():
|
||||||
|
self._idf_cache[anchor] = self._doc_count / df
|
||||||
|
|
||||||
|
def idf(self, anchor: str) -> float:
|
||||||
|
"""获取锚点的 IDF 值,默认 1.0"""
|
||||||
|
return self._idf_cache.get(anchor, 1.0)
|
||||||
46
src/block.py
Normal file
46
src/block.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Block 数据结构 - 表示一轮对话的原子单元
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
from .anchor import AnchorExtractor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Block:
|
||||||
|
"""一轮用户+助手的对话块"""
|
||||||
|
user_text: str
|
||||||
|
assistant_text: str
|
||||||
|
turn_id: int
|
||||||
|
anchors: List[str] = field(default_factory=list)
|
||||||
|
tokens_user: int = 0
|
||||||
|
tokens_assistant: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.anchors:
|
||||||
|
extractor = AnchorExtractor()
|
||||||
|
user_anchors = extractor.extract(self.user_text)
|
||||||
|
assistant_anchors = extractor.extract(self.assistant_text)
|
||||||
|
# 用户侧锚点权重更高
|
||||||
|
self.anchors = user_anchors + [f"a_{a}" for a in assistant_anchors]
|
||||||
|
if self.tokens_user == 0:
|
||||||
|
self.tokens_user = self._estimate_tokens(self.user_text)
|
||||||
|
if self.tokens_assistant == 0:
|
||||||
|
self.tokens_assistant = self._estimate_tokens(self.assistant_text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_tokens(text: str) -> int:
|
||||||
|
"""简单估算中文字符数作为 token 数(偏保守)"""
|
||||||
|
# 中文按字数 * 1.5 估算,英文按空格分隔单词数
|
||||||
|
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
|
||||||
|
english_words = len(text.split())
|
||||||
|
return int(chinese_chars * 1.5 + english_words * 1.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self.tokens_user + self.tokens_assistant
|
||||||
|
|
||||||
|
def cost(self) -> int:
|
||||||
|
"""这个 block 的代价(token 数)"""
|
||||||
|
return self.total_tokens
|
||||||
175
src/gatekeeper.py
Normal file
175
src/gatekeeper.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
上下文门控器主模块 - 将各模块组合成完整流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from .anchor import AnchorExtractor
|
||||||
|
from .block import Block
|
||||||
|
from .topic_gate import TopicGate
|
||||||
|
from .sparse import SparseRetriever
|
||||||
|
from .selector import MinimumCoverageSelector
|
||||||
|
|
||||||
|
|
||||||
|
class ContextGatekeeper:
|
||||||
|
"""
|
||||||
|
轻量级上下文选择器
|
||||||
|
|
||||||
|
用法:
|
||||||
|
gate = ContextGatekeeper(token_budget=4000)
|
||||||
|
gate.add_turn("用户消息", "助手回复")
|
||||||
|
selected = gate.select("当前查询")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_budget: int = 4000):
|
||||||
|
self.token_budget = token_budget
|
||||||
|
self.blocks: List[Block] = []
|
||||||
|
self.turn_counter = 0
|
||||||
|
|
||||||
|
# 各模块实例
|
||||||
|
self.anchor_extractor = AnchorExtractor()
|
||||||
|
self.topic_gate = TopicGate(self.anchor_extractor)
|
||||||
|
self.retriever = SparseRetriever()
|
||||||
|
self.selector = MinimumCoverageSelector()
|
||||||
|
|
||||||
|
# 稳定约束区
|
||||||
|
self.constraints: Dict[str, str] = {}
|
||||||
|
|
||||||
|
# 当前活跃话题(锚点 + IDF)
|
||||||
|
self._active_topic: Optional[Tuple[List[str], dict]] = None
|
||||||
|
|
||||||
|
def add_turn(self, user_text: str, assistant_text: str) -> int:
|
||||||
|
"""
|
||||||
|
添加一轮对话到历史
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
本轮的 turn_id
|
||||||
|
"""
|
||||||
|
self.turn_counter += 1
|
||||||
|
block = Block(
|
||||||
|
user_text=user_text,
|
||||||
|
assistant_text=assistant_text,
|
||||||
|
turn_id=self.turn_counter
|
||||||
|
)
|
||||||
|
self.blocks.append(block)
|
||||||
|
|
||||||
|
# 更新锚点提取器的 IDF(用当前 block 的锚点)
|
||||||
|
all_anchors = [b.anchors for b in self.blocks]
|
||||||
|
self.anchor_extractor.compute_idf(all_anchors)
|
||||||
|
|
||||||
|
# 初始化或更新活跃话题
|
||||||
|
user_anchors = [a for a in block.anchors if not a.startswith('a_')]
|
||||||
|
if self._active_topic is None:
|
||||||
|
self._active_topic = (user_anchors, self.anchor_extractor._idf_cache)
|
||||||
|
else:
|
||||||
|
# 只有在话题未切换时才更新(由 topic_gate 判断)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.turn_counter
|
||||||
|
|
||||||
|
def select(self, query: str) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
为当前 query 选择上下文 blocks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[{"user": ..., "assistant": ..., "turn_id": ...}, ...]
|
||||||
|
"""
|
||||||
|
# 1. 提取当前 query 的锚点
|
||||||
|
query_anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
|
||||||
|
|
||||||
|
# 2. 话题门控:判断继续还是切换
|
||||||
|
switched = self.topic_gate.is_topic_switch(query, self._active_topic)
|
||||||
|
|
||||||
|
# 3. 指代词强制继承最近 1-2 个 block
|
||||||
|
mandatory: List[Block] = []
|
||||||
|
if has_deictic and self.blocks:
|
||||||
|
mandatory = self.blocks[-2:] # 最近 2 个
|
||||||
|
|
||||||
|
# 4. 稀疏召回:取 top-20 候选
|
||||||
|
candidates = self.retriever.retrieve(
|
||||||
|
self.blocks, query_anchors, top_m=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 最小覆盖选择
|
||||||
|
selected_blocks = self.selector.select(
|
||||||
|
candidates,
|
||||||
|
query_anchors,
|
||||||
|
self.token_budget,
|
||||||
|
mandatory=mandatory
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 句级裁剪(简化版:不做进一步裁剪,直接用整个 block)
|
||||||
|
# 如需进一步裁剪,可在 Block 内部按句子级别筛选
|
||||||
|
|
||||||
|
# 7. 更新活跃话题
|
||||||
|
if not switched and query_anchors:
|
||||||
|
self._active_topic = (query_anchors, self.anchor_extractor._idf_cache)
|
||||||
|
|
||||||
|
# 8. 格式化为输出
|
||||||
|
result = [
|
||||||
|
{
|
||||||
|
"user": b.user_text,
|
||||||
|
"assistant": b.assistant_text,
|
||||||
|
"turn_id": b.turn_id
|
||||||
|
}
|
||||||
|
for b in selected_blocks
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def select_with_constraints(self, query: str) -> Tuple[List[Dict], Dict]:
|
||||||
|
"""
|
||||||
|
带约束的上下文选择
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(selected_blocks, constraints)
|
||||||
|
"""
|
||||||
|
blocks = self.select(query)
|
||||||
|
return blocks, self.constraints
|
||||||
|
|
||||||
|
def set_constraint(self, key: str, value: str) -> None:
|
||||||
|
"""设置稳定约束"""
|
||||||
|
self.constraints[key] = value
|
||||||
|
|
||||||
|
def get_constraints(self) -> Dict[str, str]:
|
||||||
|
"""获取当前所有约束"""
|
||||||
|
return dict(self.constraints)
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self, query: str, system_prompt: str = "你是一个有帮助的助手。"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
构建完整的 prompt(用于实际调用 LLM)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
组装好的 prompt 字符串
|
||||||
|
"""
|
||||||
|
selected = self.select(query)
|
||||||
|
|
||||||
|
# 组装上下文
|
||||||
|
context_parts = []
|
||||||
|
for item in selected:
|
||||||
|
context_parts.append(f"【轮次 {item['turn_id']}】\n用户: {item['user']}\n助手: {item['assistant']}")
|
||||||
|
|
||||||
|
context_str = "\n\n".join(context_parts) if context_parts else "(无相关上下文)"
|
||||||
|
|
||||||
|
# 拼接完整 prompt
|
||||||
|
constraints_str = ""
|
||||||
|
if self.constraints:
|
||||||
|
parts = [f"{k}: {v}" for k, v in self.constraints.items()]
|
||||||
|
constraints_str = "\n".join(parts)
|
||||||
|
|
||||||
|
prompt = f"{system_prompt}\n\n"
|
||||||
|
if constraints_str:
|
||||||
|
prompt += f"【固定约束】\n{constraints_str}\n\n"
|
||||||
|
if context_str:
|
||||||
|
prompt += f"【相关上下文】\n{context_str}\n\n"
|
||||||
|
prompt += f"【当前问题】\n用户: {query}"
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""重置所有状态(新会话开始时调用)"""
|
||||||
|
self.blocks.clear()
|
||||||
|
self.turn_counter = 0
|
||||||
|
self._active_topic = None
|
||||||
|
# constraints 保留
|
||||||
116
src/selector.py
Normal file
116
src/selector.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
最小覆盖选择模块 - 用贪心算法选择最小且足够的上下文块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple, Set
|
||||||
|
from .block import Block
|
||||||
|
|
||||||
|
|
||||||
|
class MinimumCoverageSelector:
|
||||||
|
"""最小覆盖选择器:用贪心算法找出最小 tokens 覆盖最多 query 锚点"""
|
||||||
|
|
||||||
|
# 覆盖目标(达到 85% 即停止)
|
||||||
|
COVERAGE_TARGET = 0.85
|
||||||
|
|
||||||
|
# 长度惩罚指数(0.7 ~ 1.0 之间)
|
||||||
|
COST_ALPHA = 0.8
|
||||||
|
|
||||||
|
# 最小增益阈值(降低以避免过滤掉有效的 block)
|
||||||
|
MIN_GAIN = 0.05
|
||||||
|
|
||||||
|
# 最小稀疏分数阈值(避免选入完全不相关的 block)
|
||||||
|
MIN_SPARSE_SCORE = 0.5
|
||||||
|
|
||||||
|
def __init__(self, coverage_target: float = 0.85, cost_alpha: float = 0.8):
|
||||||
|
self.coverage_target = coverage_target
|
||||||
|
self.cost_alpha = cost_alpha
|
||||||
|
|
||||||
|
def select(
|
||||||
|
self,
|
||||||
|
candidates: List[Tuple[Block, float]],
|
||||||
|
query_anchors: List[str],
|
||||||
|
token_budget: int,
|
||||||
|
mandatory: List[Block] = None
|
||||||
|
) -> List[Block]:
|
||||||
|
"""
|
||||||
|
贪心选择最小覆盖集合
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: 候选 blocks(已按得分排序)
|
||||||
|
query_anchors: 查询的锚点
|
||||||
|
token_budget: token 预算上限
|
||||||
|
mandatory: 必须包含的 blocks(如指代词触发的强制继承)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
选中的 blocks 列表
|
||||||
|
"""
|
||||||
|
mandatory = mandatory or []
|
||||||
|
selected: List[Block] = []
|
||||||
|
covered: Set[str] = set()
|
||||||
|
# 过滤出稀疏分数足够的候选
|
||||||
|
snippets = [b for b, score in candidates if score >= self.MIN_SPARSE_SCORE]
|
||||||
|
|
||||||
|
# 先加入 mandatory blocks
|
||||||
|
for block in mandatory:
|
||||||
|
if block not in selected:
|
||||||
|
selected.append(block)
|
||||||
|
covered.update(self._block_anchor_set(block))
|
||||||
|
|
||||||
|
current_tokens = sum(b.cost() for b in selected)
|
||||||
|
|
||||||
|
# 贪心选择
|
||||||
|
for block in snippets:
|
||||||
|
if block in selected:
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_cost = block.cost()
|
||||||
|
if current_tokens + block_cost > token_budget:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算这块的增量覆盖
|
||||||
|
block_anchors = self._block_anchor_set(block)
|
||||||
|
new_anchors = block_anchors - covered
|
||||||
|
|
||||||
|
if not new_anchors:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算新增覆盖的 IDF 加权和(这里简化处理)
|
||||||
|
new_coverage = len(new_anchors) / max(len(query_anchors), 1)
|
||||||
|
gain = new_coverage / (block_cost ** self.cost_alpha)
|
||||||
|
|
||||||
|
if gain < self.MIN_GAIN:
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected.append(block)
|
||||||
|
covered.update(new_anchors)
|
||||||
|
current_tokens += block_cost
|
||||||
|
|
||||||
|
# 检查是否达到覆盖目标
|
||||||
|
coverage_ratio = sum(
|
||||||
|
1 for a in query_anchors if a in covered
|
||||||
|
) / max(len(query_anchors), 1)
|
||||||
|
|
||||||
|
if coverage_ratio >= self.coverage_target:
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def _block_anchor_set(self, block: Block) -> Set[str]:
|
||||||
|
"""获取 block 的锚点集合(去 a_ 前缀,转小写)"""
|
||||||
|
return {
|
||||||
|
(a.replace('a_', '', 1) if a.startswith('a_') else a).lower()
|
||||||
|
for a in block.anchors
|
||||||
|
}
|
||||||
|
|
||||||
|
def coverage_ratio(self, selected: List[Block], query_anchors: List[str]) -> float:
|
||||||
|
"""计算选中的 blocks 对 query 锚点的覆盖率"""
|
||||||
|
covered = set()
|
||||||
|
for block in selected:
|
||||||
|
covered.update(self._block_anchor_set(block))
|
||||||
|
|
||||||
|
if not query_anchors:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
hit = sum(1 for a in query_anchors if a in covered)
|
||||||
|
return hit / len(query_anchors)
|
||||||
128
src/sparse.py
Normal file
128
src/sparse.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
稀疏召回模块 - 用 BM25/IDF-overlap 对历史 blocks 做评分和召回
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple
|
||||||
|
from .block import Block
|
||||||
|
from .anchor import AnchorExtractor
|
||||||
|
|
||||||
|
|
||||||
|
class SparseRetriever:
|
||||||
|
"""稀疏召回器:基于词项重叠的轻量评分"""
|
||||||
|
|
||||||
|
# 评分权重
|
||||||
|
WEIGHT_USER = 1.5
|
||||||
|
WEIGHT_ASSISTANT = 0.7
|
||||||
|
WEIGHT_EXACT = 1.0
|
||||||
|
WEIGHT_RECENT = 0.2
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.anchor_extractor = AnchorExtractor()
|
||||||
|
|
||||||
|
def score(self, block: Block, query_anchors: List[str], recency: float = 0.0) -> float:
|
||||||
|
"""
|
||||||
|
计算 block 相对于 query 的相关度得分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: 待评分的对话块
|
||||||
|
query_anchors: 查询的锚点列表
|
||||||
|
recency: 时间衰减因子 (0.0 ~ 1.0,越新越大)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关度得分 (0.0 ~ )
|
||||||
|
"""
|
||||||
|
# 1. 用户侧词项重叠
|
||||||
|
user_anchors = [a for a in block.anchors if not a.startswith('a_')]
|
||||||
|
assistant_anchors = [a.replace('a_', '', 1) for a in block.anchors if a.startswith('a_')]
|
||||||
|
|
||||||
|
lex_user = self._lex_overlap(user_anchors, query_anchors)
|
||||||
|
lex_assistant = self._lex_overlap(assistant_anchors, query_anchors)
|
||||||
|
|
||||||
|
# 2. Exact match 加分
|
||||||
|
exact = self._exact_match(block, query_anchors)
|
||||||
|
|
||||||
|
# 3. 综合评分
|
||||||
|
score = (
|
||||||
|
self.WEIGHT_USER * lex_user +
|
||||||
|
self.WEIGHT_ASSISTANT * lex_assistant +
|
||||||
|
self.WEIGHT_EXACT * exact +
|
||||||
|
self.WEIGHT_RECENT * recency
|
||||||
|
)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _lex_overlap(self, block_anchors: List[str], query_anchors: List[str]) -> float:
|
||||||
|
"""计算词项重叠得分(带 IDF 权重)"""
|
||||||
|
score = 0.0
|
||||||
|
# 将 query anchors 转小写用于匹配
|
||||||
|
query_lower = [a.lower() for a in query_anchors]
|
||||||
|
for qa in query_lower:
|
||||||
|
# 精确匹配
|
||||||
|
if qa in block_anchors:
|
||||||
|
tf = min(block_anchors.count(qa), 3)
|
||||||
|
score += 1.0 * tf
|
||||||
|
# 或者前缀/后缀匹配(处理中英文混合)
|
||||||
|
else:
|
||||||
|
for ba in block_anchors:
|
||||||
|
if qa in ba.lower() or ba.lower() in qa:
|
||||||
|
score += 0.5
|
||||||
|
break
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _exact_match(self, block: Block, query_anchors: List[str]) -> float:
|
||||||
|
"""
|
||||||
|
Exact match 加分:
|
||||||
|
- 完整英文术语命中
|
||||||
|
- 代码标识符命中
|
||||||
|
- 数字、版本号命中
|
||||||
|
- 引号短语完整命中
|
||||||
|
"""
|
||||||
|
text_lower = block.user_text.lower() + ' ' + block.assistant_text.lower()
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
for qa in query_anchors:
|
||||||
|
# 版本号(v开头)
|
||||||
|
if qa.startswith('v') and '.' in qa:
|
||||||
|
if qa in text_lower:
|
||||||
|
score += 1.0
|
||||||
|
# 全数字(如 1.2.3, 2.0)
|
||||||
|
elif qa.replace('.', '').isdigit():
|
||||||
|
if qa in text_lower:
|
||||||
|
score += 1.0
|
||||||
|
# 英文术语(较长,有下划线/驼峰)
|
||||||
|
elif '_' in qa or (any(c.isupper() for c in qa) and qa.isalnum()):
|
||||||
|
if qa.lower() in text_lower:
|
||||||
|
score += 1.0
|
||||||
|
# 引号短语(带空格)
|
||||||
|
elif ' ' in qa and any(q in text_lower for q in [qa, qa.strip('"\'')]):
|
||||||
|
score += 0.5
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self, blocks: List[Block], query_anchors: List[str], top_m: int = 20
|
||||||
|
) -> List[Tuple[Block, float]]:
|
||||||
|
"""
|
||||||
|
从历史 blocks 中召回 top-m 个最相关的 block
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocks: 所有历史 blocks
|
||||||
|
query_anchors: 查询锚点
|
||||||
|
top_m: 返回前 m 个
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[(block, score), ...] 按得分降序排列
|
||||||
|
"""
|
||||||
|
scored = []
|
||||||
|
total = len(blocks)
|
||||||
|
|
||||||
|
for i, block in enumerate(blocks):
|
||||||
|
# recency = i/total,越新的 block 越接近 1.0
|
||||||
|
recency = (i + 1) / total if total > 0 else 0.0
|
||||||
|
s = self.score(block, query_anchors, recency)
|
||||||
|
scored.append((block, s))
|
||||||
|
|
||||||
|
# 降序排列,取前 top_m
|
||||||
|
scored.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return scored[:top_m]
|
||||||
106
src/topic_gate.py
Normal file
106
src/topic_gate.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
话题门控模块 - 判断当前 query 是继续话题还是切换话题
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
from .anchor import AnchorExtractor
|
||||||
|
|
||||||
|
|
||||||
|
# 指代词列表
|
||||||
|
DEICTIC_WORDS = {
|
||||||
|
'这个', '那个', '它', '她', '他', '上面', '刚才', '继续',
|
||||||
|
'展开', '为什么会这样', '怎么改', '然后', '还有', '另外',
|
||||||
|
'这样', '那样', '如此', '前一个', '这一', '那一'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TopicGate:
|
||||||
|
"""话题门控:判断是继续当前话题还是切换到新话题"""
|
||||||
|
|
||||||
|
# 阈值参数
|
||||||
|
OVERLAP_CONTINUE_THRESHOLD = 0.45
|
||||||
|
OVERLAP_SWITCH_THRESHOLD = 0.20
|
||||||
|
NEW_RATIO_SWITCH_THRESHOLD = 0.70
|
||||||
|
|
||||||
|
def __init__(self, anchor_extractor: AnchorExtractor = None):
|
||||||
|
self.anchor_extractor = anchor_extractor or AnchorExtractor()
|
||||||
|
self.active_topic_anchors: List[str] = []
|
||||||
|
self.active_topic_idf: dict[str, float] = {}
|
||||||
|
|
||||||
|
def update_active_topic(self, anchors: List[str], idf: dict[str, float]) -> None:
|
||||||
|
"""更新活跃话题的锚点集合"""
|
||||||
|
self.active_topic_anchors = anchors
|
||||||
|
self.active_topic_idf = idf
|
||||||
|
|
||||||
|
def is_topic_switch(
|
||||||
|
self, query: str, active_topic: Tuple[List[str], dict[str, float]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
判断当前 query 是否切换了话题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True -> 切换到新话题
|
||||||
|
False -> 继续当前话题
|
||||||
|
"""
|
||||||
|
anchors, has_deictic = self.anchor_extractor.extract_with_deictic(query)
|
||||||
|
|
||||||
|
if active_topic is None:
|
||||||
|
self.active_topic_anchors = anchors
|
||||||
|
return False
|
||||||
|
|
||||||
|
topic_anchors, topic_idf = active_topic
|
||||||
|
|
||||||
|
# 有指代词 → 强暗示继续
|
||||||
|
if has_deictic:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 计算 overlap 和 new_ratio
|
||||||
|
overlap = self._compute_overlap(anchors, topic_anchors, topic_idf)
|
||||||
|
new_ratio = self._compute_new_ratio(anchors, topic_anchors, topic_idf)
|
||||||
|
|
||||||
|
# 切换判断
|
||||||
|
if overlap < self.OVERLAP_SWITCH_THRESHOLD and new_ratio > self.NEW_RATIO_SWITCH_THRESHOLD:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 继续判断
|
||||||
|
if overlap > self.OVERLAP_CONTINUE_THRESHOLD:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 中间地带:默认继续
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _compute_overlap(
|
||||||
|
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
||||||
|
) -> float:
|
||||||
|
"""计算加权重叠率"""
|
||||||
|
if not query_anchors:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
overlap_weight = sum(
|
||||||
|
topic_idf.get(a, 1.0)
|
||||||
|
for a in query_anchors
|
||||||
|
if a in topic_anchors
|
||||||
|
)
|
||||||
|
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
||||||
|
|
||||||
|
return overlap_weight / total_weight if total_weight > 0 else 0.0
|
||||||
|
|
||||||
|
def _compute_new_ratio(
|
||||||
|
self, query_anchors: List[str], topic_anchors: List[str], topic_idf: dict[str, float]
|
||||||
|
) -> float:
|
||||||
|
"""计算新锚点比例"""
|
||||||
|
if not query_anchors:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
new_anchors = set(query_anchors) - set(topic_anchors)
|
||||||
|
new_weight = sum(topic_idf.get(a, 1.0) for a in new_anchors)
|
||||||
|
total_weight = sum(topic_idf.get(a, 1.0) for a in query_anchors)
|
||||||
|
|
||||||
|
return new_weight / total_weight if total_weight > 0 else 1.0
|
||||||
|
|
||||||
|
def compute_idf_for_anchors(self, anchors: List[str]) -> dict[str, float]:
|
||||||
|
"""为锚点集合计算 IDF 值(用于后续计算)"""
|
||||||
|
idf = {}
|
||||||
|
for a in anchors:
|
||||||
|
idf[a] = self.anchor_extractor.idf(a)
|
||||||
|
return idf
|
||||||
147
tests/test_e2e.py
Normal file
147
tests/test_e2e.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
端到端测试 - 使用 MiniMax API 验证上下文门控器效果
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 加载 .env
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from src.gatekeeper import ContextGatekeeper
|
||||||
|
|
||||||
|
# 获取 API Key
|
||||||
|
API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||||
|
if not API_KEY:
|
||||||
|
print("❌ 未找到 MINIMAX_API_KEY,请检查 .env 文件")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
BASE_URL = "https://api.minimaxi.com/v1/text/chatcompletion_v2"
|
||||||
|
|
||||||
|
|
||||||
|
def call_minimax(prompt: str) -> str:
|
||||||
|
"""调用 MiniMax API"""
|
||||||
|
import urllib.request
|
||||||
|
import json
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "MiniMax-M2.7",
|
||||||
|
"messages": [{"role": "user", "content": prompt[:2000]}], # 限制长度
|
||||||
|
"max_tokens": 500,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
BASE_URL,
|
||||||
|
data=data,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
method="POST"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
result = json.loads(resp.read().decode("utf-8"))
|
||||||
|
return result["choices"][0]["message"]["content"]
|
||||||
|
except Exception as e:
|
||||||
|
return f"API 调用失败: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_conversation():
|
||||||
|
"""端到端对话测试"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("端到端对话测试 - 验证上下文门控器")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
gate = ContextGatekeeper(token_budget=2000)
|
||||||
|
|
||||||
|
# === 第一阶段:Redis 分布式锁话题 ===
|
||||||
|
print("\n📌 第1轮:Redis 分布式锁话题")
|
||||||
|
user1 = "Redis 锁续租为什么会脑裂"
|
||||||
|
assistant1 = call_minimax(f"请用 2-3 句话回答: {user1}")
|
||||||
|
gate.add_turn(user1, assistant1)
|
||||||
|
print(f"用户: {user1}")
|
||||||
|
print(f"助手: {assistant1}")
|
||||||
|
|
||||||
|
print("\n📌 第2轮:继续 Redis 话题")
|
||||||
|
user2 = "如何避免脑裂?"
|
||||||
|
assistant2 = call_minimax(f"请用 2-3 句话回答: {user2}")
|
||||||
|
gate.add_turn(user2, assistant2)
|
||||||
|
print(f"用户: {user2}")
|
||||||
|
print(f"助手: {assistant2}")
|
||||||
|
|
||||||
|
# 验证:第3轮问 Redis 相关,应该召回第1轮(可能不召回第2轮,取决于锚点重叠度)
|
||||||
|
print("\n📌 第3轮:问 Redis 相关问题")
|
||||||
|
query3 = "锁的 TTL 怎么设置才合理"
|
||||||
|
selected = gate.select(query3)
|
||||||
|
print(f"用户查询: {query3}")
|
||||||
|
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected]}")
|
||||||
|
turn_ids = [b['turn_id'] for b in selected]
|
||||||
|
assert 1 in turn_ids, "❌ 应召回第1轮 Redis 内容"
|
||||||
|
print("✅ Redis 相关问题召回第1轮内容")
|
||||||
|
|
||||||
|
# === 第二阶段:切换到 Python 话题 ===
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📌 第4轮:切换到 Python 话题")
|
||||||
|
user4 = "Python 异步编程怎么做?用 asyncio 举例子"
|
||||||
|
assistant4 = call_minimax(f"请用 3-4 句话回答: {user4}")
|
||||||
|
gate.add_turn(user4, assistant4)
|
||||||
|
print(f"用户: {user4}")
|
||||||
|
print(f"助手: {assistant4}")
|
||||||
|
|
||||||
|
# 验证:问 Python 相关,应该召回 Python 内容(3或4)
|
||||||
|
print("\n📌 第5轮:问 Python asyncio 相关")
|
||||||
|
query5 = "asyncio 怎么用?举一个爬虫的例子"
|
||||||
|
selected5 = gate.select(query5)
|
||||||
|
print(f"用户查询: {query5}")
|
||||||
|
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected5]}")
|
||||||
|
turn_ids5 = [b['turn_id'] for b in selected5]
|
||||||
|
# Python 话题应该召回 3 或 4
|
||||||
|
has_python_topic = (3 in turn_ids5 or 4 in turn_ids5)
|
||||||
|
assert has_python_topic, f"❌ 应召回 Python 内容,实际: {turn_ids5}"
|
||||||
|
print(f"✅ Python 相关问题召回正确轮次: {turn_ids5}")
|
||||||
|
|
||||||
|
# === 第三阶段:指代词测试 ===
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📌 第6轮:指代词测试")
|
||||||
|
user6 = "它的并发性能怎么样"
|
||||||
|
assistant6 = call_minimax(f"请用 2-3 句话回答: {user6}")
|
||||||
|
gate.add_turn(user6, assistant6)
|
||||||
|
print(f"用户: {user6}")
|
||||||
|
print(f"助手: {assistant6}")
|
||||||
|
|
||||||
|
# 验证:有指代词时,应该强制继承最近轮次
|
||||||
|
print("\n📌 第7轮:指代词强制继承验证")
|
||||||
|
query7 = "它和 ThreadPool 比哪个更好"
|
||||||
|
selected7 = gate.select(query7)
|
||||||
|
print(f"用户查询: {query7}")
|
||||||
|
print(f"召回的上下文轮次: {[b['turn_id'] for b in selected7]}")
|
||||||
|
# 应该有指代词强制继承
|
||||||
|
assert len(selected7) >= 2, "❌ 有指代词时应强制继承最近 2 个 block"
|
||||||
|
print(f"✅ 指代词触发强制继承,召回了 {len(selected7)} 个 block")
|
||||||
|
|
||||||
|
# === 验证 Token 预算控制 ===
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📌 Token 预算验证")
|
||||||
|
total_context_tokens = sum(
|
||||||
|
len(b['user']) * 1.5 + len(b['assistant']) * 1.5
|
||||||
|
for b in selected
|
||||||
|
)
|
||||||
|
print(f"当前上下文 token 估算: {total_context_tokens:.0f} / {gate.token_budget}")
|
||||||
|
assert total_context_tokens <= gate.token_budget * 1.5, "❌ 超出 token 预算"
|
||||||
|
print("✅ Token 预算控制正常")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("✅ 所有端到端测试通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_e2e_conversation()
|
||||||
136
tests/test_gatekeeper.py
Normal file
136
tests/test_gatekeeper.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
上下文门控器 - 轻量级选择器测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from src.gatekeeper import ContextGatekeeper
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnchorExtraction:
|
||||||
|
"""锚点提取测试"""
|
||||||
|
|
||||||
|
def test_chinese_ngram(self):
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
# 触发锚点提取
|
||||||
|
gate.add_turn("Redis 分布式锁如何实现", "可以使用 Redlock 算法...")
|
||||||
|
assert len(gate.blocks) == 1
|
||||||
|
anchors = gate.blocks[0].anchors
|
||||||
|
# 应该有 redis、分布式锁等锚点
|
||||||
|
assert any('redis' in a.lower() for a in anchors)
|
||||||
|
|
||||||
|
def test_empty_turn(self):
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.add_turn("", "")
|
||||||
|
assert len(gate.blocks) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopicGate:
|
||||||
|
"""话题门控测试"""
|
||||||
|
|
||||||
|
def test_obvious_switch(self):
|
||||||
|
"""从 Redis 切换到 Python,明显切换"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.add_turn("Redis 锁怎么用", "用 Redlock...")
|
||||||
|
gate.add_turn("Python 怎么写快速排序", "可以用递归...")
|
||||||
|
|
||||||
|
# 第三个问题,明显是 Python 话题
|
||||||
|
blocks = gate.select("Python 列表推导式怎么写")
|
||||||
|
# 应该主要返回第二、三轮的 Python 内容
|
||||||
|
turn_ids = [b['turn_id'] for b in blocks]
|
||||||
|
assert 2 in turn_ids or 3 in turn_ids
|
||||||
|
|
||||||
|
def test_continuation_with_deictic(self):
|
||||||
|
"""指代词触发强制继承"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.add_turn("Redis 锁是什么", "是分布式锁...")
|
||||||
|
gate.add_turn("它有哪些实现方式", "有 Redlock...")
|
||||||
|
|
||||||
|
# 第二轮有"它",应该强制继承第一轮
|
||||||
|
blocks = gate.select("它有哪些实现方式")
|
||||||
|
turn_ids = [b['turn_id'] for b in blocks]
|
||||||
|
# 应该包含第1轮
|
||||||
|
assert 1 in turn_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestSparseRetrieval:
|
||||||
|
"""稀疏召回测试"""
|
||||||
|
|
||||||
|
def test_exact_match_boost(self):
|
||||||
|
"""exact match 应该提升得分"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.add_turn("使用 Redis v3.0 集群", "Redis 集群配置...")
|
||||||
|
gate.add_turn("Python 装饰器用法", "装饰器是...")
|
||||||
|
|
||||||
|
blocks = gate.select("Redis v3.0 集群如何搭建")
|
||||||
|
turn_ids = [b['turn_id'] for b in blocks]
|
||||||
|
# 应该召回第一轮
|
||||||
|
assert 1 in turn_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestMinimumCoverage:
|
||||||
|
"""最小覆盖选择测试"""
|
||||||
|
|
||||||
|
def test_token_budget(self):
|
||||||
|
"""token 预算限制"""
|
||||||
|
gate = ContextGatekeeper(token_budget=300) # 调整预算,100字符block约150tokens
|
||||||
|
gate.add_turn("A" * 100, "B" * 100)
|
||||||
|
gate.add_turn("C" * 100, "D" * 100)
|
||||||
|
|
||||||
|
blocks = gate.select("A")
|
||||||
|
# 预算300,约等于2个block的代价,应该只能选1个
|
||||||
|
assert len(blocks) <= 2, f"预算300应该只选1-2个block,实际选了{len(blocks)}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
"""端到端集成测试"""
|
||||||
|
|
||||||
|
def test_multi_turn_conversation(self):
|
||||||
|
"""模拟多轮对话"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
|
||||||
|
# 第1轮
|
||||||
|
gate.add_turn("Redis 锁续租为什么会脑裂", "因为锁过期时间设置不合理...")
|
||||||
|
# 第2轮
|
||||||
|
gate.add_turn("如何避免脑裂", "可以增加时钟偏移检测...")
|
||||||
|
# 第3轮:切换话题
|
||||||
|
gate.add_turn("Python 异步编程怎么做", "用 asyncio 模块...")
|
||||||
|
|
||||||
|
# 问 Redis 相关问题,验证能召回 Redis 内容
|
||||||
|
redis_blocks = gate.select("Redis 锁的 TTL 怎么设")
|
||||||
|
turn_ids = [b['turn_id'] for b in redis_blocks]
|
||||||
|
assert 1 in turn_ids, f"第1轮 Redis 内容应该被召回,实际: {turn_ids}"
|
||||||
|
print(f" Redis 查询召回: {turn_ids}")
|
||||||
|
|
||||||
|
# 问 Python 相关问题,验证话题切换后召回正确内容
|
||||||
|
py_blocks = gate.select("asyncio 怎么用")
|
||||||
|
turn_ids_py = [b['turn_id'] for b in py_blocks]
|
||||||
|
assert 3 in turn_ids_py, f"第3轮 Python 内容应该被召回,实际: {turn_ids_py}"
|
||||||
|
print(f" Python 查询召回: {turn_ids_py}")
|
||||||
|
|
||||||
|
def test_constraints_preserved(self):
|
||||||
|
"""约束持久化测试"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.set_constraint("language", "中文")
|
||||||
|
gate.set_constraint("style", "简洁")
|
||||||
|
|
||||||
|
constraints = gate.get_constraints()
|
||||||
|
assert constraints["language"] == "中文"
|
||||||
|
assert constraints["style"] == "简洁"
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
"""重置测试"""
|
||||||
|
gate = ContextGatekeeper()
|
||||||
|
gate.add_turn("test", "test")
|
||||||
|
assert len(gate.blocks) == 1
|
||||||
|
|
||||||
|
gate.reset()
|
||||||
|
assert len(gate.blocks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user