Files
context-gatekeeper/test_comparison.py

108 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对照实验:有上下文门控 vs 无上下文门控
使用 SiliconFlow Qwen/Qwen3-8B 模型
"""
import os
import json
import requests
from src.gatekeeper import ContextGatekeeper
# SiliconFlow API 配置
API_KEY = "sk-ryxkiqmodfrlthvzvcwrrvbcxilkfibymjrkorgkplhctwff"
API_URL = "https://api.siliconflow.cn/v1/chat/completions"
def call_llm(prompt: str, model: str = "Qwen/Qwen3-8B") -> str:
"""调用 SiliconFlow LLM"""
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 512,
"temperature": 0.7
}
resp = requests.post(API_URL, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def build_prompt_no_gatekeeper(query: str, history: list) -> str:
"""无门控直接拼接最近N轮历史"""
context_parts = []
for h in history[-3:]: # 最近3轮
context_parts.append(f"用户: {h['user']}\n助手: {h['assistant']}")
context_str = "\n\n".join(context_parts)
return f"{context_str}\n\n用户: {query}"
def main():
gk = ContextGatekeeper(token_budget=1500)
# 构造一段有话题切换的对话历史
conversations = [
("如何设计一个 Redis 分布式锁?",
"分布式锁需要满足互斥性、死锁避免、性能要求。常用 Redisson 实现,核心是 SET if Not Exists + 过期时间。"),
("锁的 TTL 设置多少合适?",
"TTL 取决于业务耗时,建议 3-5 倍 buffer。同时要 watchdog 续期机制。"),
("介绍一下 Python 的异步编程",
"Python 异步编程用 async/await配合事件循环。asyncio 是标准库,典型场景是 IO 密集型任务。"),
("asyncio 是怎么工作的?",
"asyncio 基于协程和事件循环。调用 await 时协程挂起,事件循环调度其他协程执行。"),
("Redis 支持哪些数据结构?",
"Redis 支持 String、Hash、List、Set、ZSet 五种基本类型,还有 Bitmap、HyperLogLog 等。"),
("它和 Memcached 有什么区别?",
"Redis 是持久化数据库Memcached 是纯内存缓存。Redis 支持更多数据结构。"),
]
history = []
for u, a in conversations:
gk.add_turn(u, a)
history.append({"user": u, "assistant": a})
# 测试Query话题已切换到Python问的是Redis有上下文污染风险
test_query = "如何保证 Redis 缓存和数据库一致性?"
print("=" * 70)
print("对照实验Qwen3-8B 有/无上下文门控")
print("=" * 70)
print(f"\n测试Query: {test_query}\n")
# --- 无门控 ---
print("【无门控】最近3轮直接拼接")
print("-" * 50)
prompt_no_gate = build_prompt_no_gatekeeper(test_query, history)
print(f"[输入]\n{prompt_no_gate}\n")
answer_no_gate = call_llm(prompt_no_gate)
print(f"[输出] {answer_no_gate[:200]}...")
print()
# --- 有门控 ---
print("【有门控】上下文门控器选择相关片段")
print("-" * 50)
selected = gk.select(test_query)
print(f"召回 blocks: {[b['turn_id'] for b in selected]}")
context_parts = []
for b in selected:
context_parts.append(f"【轮次 {b['turn_id']}\n用户: {b['user']}\n助手: {b['assistant']}")
context_str = "\n\n".join(context_parts)
prompt_with_gate = f"你是一个有帮助的助手。\n\n【相关上下文】\n{context_str}\n\n【当前问题】\n用户: {test_query}"
print(f"[输入]\n{prompt_with_gate}\n")
answer_with_gate = call_llm(prompt_with_gate)
print(f"[输出] {answer_with_gate[:200]}...")
print("\n" + "=" * 70)
print("对比分析:")
print(f"无门控 - 可能受最近Python话题干扰")
print(f"有门控 - 仅召回Redis相关轮次 {[b['turn_id'] for b in selected]}")
print("=" * 70)
if __name__ == "__main__":
main()