Files
QQbot/tests/test_memory_embedding_policy.py

69 lines
1.7 KiB
Python

import asyncio
from pathlib import Path
from src.ai.memory import MemorySystem
def test_query_does_not_use_embedding_when_disabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=False,
)
# 入库路径仍会触发 embedding
stored = asyncio.run(
memory.add_long_term(
user_id="u1",
content="我喜欢苹果和香蕉",
importance=0.9,
metadata={"source": "test"},
)
)
assert stored is not None
assert calls["count"] == 1
# 查询上下文与搜索均不触发 embedding
asyncio.run(memory.get_context(user_id="u1", query="苹果"))
asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5))
assert calls["count"] == 1
asyncio.run(memory.close())
def test_query_uses_embedding_when_enabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=True,
)
asyncio.run(
memory.add_long_term(
user_id="u1",
content="北京天气不错",
importance=0.9,
metadata={"source": "test"},
)
)
assert calls["count"] == 1
asyncio.run(memory.get_context(user_id="u1", query="北京"))
assert calls["count"] >= 2
asyncio.run(memory.close())