108 lines
4.0 KiB
Python
108 lines
4.0 KiB
Python
"""
|
||
对照实验:有上下文门控 vs 无上下文门控
|
||
使用 SiliconFlow Qwen/Qwen3-8B 模型
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import requests
|
||
from src.gatekeeper import ContextGatekeeper
|
||
|
||
# SiliconFlow API 配置
|
||
API_KEY = "sk-ryxkiqmodfrlthvzvcwrrvbcxilkfibymjrkorgkplhctwff"
|
||
API_URL = "https://api.siliconflow.cn/v1/chat/completions"
|
||
|
||
def call_llm(prompt: str, model: str = "Qwen/Qwen3-8B") -> str:
|
||
"""调用 SiliconFlow LLM"""
|
||
headers = {
|
||
"Authorization": f"Bearer {API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 512,
|
||
"temperature": 0.7
|
||
}
|
||
resp = requests.post(API_URL, headers=headers, json=payload, timeout=60)
|
||
resp.raise_for_status()
|
||
return resp.json()["choices"][0]["message"]["content"]
|
||
|
||
|
||
def build_prompt_no_gatekeeper(query: str, history: list) -> str:
|
||
"""无门控:直接拼接最近N轮历史"""
|
||
context_parts = []
|
||
for h in history[-3:]: # 最近3轮
|
||
context_parts.append(f"用户: {h['user']}\n助手: {h['assistant']}")
|
||
context_str = "\n\n".join(context_parts)
|
||
return f"{context_str}\n\n用户: {query}"
|
||
|
||
|
||
def main():
|
||
gk = ContextGatekeeper(token_budget=1500)
|
||
|
||
# 构造一段有话题切换的对话历史
|
||
conversations = [
|
||
("如何设计一个 Redis 分布式锁?",
|
||
"分布式锁需要满足互斥性、死锁避免、性能要求。常用 Redisson 实现,核心是 SET if Not Exists + 过期时间。"),
|
||
("锁的 TTL 设置多少合适?",
|
||
"TTL 取决于业务耗时,建议 3-5 倍 buffer。同时要 watchdog 续期机制。"),
|
||
("介绍一下 Python 的异步编程",
|
||
"Python 异步编程用 async/await,配合事件循环。asyncio 是标准库,典型场景是 IO 密集型任务。"),
|
||
("asyncio 是怎么工作的?",
|
||
"asyncio 基于协程和事件循环。调用 await 时协程挂起,事件循环调度其他协程执行。"),
|
||
("Redis 支持哪些数据结构?",
|
||
"Redis 支持 String、Hash、List、Set、ZSet 五种基本类型,还有 Bitmap、HyperLogLog 等。"),
|
||
("它和 Memcached 有什么区别?",
|
||
"Redis 是持久化数据库,Memcached 是纯内存缓存。Redis 支持更多数据结构。"),
|
||
]
|
||
|
||
history = []
|
||
for u, a in conversations:
|
||
gk.add_turn(u, a)
|
||
history.append({"user": u, "assistant": a})
|
||
|
||
# 测试Query:话题已切换到Python,问的是Redis(有上下文污染风险)
|
||
test_query = "如何保证 Redis 缓存和数据库一致性?"
|
||
|
||
print("=" * 70)
|
||
print("对照实验:Qwen3-8B 有/无上下文门控")
|
||
print("=" * 70)
|
||
print(f"\n测试Query: {test_query}\n")
|
||
|
||
# --- 无门控 ---
|
||
print("【无门控】最近3轮直接拼接")
|
||
print("-" * 50)
|
||
prompt_no_gate = build_prompt_no_gatekeeper(test_query, history)
|
||
print(f"[输入]\n{prompt_no_gate}\n")
|
||
answer_no_gate = call_llm(prompt_no_gate)
|
||
print(f"[输出] {answer_no_gate[:200]}...")
|
||
|
||
print()
|
||
|
||
# --- 有门控 ---
|
||
print("【有门控】上下文门控器选择相关片段")
|
||
print("-" * 50)
|
||
selected = gk.select(test_query)
|
||
print(f"召回 blocks: {[b['turn_id'] for b in selected]}")
|
||
|
||
context_parts = []
|
||
for b in selected:
|
||
context_parts.append(f"【轮次 {b['turn_id']}】\n用户: {b['user']}\n助手: {b['assistant']}")
|
||
context_str = "\n\n".join(context_parts)
|
||
prompt_with_gate = f"你是一个有帮助的助手。\n\n【相关上下文】\n{context_str}\n\n【当前问题】\n用户: {test_query}"
|
||
|
||
print(f"[输入]\n{prompt_with_gate}\n")
|
||
answer_with_gate = call_llm(prompt_with_gate)
|
||
print(f"[输出] {answer_with_gate[:200]}...")
|
||
|
||
print("\n" + "=" * 70)
|
||
print("对比分析:")
|
||
print(f"无门控 - 可能受最近Python话题干扰")
|
||
print(f"有门控 - 仅召回Redis相关轮次 {[b['turn_id'] for b in selected]}")
|
||
print("=" * 70)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|