Implement SQLite support check for Chroma's trigram tokenizer, enhancing compatibility with cloud environments. Update README for model command syntax and add pysqlite3-binary to requirements for improved SQLite functionality.

This commit is contained in:
Mimikko-zeus
2026-03-03 13:29:05 +08:00
parent 774ea9d5e4
commit 46ff239f4c
6 changed files with 256 additions and 36 deletions

View File

@@ -73,8 +73,8 @@ python main.py
- `/models current` - `/models current`
- `/models add <model_name>` - `/models add <model_name>`
- `/models add <key> <provider> <model_name> [api_base]` - `/models add <key> <provider> <model_name> [api_base]`
- `/models switch <key>` - `/models switch <key|index>`
- `/models remove <key>` - `/models remove <key|index>`
说明: 说明:
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key - `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key

47
main.py
View File

@@ -8,6 +8,53 @@ from pathlib import Path
project_root = Path(__file__).parent project_root = Path(__file__).parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
def _sqlite_supports_trigram(sqlite_module) -> bool:
conn = None
try:
conn = sqlite_module.connect(":memory:")
conn.execute("create virtual table t using fts5(content, tokenize='trigram')")
return True
except Exception:
return False
finally:
if conn is not None:
conn.close()
def _ensure_sqlite_for_chroma():
"""
Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma.
On some cloud images, system sqlite lacks trigram support.
"""
try:
import sqlite3
except Exception:
return
if _sqlite_supports_trigram(sqlite3):
return
try:
import pysqlite3
except Exception as exc:
print(
"[WARN] sqlite3 does not support trigram tokenizer and pysqlite3 is unavailable: "
f"{exc}"
)
print("[WARN] Chroma may fail and fallback to JSON storage.")
return
if _sqlite_supports_trigram(pysqlite3):
sys.modules["sqlite3"] = pysqlite3
print("[INFO] sqlite3 switched to pysqlite3 for Chroma trigram support.")
else:
print("[WARN] pysqlite3 is installed but still lacks trigram tokenizer support.")
print("[WARN] Chroma may fail and fallback to JSON storage.")
_ensure_sqlite_for_chroma()
from src.core.bot import MyClient, build_intents from src.core.bot import MyClient, build_intents
from src.core.config import Config from src.core.config import Config
from src.utils.logger import setup_logger from src.utils.logger import setup_logger

View File

@@ -12,3 +12,4 @@ openai>=1.0.0
anthropic>=0.18.0 anthropic>=0.18.0
numpy>=1.24.0 numpy>=1.24.0
chromadb>=0.4.0 # 向量数据库,用于记忆存储 chromadb>=0.4.0 # 向量数据库,用于记忆存储
pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持

View File

@@ -91,12 +91,15 @@ class MemorySystem:
embed_func: Optional[callable] = None, embed_func: Optional[callable] = None,
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None, importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
importance_threshold: float = 0.6, importance_threshold: float = 0.6,
use_vector_db: bool = True use_vector_db: bool = True,
use_query_embedding: bool = False,
): ):
self.short_term = ShortTermMemory() self.short_term = ShortTermMemory()
self.embed_func = embed_func self.embed_func = embed_func
self.importance_evaluator = importance_evaluator self.importance_evaluator = importance_evaluator
self.importance_threshold = importance_threshold self.importance_threshold = importance_threshold
# Only embed retrieval queries when explicitly enabled.
self.use_query_embedding = use_query_embedding
# 初始化向量存储 # 初始化向量存储
if use_vector_db: if use_vector_db:
@@ -117,11 +120,23 @@ class MemorySystem:
msg = str(error).lower() msg = str(error).lower()
return "table embeddings already exists" in msg return "table embeddings already exists" in msg
@staticmethod
def _is_chroma_trigram_error(error: Exception) -> bool:
msg = str(error).lower()
return "no such tokenizer: trigram" in msg
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]: def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
"""初始化 Chroma遇到已知 sqlite schema 冲突时尝试修复。""" """初始化 Chroma遇到已知 sqlite schema 冲突时尝试修复。"""
try: try:
return ChromaVectorStore(chroma_path) return ChromaVectorStore(chroma_path)
except Exception as error: except Exception as error:
if self._is_chroma_trigram_error(error):
logger.warning(
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。"
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
)
return None
if not self._is_chroma_table_conflict(error): if not self._is_chroma_table_conflict(error):
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}") logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
return None return None
@@ -327,7 +342,7 @@ class MemorySystem:
# 获取相关长期记忆 # 获取相关长期记忆
long_term_memories = [] long_term_memories = []
if query: if query and self.use_query_embedding:
try: try:
# 使用向量检索 # 使用向量检索
query_embedding = await self._build_embedding(query) query_embedding = await self._build_embedding(query)
@@ -430,6 +445,7 @@ class MemorySystem:
if not query: if not query:
return [] return []
if self.use_query_embedding:
query_embedding = await self._build_embedding(query) query_embedding = await self._build_embedding(query)
results = await self.vector_store.search( results = await self.vector_store.search(
user_id=user_id, user_id=user_id,

View File

@@ -6,7 +6,7 @@ import asyncio
import json import json
from pathlib import Path from pathlib import Path
import re import re
from typing import Any, Dict from typing import Any, Dict, Optional
from botpy.message import Message from botpy.message import Message
@@ -33,7 +33,8 @@ class MessageHandler:
(re.compile(r"\*\*([^*]+)\*\*"), r"\1"), (re.compile(r"\*\*([^*]+)\*\*"), r"\1"),
(re.compile(r"\*([^*]+)\*"), r"\1"), (re.compile(r"\*([^*]+)\*"), r"\1"),
(re.compile(r"__([^_]+)__"), r"\1"), (re.compile(r"__([^_]+)__"), r"\1"),
(re.compile(r"_([^_]+)_"), r"\1"), # Avoid stripping underscores inside identifiers like model keys.
(re.compile(r"(?<!\w)_([^_\n]+)_(?!\w)"), r"\1"),
(re.compile(r"^#{1,6}\s*", re.MULTILINE), ""), (re.compile(r"^#{1,6}\s*", re.MULTILINE), ""),
(re.compile(r"^>\s?", re.MULTILINE), ""), (re.compile(r"^>\s?", re.MULTILINE), ""),
(re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"), (re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"),
@@ -100,8 +101,8 @@ class MessageHandler:
f"{command_name} add <key> <provider> <model_name> [api_base]\n" f"{command_name} add <key> <provider> <model_name> [api_base]\n"
f"{command_name} add <key> <json>\n" f"{command_name} add <key> <json>\n"
" json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n" " json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
f"{command_name} switch <key>\n" f"{command_name} switch <key|index>\n"
f"{command_name} remove <key>" f"{command_name} remove <key|index>"
) )
@staticmethod @staticmethod
@@ -142,6 +143,66 @@ class MessageHandler:
key = f"model_{key}" key = f"model_{key}"
return key return key
@staticmethod
def _compact_model_key(raw_key: str) -> str:
return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower())
def _ordered_model_keys(self) -> list[str]:
return sorted(self.model_profiles.keys())
def _resolve_model_selector(self, selector: str) -> str:
raw = (selector or "").strip()
if not raw:
raise ValueError("模型 key 不能为空")
ordered_keys = self._ordered_model_keys()
if raw.isdigit():
index = int(raw)
if index < 1 or index > len(ordered_keys):
raise ValueError(
f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}"
)
return ordered_keys[index - 1]
if raw in self.model_profiles:
return raw
normalized_selector: Optional[str]
try:
normalized_selector = self._normalize_model_key(raw)
except ValueError:
normalized_selector = None
if normalized_selector and normalized_selector in self.model_profiles:
return normalized_selector
normalized_candidates: Dict[str, list[str]] = {}
compact_candidates: Dict[str, list[str]] = {}
for key in ordered_keys:
try:
normalized_key = self._normalize_model_key(key)
except ValueError:
normalized_key = key.strip().lower()
normalized_candidates.setdefault(normalized_key, []).append(key)
compact_candidates.setdefault(
self._compact_model_key(normalized_key), []
).append(key)
if normalized_selector and normalized_selector in normalized_candidates:
matches = normalized_candidates[normalized_selector]
if len(matches) == 1:
return matches[0]
raise ValueError(f"匹配到多个模型 key请使用完整 key: {', '.join(matches)}")
compact_selector = self._compact_model_key(normalized_selector or raw)
if compact_selector in compact_candidates:
matches = compact_candidates[compact_selector]
if len(matches) == 1:
return matches[0]
raise ValueError(f"匹配到多个模型 key请使用完整 key: {', '.join(matches)}")
raise ValueError(f"模型配置不存在: {raw}")
@classmethod @classmethod
def _parse_provider(cls, raw_provider: str) -> ModelProvider: def _parse_provider(cls, raw_provider: str) -> ModelProvider:
provider = cls._provider_map().get(raw_provider.strip().lower()) provider = cls._provider_map().get(raw_provider.strip().lower())
@@ -249,9 +310,29 @@ class MessageHandler:
logger.warning(f"load model profiles failed, reset to defaults: {exc}") logger.warning(f"load model profiles failed, reset to defaults: {exc}")
payload = {} payload = {}
profiles = payload.get("profiles") raw_profiles = payload.get("profiles")
if not isinstance(profiles, dict): profiles: Dict[str, Dict[str, Any]] = {}
profiles = {} if isinstance(raw_profiles, dict):
for raw_key, raw_profile in raw_profiles.items():
if not isinstance(raw_profile, dict):
continue
key_text = str(raw_key or "").strip()
if not key_text:
continue
try:
normalized_key = self._normalize_model_key(key_text)
except ValueError:
continue
if normalized_key in profiles and profiles[normalized_key] != raw_profile:
logger.warning(
f"duplicate model key after normalization, keep first: {normalized_key}"
)
continue
profiles[normalized_key] = raw_profile
if not profiles: if not profiles:
profiles = { profiles = {
@@ -260,8 +341,19 @@ class MessageHandler:
) )
} }
active = str(payload.get("active") or "") active_raw = str(payload.get("active") or "").strip()
if active not in profiles: active = ""
if active_raw in profiles:
active = active_raw
elif active_raw:
try:
normalized_active = self._normalize_model_key(active_raw)
except ValueError:
normalized_active = ""
if normalized_active in profiles:
active = normalized_active
if not active:
active = "default" if "default" in profiles else sorted(profiles.keys())[0] active = "default" if "default" in profiles else sorted(profiles.keys())[0]
self.model_profiles = profiles self.model_profiles = profiles
@@ -918,12 +1010,16 @@ class MessageHandler:
if action in {"list", "ls"} and len(parts) <= 2: if action in {"list", "ls"} and len(parts) <= 2:
lines = [f"当前模型配置: {self.active_model_key}"] lines = [f"当前模型配置: {self.active_model_key}"]
for key in sorted(self.model_profiles.keys()): ordered_keys = self._ordered_model_keys()
for idx, key in enumerate(ordered_keys, start=1):
profile = self.model_profiles.get(key, {}) profile = self.model_profiles.get(key, {})
marker = "*" if key == self.active_model_key else "-" marker = "*" if key == self.active_model_key else "-"
provider = str(profile.get("provider") or "?") provider = str(profile.get("provider") or "?")
model_name = str(profile.get("model_name") or "?") model_name = str(profile.get("model_name") or "?")
lines.append(f"{marker} {key}: {provider}/{model_name}") lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}")
if ordered_keys:
lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2")
lines.append(self._build_models_usage("/models")) lines.append(self._build_models_usage("/models"))
await self._reply_plain(message, "\n".join(lines)) await self._reply_plain(message, "\n".join(lines))
@@ -947,15 +1043,11 @@ class MessageHandler:
return return
try: try:
key = self._normalize_model_key(parts[2]) key = self._resolve_model_selector(parts[2])
except ValueError as exc: except ValueError as exc:
await self._reply_plain(message, str(exc)) await self._reply_plain(message, str(exc))
return return
if key not in self.model_profiles:
await self._reply_plain(message, f"模型配置不存在: {key}")
return
try: try:
config = self._model_config_from_dict( config = self._model_config_from_dict(
self.model_profiles[key], self.ai_client.config self.model_profiles[key], self.ai_client.config
@@ -1065,7 +1157,7 @@ class MessageHandler:
return return
try: try:
key = self._normalize_model_key(parts[2]) key = self._resolve_model_selector(parts[2])
except ValueError as exc: except ValueError as exc:
await self._reply_plain(message, str(exc)) await self._reply_plain(message, str(exc))
return return
@@ -1074,10 +1166,6 @@ class MessageHandler:
await self._reply_plain(message, "默认模型配置不能删除") await self._reply_plain(message, "默认模型配置不能删除")
return return
if key not in self.model_profiles:
await self._reply_plain(message, f"模型配置不存在: {key}")
return
del self.model_profiles[key] del self.model_profiles[key]
switched_to = None switched_to = None
@@ -1146,8 +1234,8 @@ class MessageHandler:
"/models current\n" "/models current\n"
"/models add <model_name>\n" "/models add <model_name>\n"
"/models add <key> <provider> <model_name> [api_base]\n" "/models add <key> <provider> <model_name> [api_base]\n"
"/models switch <key>\n" "/models switch <key|index>\n"
"/models remove <key>\n" "/models remove <key|index>\n"
"\n" "\n"
"记忆命令\n" "记忆命令\n"
"/memory\n" "/memory\n"

View File

@@ -0,0 +1,68 @@
import asyncio
from pathlib import Path
from src.ai.memory import MemorySystem
def test_query_does_not_use_embedding_when_disabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=False,
)
# 入库路径仍会触发 embedding
stored = asyncio.run(
memory.add_long_term(
user_id="u1",
content="我喜欢苹果和香蕉",
importance=0.9,
metadata={"source": "test"},
)
)
assert stored is not None
assert calls["count"] == 1
# 查询上下文与搜索均不触发 embedding
asyncio.run(memory.get_context(user_id="u1", query="苹果"))
asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5))
assert calls["count"] == 1
asyncio.run(memory.close())
def test_query_uses_embedding_when_enabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=True,
)
asyncio.run(
memory.add_long_term(
user_id="u1",
content="北京天气不错",
importance=0.9,
metadata={"source": "test"},
)
)
assert calls["count"] == 1
asyncio.run(memory.get_context(user_id="u1", query="北京"))
assert calls["count"] >= 2
asyncio.run(memory.close())