69 lines
1.7 KiB
Python
69 lines
1.7 KiB
Python
import asyncio
|
|
from pathlib import Path
|
|
|
|
from src.ai.memory import MemorySystem
|
|
|
|
|
|
def test_query_does_not_use_embedding_when_disabled(tmp_path: Path):
|
|
calls = {"count": 0}
|
|
|
|
async def fake_embed(_text: str):
|
|
calls["count"] += 1
|
|
return [0.1] * 8
|
|
|
|
memory = MemorySystem(
|
|
storage_path=tmp_path / "memory.json",
|
|
embed_func=fake_embed,
|
|
use_vector_db=False,
|
|
use_query_embedding=False,
|
|
)
|
|
|
|
# 入库路径仍会触发 embedding
|
|
stored = asyncio.run(
|
|
memory.add_long_term(
|
|
user_id="u1",
|
|
content="我喜欢苹果和香蕉",
|
|
importance=0.9,
|
|
metadata={"source": "test"},
|
|
)
|
|
)
|
|
assert stored is not None
|
|
assert calls["count"] == 1
|
|
|
|
# 查询上下文与搜索均不触发 embedding
|
|
asyncio.run(memory.get_context(user_id="u1", query="苹果"))
|
|
asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5))
|
|
assert calls["count"] == 1
|
|
|
|
asyncio.run(memory.close())
|
|
|
|
|
|
def test_query_uses_embedding_when_enabled(tmp_path: Path):
|
|
calls = {"count": 0}
|
|
|
|
async def fake_embed(_text: str):
|
|
calls["count"] += 1
|
|
return [0.1] * 8
|
|
|
|
memory = MemorySystem(
|
|
storage_path=tmp_path / "memory.json",
|
|
embed_func=fake_embed,
|
|
use_vector_db=False,
|
|
use_query_embedding=True,
|
|
)
|
|
|
|
asyncio.run(
|
|
memory.add_long_term(
|
|
user_id="u1",
|
|
content="北京天气不错",
|
|
importance=0.9,
|
|
metadata={"source": "test"},
|
|
)
|
|
)
|
|
assert calls["count"] == 1
|
|
|
|
asyncio.run(memory.get_context(user_id="u1", query="北京"))
|
|
assert calls["count"] >= 2
|
|
|
|
asyncio.run(memory.close())
|