Implement SQLite support check for Chroma's trigram tokenizer, enhancing compatibility with cloud environments. Update README for model command syntax and add pysqlite3-binary to requirements for improved SQLite functionality.

This commit is contained in:
Mimikko-zeus
2026-03-03 13:29:05 +08:00
parent 774ea9d5e4
commit 46ff239f4c
6 changed files with 256 additions and 36 deletions

View File

@@ -73,8 +73,8 @@ python main.py
- `/models current`
- `/models add <model_name>`
- `/models add <key> <provider> <model_name> [api_base]`
- `/models switch <key>`
- `/models remove <key>`
- `/models switch <key|index>`
- `/models remove <key|index>`
说明:
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key

47
main.py
View File

@@ -8,6 +8,53 @@ from pathlib import Path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
def _sqlite_supports_trigram(sqlite_module) -> bool:
conn = None
try:
conn = sqlite_module.connect(":memory:")
conn.execute("create virtual table t using fts5(content, tokenize='trigram')")
return True
except Exception:
return False
finally:
if conn is not None:
conn.close()
def _ensure_sqlite_for_chroma():
"""
Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma.
On some cloud images, system sqlite lacks trigram support.
"""
try:
import sqlite3
except Exception:
return
if _sqlite_supports_trigram(sqlite3):
return
try:
import pysqlite3
except Exception as exc:
print(
"[WARN] sqlite3 does not support trigram tokenizer and pysqlite3 is unavailable: "
f"{exc}"
)
print("[WARN] Chroma may fail and fallback to JSON storage.")
return
if _sqlite_supports_trigram(pysqlite3):
sys.modules["sqlite3"] = pysqlite3
print("[INFO] sqlite3 switched to pysqlite3 for Chroma trigram support.")
else:
print("[WARN] pysqlite3 is installed but still lacks trigram tokenizer support.")
print("[WARN] Chroma may fail and fallback to JSON storage.")
_ensure_sqlite_for_chroma()
from src.core.bot import MyClient, build_intents
from src.core.config import Config
from src.utils.logger import setup_logger

View File

@@ -12,3 +12,4 @@ openai>=1.0.0
anthropic>=0.18.0
numpy>=1.24.0
chromadb>=0.4.0 # 向量数据库,用于记忆存储
pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持

View File

@@ -91,12 +91,15 @@ class MemorySystem:
embed_func: Optional[callable] = None,
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
importance_threshold: float = 0.6,
use_vector_db: bool = True
use_vector_db: bool = True,
use_query_embedding: bool = False,
):
self.short_term = ShortTermMemory()
self.embed_func = embed_func
self.importance_evaluator = importance_evaluator
self.importance_threshold = importance_threshold
# Only embed retrieval queries when explicitly enabled.
self.use_query_embedding = use_query_embedding
# 初始化向量存储
if use_vector_db:
@@ -117,11 +120,23 @@ class MemorySystem:
msg = str(error).lower()
return "table embeddings already exists" in msg
@staticmethod
def _is_chroma_trigram_error(error: Exception) -> bool:
msg = str(error).lower()
return "no such tokenizer: trigram" in msg
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
"""初始化 Chroma遇到已知 sqlite schema 冲突时尝试修复。"""
try:
return ChromaVectorStore(chroma_path)
except Exception as error:
if self._is_chroma_trigram_error(error):
logger.warning(
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。"
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
)
return None
if not self._is_chroma_table_conflict(error):
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
return None
@@ -327,7 +342,7 @@ class MemorySystem:
# 获取相关长期记忆
long_term_memories = []
if query:
if query and self.use_query_embedding:
try:
# 使用向量检索
query_embedding = await self._build_embedding(query)
@@ -430,15 +445,16 @@ class MemorySystem:
if not query:
return []
query_embedding = await self._build_embedding(query)
results = await self.vector_store.search(
user_id=user_id,
query_embedding=query_embedding,
limit=limit,
min_importance=0.0,
)
if results:
return results
if self.use_query_embedding:
query_embedding = await self._build_embedding(query)
results = await self.vector_store.search(
user_id=user_id,
query_embedding=query_embedding,
limit=limit,
min_importance=0.0,
)
if results:
return results
all_memories = await self.vector_store.get_all(user_id)
query_lower = query.lower()

View File

@@ -6,7 +6,7 @@ import asyncio
import json
from pathlib import Path
import re
from typing import Any, Dict
from typing import Any, Dict, Optional
from botpy.message import Message
@@ -33,7 +33,8 @@ class MessageHandler:
(re.compile(r"\*\*([^*]+)\*\*"), r"\1"),
(re.compile(r"\*([^*]+)\*"), r"\1"),
(re.compile(r"__([^_]+)__"), r"\1"),
(re.compile(r"_([^_]+)_"), r"\1"),
# Avoid stripping underscores inside identifiers like model keys.
(re.compile(r"(?<!\w)_([^_\n]+)_(?!\w)"), r"\1"),
(re.compile(r"^#{1,6}\s*", re.MULTILINE), ""),
(re.compile(r"^>\s?", re.MULTILINE), ""),
(re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"),
@@ -100,8 +101,8 @@ class MessageHandler:
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
f"{command_name} add <key> <json>\n"
" json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
f"{command_name} switch <key>\n"
f"{command_name} remove <key>"
f"{command_name} switch <key|index>\n"
f"{command_name} remove <key|index>"
)
@staticmethod
@@ -142,6 +143,66 @@ class MessageHandler:
key = f"model_{key}"
return key
@staticmethod
def _compact_model_key(raw_key: str) -> str:
return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower())
def _ordered_model_keys(self) -> list[str]:
return sorted(self.model_profiles.keys())
def _resolve_model_selector(self, selector: str) -> str:
raw = (selector or "").strip()
if not raw:
raise ValueError("模型 key 不能为空")
ordered_keys = self._ordered_model_keys()
if raw.isdigit():
index = int(raw)
if index < 1 or index > len(ordered_keys):
raise ValueError(
f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}"
)
return ordered_keys[index - 1]
if raw in self.model_profiles:
return raw
normalized_selector: Optional[str]
try:
normalized_selector = self._normalize_model_key(raw)
except ValueError:
normalized_selector = None
if normalized_selector and normalized_selector in self.model_profiles:
return normalized_selector
normalized_candidates: Dict[str, list[str]] = {}
compact_candidates: Dict[str, list[str]] = {}
for key in ordered_keys:
try:
normalized_key = self._normalize_model_key(key)
except ValueError:
normalized_key = key.strip().lower()
normalized_candidates.setdefault(normalized_key, []).append(key)
compact_candidates.setdefault(
self._compact_model_key(normalized_key), []
).append(key)
if normalized_selector and normalized_selector in normalized_candidates:
matches = normalized_candidates[normalized_selector]
if len(matches) == 1:
return matches[0]
raise ValueError(f"匹配到多个模型 key请使用完整 key: {', '.join(matches)}")
compact_selector = self._compact_model_key(normalized_selector or raw)
if compact_selector in compact_candidates:
matches = compact_candidates[compact_selector]
if len(matches) == 1:
return matches[0]
raise ValueError(f"匹配到多个模型 key请使用完整 key: {', '.join(matches)}")
raise ValueError(f"模型配置不存在: {raw}")
@classmethod
def _parse_provider(cls, raw_provider: str) -> ModelProvider:
provider = cls._provider_map().get(raw_provider.strip().lower())
@@ -249,9 +310,29 @@ class MessageHandler:
logger.warning(f"load model profiles failed, reset to defaults: {exc}")
payload = {}
profiles = payload.get("profiles")
if not isinstance(profiles, dict):
profiles = {}
raw_profiles = payload.get("profiles")
profiles: Dict[str, Dict[str, Any]] = {}
if isinstance(raw_profiles, dict):
for raw_key, raw_profile in raw_profiles.items():
if not isinstance(raw_profile, dict):
continue
key_text = str(raw_key or "").strip()
if not key_text:
continue
try:
normalized_key = self._normalize_model_key(key_text)
except ValueError:
continue
if normalized_key in profiles and profiles[normalized_key] != raw_profile:
logger.warning(
f"duplicate model key after normalization, keep first: {normalized_key}"
)
continue
profiles[normalized_key] = raw_profile
if not profiles:
profiles = {
@@ -260,8 +341,19 @@ class MessageHandler:
)
}
active = str(payload.get("active") or "")
if active not in profiles:
active_raw = str(payload.get("active") or "").strip()
active = ""
if active_raw in profiles:
active = active_raw
elif active_raw:
try:
normalized_active = self._normalize_model_key(active_raw)
except ValueError:
normalized_active = ""
if normalized_active in profiles:
active = normalized_active
if not active:
active = "default" if "default" in profiles else sorted(profiles.keys())[0]
self.model_profiles = profiles
@@ -918,12 +1010,16 @@ class MessageHandler:
if action in {"list", "ls"} and len(parts) <= 2:
lines = [f"当前模型配置: {self.active_model_key}"]
for key in sorted(self.model_profiles.keys()):
ordered_keys = self._ordered_model_keys()
for idx, key in enumerate(ordered_keys, start=1):
profile = self.model_profiles.get(key, {})
marker = "*" if key == self.active_model_key else "-"
provider = str(profile.get("provider") or "?")
model_name = str(profile.get("model_name") or "?")
lines.append(f"{marker} {key}: {provider}/{model_name}")
lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}")
if ordered_keys:
lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2")
lines.append(self._build_models_usage("/models"))
await self._reply_plain(message, "\n".join(lines))
@@ -947,15 +1043,11 @@ class MessageHandler:
return
try:
key = self._normalize_model_key(parts[2])
key = self._resolve_model_selector(parts[2])
except ValueError as exc:
await self._reply_plain(message, str(exc))
return
if key not in self.model_profiles:
await self._reply_plain(message, f"模型配置不存在: {key}")
return
try:
config = self._model_config_from_dict(
self.model_profiles[key], self.ai_client.config
@@ -1065,7 +1157,7 @@ class MessageHandler:
return
try:
key = self._normalize_model_key(parts[2])
key = self._resolve_model_selector(parts[2])
except ValueError as exc:
await self._reply_plain(message, str(exc))
return
@@ -1074,10 +1166,6 @@ class MessageHandler:
await self._reply_plain(message, "默认模型配置不能删除")
return
if key not in self.model_profiles:
await self._reply_plain(message, f"模型配置不存在: {key}")
return
del self.model_profiles[key]
switched_to = None
@@ -1146,8 +1234,8 @@ class MessageHandler:
"/models current\n"
"/models add <model_name>\n"
"/models add <key> <provider> <model_name> [api_base]\n"
"/models switch <key>\n"
"/models remove <key>\n"
"/models switch <key|index>\n"
"/models remove <key|index>\n"
"\n"
"记忆命令\n"
"/memory\n"

View File

@@ -0,0 +1,68 @@
import asyncio
from pathlib import Path
from src.ai.memory import MemorySystem
def test_query_does_not_use_embedding_when_disabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=False,
)
# 入库路径仍会触发 embedding
stored = asyncio.run(
memory.add_long_term(
user_id="u1",
content="我喜欢苹果和香蕉",
importance=0.9,
metadata={"source": "test"},
)
)
assert stored is not None
assert calls["count"] == 1
# 查询上下文与搜索均不触发 embedding
asyncio.run(memory.get_context(user_id="u1", query="苹果"))
asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5))
assert calls["count"] == 1
asyncio.run(memory.close())
def test_query_uses_embedding_when_enabled(tmp_path: Path):
calls = {"count": 0}
async def fake_embed(_text: str):
calls["count"] += 1
return [0.1] * 8
memory = MemorySystem(
storage_path=tmp_path / "memory.json",
embed_func=fake_embed,
use_vector_db=False,
use_query_embedding=True,
)
asyncio.run(
memory.add_long_term(
user_id="u1",
content="北京天气不错",
importance=0.9,
metadata={"source": "test"},
)
)
assert calls["count"] == 1
asyncio.run(memory.get_context(user_id="u1", query="北京"))
assert calls["count"] >= 2
asyncio.run(memory.close())