Implement SQLite support check for Chroma's trigram tokenizer, enhancing compatibility with cloud environments. Update README for model command syntax and add pysqlite3-binary to requirements for improved SQLite functionality.
This commit is contained in:
@@ -73,8 +73,8 @@ python main.py
|
|||||||
- `/models current`
|
- `/models current`
|
||||||
- `/models add <model_name>`
|
- `/models add <model_name>`
|
||||||
- `/models add <key> <provider> <model_name> [api_base]`
|
- `/models add <key> <provider> <model_name> [api_base]`
|
||||||
- `/models switch <key>`
|
- `/models switch <key|index>`
|
||||||
- `/models remove <key>`
|
- `/models remove <key|index>`
|
||||||
|
|
||||||
说明:
|
说明:
|
||||||
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key
|
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key
|
||||||
|
|||||||
47
main.py
47
main.py
@@ -8,6 +8,53 @@ from pathlib import Path
|
|||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
|
||||||
|
def _sqlite_supports_trigram(sqlite_module) -> bool:
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = sqlite_module.connect(":memory:")
|
||||||
|
conn.execute("create virtual table t using fts5(content, tokenize='trigram')")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_sqlite_for_chroma():
|
||||||
|
"""
|
||||||
|
Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma.
|
||||||
|
On some cloud images, system sqlite lacks trigram support.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import sqlite3
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _sqlite_supports_trigram(sqlite3):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pysqlite3
|
||||||
|
except Exception as exc:
|
||||||
|
print(
|
||||||
|
"[WARN] sqlite3 does not support trigram tokenizer and pysqlite3 is unavailable: "
|
||||||
|
f"{exc}"
|
||||||
|
)
|
||||||
|
print("[WARN] Chroma may fail and fallback to JSON storage.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if _sqlite_supports_trigram(pysqlite3):
|
||||||
|
sys.modules["sqlite3"] = pysqlite3
|
||||||
|
print("[INFO] sqlite3 switched to pysqlite3 for Chroma trigram support.")
|
||||||
|
else:
|
||||||
|
print("[WARN] pysqlite3 is installed but still lacks trigram tokenizer support.")
|
||||||
|
print("[WARN] Chroma may fail and fallback to JSON storage.")
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_sqlite_for_chroma()
|
||||||
|
|
||||||
from src.core.bot import MyClient, build_intents
|
from src.core.bot import MyClient, build_intents
|
||||||
from src.core.config import Config
|
from src.core.config import Config
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
|
|||||||
@@ -12,3 +12,4 @@ openai>=1.0.0
|
|||||||
anthropic>=0.18.0
|
anthropic>=0.18.0
|
||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
chromadb>=0.4.0 # 向量数据库,用于记忆存储
|
chromadb>=0.4.0 # 向量数据库,用于记忆存储
|
||||||
|
pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持
|
||||||
|
|||||||
@@ -91,12 +91,15 @@ class MemorySystem:
|
|||||||
embed_func: Optional[callable] = None,
|
embed_func: Optional[callable] = None,
|
||||||
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
|
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
|
||||||
importance_threshold: float = 0.6,
|
importance_threshold: float = 0.6,
|
||||||
use_vector_db: bool = True
|
use_vector_db: bool = True,
|
||||||
|
use_query_embedding: bool = False,
|
||||||
):
|
):
|
||||||
self.short_term = ShortTermMemory()
|
self.short_term = ShortTermMemory()
|
||||||
self.embed_func = embed_func
|
self.embed_func = embed_func
|
||||||
self.importance_evaluator = importance_evaluator
|
self.importance_evaluator = importance_evaluator
|
||||||
self.importance_threshold = importance_threshold
|
self.importance_threshold = importance_threshold
|
||||||
|
# Only embed retrieval queries when explicitly enabled.
|
||||||
|
self.use_query_embedding = use_query_embedding
|
||||||
|
|
||||||
# 初始化向量存储
|
# 初始化向量存储
|
||||||
if use_vector_db:
|
if use_vector_db:
|
||||||
@@ -117,11 +120,23 @@ class MemorySystem:
|
|||||||
msg = str(error).lower()
|
msg = str(error).lower()
|
||||||
return "table embeddings already exists" in msg
|
return "table embeddings already exists" in msg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_chroma_trigram_error(error: Exception) -> bool:
|
||||||
|
msg = str(error).lower()
|
||||||
|
return "no such tokenizer: trigram" in msg
|
||||||
|
|
||||||
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
|
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
|
||||||
"""初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。"""
|
"""初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。"""
|
||||||
try:
|
try:
|
||||||
return ChromaVectorStore(chroma_path)
|
return ChromaVectorStore(chroma_path)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
if self._is_chroma_trigram_error(error):
|
||||||
|
logger.warning(
|
||||||
|
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。"
|
||||||
|
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
if not self._is_chroma_table_conflict(error):
|
if not self._is_chroma_table_conflict(error):
|
||||||
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
|
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
|
||||||
return None
|
return None
|
||||||
@@ -327,7 +342,7 @@ class MemorySystem:
|
|||||||
# 获取相关长期记忆
|
# 获取相关长期记忆
|
||||||
long_term_memories = []
|
long_term_memories = []
|
||||||
|
|
||||||
if query:
|
if query and self.use_query_embedding:
|
||||||
try:
|
try:
|
||||||
# 使用向量检索
|
# 使用向量检索
|
||||||
query_embedding = await self._build_embedding(query)
|
query_embedding = await self._build_embedding(query)
|
||||||
@@ -430,6 +445,7 @@ class MemorySystem:
|
|||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if self.use_query_embedding:
|
||||||
query_embedding = await self._build_embedding(query)
|
query_embedding = await self._build_embedding(query)
|
||||||
results = await self.vector_store.search(
|
results = await self.vector_store.search(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from botpy.message import Message
|
from botpy.message import Message
|
||||||
|
|
||||||
@@ -33,7 +33,8 @@ class MessageHandler:
|
|||||||
(re.compile(r"\*\*([^*]+)\*\*"), r"\1"),
|
(re.compile(r"\*\*([^*]+)\*\*"), r"\1"),
|
||||||
(re.compile(r"\*([^*]+)\*"), r"\1"),
|
(re.compile(r"\*([^*]+)\*"), r"\1"),
|
||||||
(re.compile(r"__([^_]+)__"), r"\1"),
|
(re.compile(r"__([^_]+)__"), r"\1"),
|
||||||
(re.compile(r"_([^_]+)_"), r"\1"),
|
# Avoid stripping underscores inside identifiers like model keys.
|
||||||
|
(re.compile(r"(?<!\w)_([^_\n]+)_(?!\w)"), r"\1"),
|
||||||
(re.compile(r"^#{1,6}\s*", re.MULTILINE), ""),
|
(re.compile(r"^#{1,6}\s*", re.MULTILINE), ""),
|
||||||
(re.compile(r"^>\s?", re.MULTILINE), ""),
|
(re.compile(r"^>\s?", re.MULTILINE), ""),
|
||||||
(re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"),
|
(re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"),
|
||||||
@@ -100,8 +101,8 @@ class MessageHandler:
|
|||||||
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
|
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
|
||||||
f"{command_name} add <key> <json>\n"
|
f"{command_name} add <key> <json>\n"
|
||||||
" json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
|
" json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
|
||||||
f"{command_name} switch <key>\n"
|
f"{command_name} switch <key|index>\n"
|
||||||
f"{command_name} remove <key>"
|
f"{command_name} remove <key|index>"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -142,6 +143,66 @@ class MessageHandler:
|
|||||||
key = f"model_{key}"
|
key = f"model_{key}"
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compact_model_key(raw_key: str) -> str:
|
||||||
|
return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower())
|
||||||
|
|
||||||
|
def _ordered_model_keys(self) -> list[str]:
|
||||||
|
return sorted(self.model_profiles.keys())
|
||||||
|
|
||||||
|
def _resolve_model_selector(self, selector: str) -> str:
|
||||||
|
raw = (selector or "").strip()
|
||||||
|
if not raw:
|
||||||
|
raise ValueError("模型 key 不能为空")
|
||||||
|
|
||||||
|
ordered_keys = self._ordered_model_keys()
|
||||||
|
if raw.isdigit():
|
||||||
|
index = int(raw)
|
||||||
|
if index < 1 or index > len(ordered_keys):
|
||||||
|
raise ValueError(
|
||||||
|
f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}"
|
||||||
|
)
|
||||||
|
return ordered_keys[index - 1]
|
||||||
|
|
||||||
|
if raw in self.model_profiles:
|
||||||
|
return raw
|
||||||
|
|
||||||
|
normalized_selector: Optional[str]
|
||||||
|
try:
|
||||||
|
normalized_selector = self._normalize_model_key(raw)
|
||||||
|
except ValueError:
|
||||||
|
normalized_selector = None
|
||||||
|
|
||||||
|
if normalized_selector and normalized_selector in self.model_profiles:
|
||||||
|
return normalized_selector
|
||||||
|
|
||||||
|
normalized_candidates: Dict[str, list[str]] = {}
|
||||||
|
compact_candidates: Dict[str, list[str]] = {}
|
||||||
|
for key in ordered_keys:
|
||||||
|
try:
|
||||||
|
normalized_key = self._normalize_model_key(key)
|
||||||
|
except ValueError:
|
||||||
|
normalized_key = key.strip().lower()
|
||||||
|
normalized_candidates.setdefault(normalized_key, []).append(key)
|
||||||
|
compact_candidates.setdefault(
|
||||||
|
self._compact_model_key(normalized_key), []
|
||||||
|
).append(key)
|
||||||
|
|
||||||
|
if normalized_selector and normalized_selector in normalized_candidates:
|
||||||
|
matches = normalized_candidates[normalized_selector]
|
||||||
|
if len(matches) == 1:
|
||||||
|
return matches[0]
|
||||||
|
raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}")
|
||||||
|
|
||||||
|
compact_selector = self._compact_model_key(normalized_selector or raw)
|
||||||
|
if compact_selector in compact_candidates:
|
||||||
|
matches = compact_candidates[compact_selector]
|
||||||
|
if len(matches) == 1:
|
||||||
|
return matches[0]
|
||||||
|
raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}")
|
||||||
|
|
||||||
|
raise ValueError(f"模型配置不存在: {raw}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_provider(cls, raw_provider: str) -> ModelProvider:
|
def _parse_provider(cls, raw_provider: str) -> ModelProvider:
|
||||||
provider = cls._provider_map().get(raw_provider.strip().lower())
|
provider = cls._provider_map().get(raw_provider.strip().lower())
|
||||||
@@ -249,9 +310,29 @@ class MessageHandler:
|
|||||||
logger.warning(f"load model profiles failed, reset to defaults: {exc}")
|
logger.warning(f"load model profiles failed, reset to defaults: {exc}")
|
||||||
payload = {}
|
payload = {}
|
||||||
|
|
||||||
profiles = payload.get("profiles")
|
raw_profiles = payload.get("profiles")
|
||||||
if not isinstance(profiles, dict):
|
profiles: Dict[str, Dict[str, Any]] = {}
|
||||||
profiles = {}
|
if isinstance(raw_profiles, dict):
|
||||||
|
for raw_key, raw_profile in raw_profiles.items():
|
||||||
|
if not isinstance(raw_profile, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
key_text = str(raw_key or "").strip()
|
||||||
|
if not key_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
normalized_key = self._normalize_model_key(key_text)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if normalized_key in profiles and profiles[normalized_key] != raw_profile:
|
||||||
|
logger.warning(
|
||||||
|
f"duplicate model key after normalization, keep first: {normalized_key}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
profiles[normalized_key] = raw_profile
|
||||||
|
|
||||||
if not profiles:
|
if not profiles:
|
||||||
profiles = {
|
profiles = {
|
||||||
@@ -260,8 +341,19 @@ class MessageHandler:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
active = str(payload.get("active") or "")
|
active_raw = str(payload.get("active") or "").strip()
|
||||||
if active not in profiles:
|
active = ""
|
||||||
|
if active_raw in profiles:
|
||||||
|
active = active_raw
|
||||||
|
elif active_raw:
|
||||||
|
try:
|
||||||
|
normalized_active = self._normalize_model_key(active_raw)
|
||||||
|
except ValueError:
|
||||||
|
normalized_active = ""
|
||||||
|
if normalized_active in profiles:
|
||||||
|
active = normalized_active
|
||||||
|
|
||||||
|
if not active:
|
||||||
active = "default" if "default" in profiles else sorted(profiles.keys())[0]
|
active = "default" if "default" in profiles else sorted(profiles.keys())[0]
|
||||||
|
|
||||||
self.model_profiles = profiles
|
self.model_profiles = profiles
|
||||||
@@ -918,12 +1010,16 @@ class MessageHandler:
|
|||||||
|
|
||||||
if action in {"list", "ls"} and len(parts) <= 2:
|
if action in {"list", "ls"} and len(parts) <= 2:
|
||||||
lines = [f"当前模型配置: {self.active_model_key}"]
|
lines = [f"当前模型配置: {self.active_model_key}"]
|
||||||
for key in sorted(self.model_profiles.keys()):
|
ordered_keys = self._ordered_model_keys()
|
||||||
|
for idx, key in enumerate(ordered_keys, start=1):
|
||||||
profile = self.model_profiles.get(key, {})
|
profile = self.model_profiles.get(key, {})
|
||||||
marker = "*" if key == self.active_model_key else "-"
|
marker = "*" if key == self.active_model_key else "-"
|
||||||
provider = str(profile.get("provider") or "?")
|
provider = str(profile.get("provider") or "?")
|
||||||
model_name = str(profile.get("model_name") or "?")
|
model_name = str(profile.get("model_name") or "?")
|
||||||
lines.append(f"{marker} {key}: {provider}/{model_name}")
|
lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}")
|
||||||
|
|
||||||
|
if ordered_keys:
|
||||||
|
lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2")
|
||||||
|
|
||||||
lines.append(self._build_models_usage("/models"))
|
lines.append(self._build_models_usage("/models"))
|
||||||
await self._reply_plain(message, "\n".join(lines))
|
await self._reply_plain(message, "\n".join(lines))
|
||||||
@@ -947,15 +1043,11 @@ class MessageHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = self._normalize_model_key(parts[2])
|
key = self._resolve_model_selector(parts[2])
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
await self._reply_plain(message, str(exc))
|
await self._reply_plain(message, str(exc))
|
||||||
return
|
return
|
||||||
|
|
||||||
if key not in self.model_profiles:
|
|
||||||
await self._reply_plain(message, f"模型配置不存在: {key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = self._model_config_from_dict(
|
config = self._model_config_from_dict(
|
||||||
self.model_profiles[key], self.ai_client.config
|
self.model_profiles[key], self.ai_client.config
|
||||||
@@ -1065,7 +1157,7 @@ class MessageHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = self._normalize_model_key(parts[2])
|
key = self._resolve_model_selector(parts[2])
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
await self._reply_plain(message, str(exc))
|
await self._reply_plain(message, str(exc))
|
||||||
return
|
return
|
||||||
@@ -1074,10 +1166,6 @@ class MessageHandler:
|
|||||||
await self._reply_plain(message, "默认模型配置不能删除")
|
await self._reply_plain(message, "默认模型配置不能删除")
|
||||||
return
|
return
|
||||||
|
|
||||||
if key not in self.model_profiles:
|
|
||||||
await self._reply_plain(message, f"模型配置不存在: {key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
del self.model_profiles[key]
|
del self.model_profiles[key]
|
||||||
switched_to = None
|
switched_to = None
|
||||||
|
|
||||||
@@ -1146,8 +1234,8 @@ class MessageHandler:
|
|||||||
"/models current\n"
|
"/models current\n"
|
||||||
"/models add <model_name>\n"
|
"/models add <model_name>\n"
|
||||||
"/models add <key> <provider> <model_name> [api_base]\n"
|
"/models add <key> <provider> <model_name> [api_base]\n"
|
||||||
"/models switch <key>\n"
|
"/models switch <key|index>\n"
|
||||||
"/models remove <key>\n"
|
"/models remove <key|index>\n"
|
||||||
"\n"
|
"\n"
|
||||||
"记忆命令\n"
|
"记忆命令\n"
|
||||||
"/memory\n"
|
"/memory\n"
|
||||||
|
|||||||
68
tests/test_memory_embedding_policy.py
Normal file
68
tests/test_memory_embedding_policy.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.ai.memory import MemorySystem
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_does_not_use_embedding_when_disabled(tmp_path: Path):
|
||||||
|
calls = {"count": 0}
|
||||||
|
|
||||||
|
async def fake_embed(_text: str):
|
||||||
|
calls["count"] += 1
|
||||||
|
return [0.1] * 8
|
||||||
|
|
||||||
|
memory = MemorySystem(
|
||||||
|
storage_path=tmp_path / "memory.json",
|
||||||
|
embed_func=fake_embed,
|
||||||
|
use_vector_db=False,
|
||||||
|
use_query_embedding=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 入库路径仍会触发 embedding
|
||||||
|
stored = asyncio.run(
|
||||||
|
memory.add_long_term(
|
||||||
|
user_id="u1",
|
||||||
|
content="我喜欢苹果和香蕉",
|
||||||
|
importance=0.9,
|
||||||
|
metadata={"source": "test"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert stored is not None
|
||||||
|
assert calls["count"] == 1
|
||||||
|
|
||||||
|
# 查询上下文与搜索均不触发 embedding
|
||||||
|
asyncio.run(memory.get_context(user_id="u1", query="苹果"))
|
||||||
|
asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5))
|
||||||
|
assert calls["count"] == 1
|
||||||
|
|
||||||
|
asyncio.run(memory.close())
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_uses_embedding_when_enabled(tmp_path: Path):
|
||||||
|
calls = {"count": 0}
|
||||||
|
|
||||||
|
async def fake_embed(_text: str):
|
||||||
|
calls["count"] += 1
|
||||||
|
return [0.1] * 8
|
||||||
|
|
||||||
|
memory = MemorySystem(
|
||||||
|
storage_path=tmp_path / "memory.json",
|
||||||
|
embed_func=fake_embed,
|
||||||
|
use_vector_db=False,
|
||||||
|
use_query_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
memory.add_long_term(
|
||||||
|
user_id="u1",
|
||||||
|
content="北京天气不错",
|
||||||
|
importance=0.9,
|
||||||
|
metadata={"source": "test"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert calls["count"] == 1
|
||||||
|
|
||||||
|
asyncio.run(memory.get_context(user_id="u1", query="北京"))
|
||||||
|
assert calls["count"] >= 2
|
||||||
|
|
||||||
|
asyncio.run(memory.close())
|
||||||
Reference in New Issue
Block a user