Implement SQLite support check for Chroma's trigram tokenizer, enhancing compatibility with cloud environments. Update README for model command syntax and add pysqlite3-binary to requirements for improved SQLite functionality.
This commit is contained in:
@@ -91,12 +91,15 @@ class MemorySystem:
|
||||
embed_func: Optional[callable] = None,
|
||||
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
|
||||
importance_threshold: float = 0.6,
|
||||
use_vector_db: bool = True
|
||||
use_vector_db: bool = True,
|
||||
use_query_embedding: bool = False,
|
||||
):
|
||||
self.short_term = ShortTermMemory()
|
||||
self.embed_func = embed_func
|
||||
self.importance_evaluator = importance_evaluator
|
||||
self.importance_threshold = importance_threshold
|
||||
# Only embed retrieval queries when explicitly enabled.
|
||||
self.use_query_embedding = use_query_embedding
|
||||
|
||||
# 初始化向量存储
|
||||
if use_vector_db:
|
||||
@@ -117,11 +120,23 @@ class MemorySystem:
|
||||
msg = str(error).lower()
|
||||
return "table embeddings already exists" in msg
|
||||
|
||||
@staticmethod
|
||||
def _is_chroma_trigram_error(error: Exception) -> bool:
|
||||
msg = str(error).lower()
|
||||
return "no such tokenizer: trigram" in msg
|
||||
|
||||
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
|
||||
"""初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。"""
|
||||
try:
|
||||
return ChromaVectorStore(chroma_path)
|
||||
except Exception as error:
|
||||
if self._is_chroma_trigram_error(error):
|
||||
logger.warning(
|
||||
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。"
|
||||
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
|
||||
)
|
||||
return None
|
||||
|
||||
if not self._is_chroma_table_conflict(error):
|
||||
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
|
||||
return None
|
||||
@@ -327,7 +342,7 @@ class MemorySystem:
|
||||
# 获取相关长期记忆
|
||||
long_term_memories = []
|
||||
|
||||
if query:
|
||||
if query and self.use_query_embedding:
|
||||
try:
|
||||
# 使用向量检索
|
||||
query_embedding = await self._build_embedding(query)
|
||||
@@ -430,15 +445,16 @@ class MemorySystem:
|
||||
if not query:
|
||||
return []
|
||||
|
||||
query_embedding = await self._build_embedding(query)
|
||||
results = await self.vector_store.search(
|
||||
user_id=user_id,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
min_importance=0.0,
|
||||
)
|
||||
if results:
|
||||
return results
|
||||
if self.use_query_embedding:
|
||||
query_embedding = await self._build_embedding(query)
|
||||
results = await self.vector_store.search(
|
||||
user_id=user_id,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
min_importance=0.0,
|
||||
)
|
||||
if results:
|
||||
return results
|
||||
|
||||
all_memories = await self.vector_store.get_all(user_id)
|
||||
query_lower = query.lower()
|
||||
|
||||
@@ -6,7 +6,7 @@ import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from botpy.message import Message
|
||||
|
||||
@@ -33,7 +33,8 @@ class MessageHandler:
|
||||
(re.compile(r"\*\*([^*]+)\*\*"), r"\1"),
|
||||
(re.compile(r"\*([^*]+)\*"), r"\1"),
|
||||
(re.compile(r"__([^_]+)__"), r"\1"),
|
||||
(re.compile(r"_([^_]+)_"), r"\1"),
|
||||
# Avoid stripping underscores inside identifiers like model keys.
|
||||
(re.compile(r"(?<!\w)_([^_\n]+)_(?!\w)"), r"\1"),
|
||||
(re.compile(r"^#{1,6}\s*", re.MULTILINE), ""),
|
||||
(re.compile(r"^>\s?", re.MULTILINE), ""),
|
||||
(re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"),
|
||||
@@ -100,8 +101,8 @@ class MessageHandler:
|
||||
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
|
||||
f"{command_name} add <key> <json>\n"
|
||||
" json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
|
||||
f"{command_name} switch <key>\n"
|
||||
f"{command_name} remove <key>"
|
||||
f"{command_name} switch <key|index>\n"
|
||||
f"{command_name} remove <key|index>"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -142,6 +143,66 @@ class MessageHandler:
|
||||
key = f"model_{key}"
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def _compact_model_key(raw_key: str) -> str:
|
||||
return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower())
|
||||
|
||||
def _ordered_model_keys(self) -> list[str]:
|
||||
return sorted(self.model_profiles.keys())
|
||||
|
||||
def _resolve_model_selector(self, selector: str) -> str:
|
||||
raw = (selector or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("模型 key 不能为空")
|
||||
|
||||
ordered_keys = self._ordered_model_keys()
|
||||
if raw.isdigit():
|
||||
index = int(raw)
|
||||
if index < 1 or index > len(ordered_keys):
|
||||
raise ValueError(
|
||||
f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}"
|
||||
)
|
||||
return ordered_keys[index - 1]
|
||||
|
||||
if raw in self.model_profiles:
|
||||
return raw
|
||||
|
||||
normalized_selector: Optional[str]
|
||||
try:
|
||||
normalized_selector = self._normalize_model_key(raw)
|
||||
except ValueError:
|
||||
normalized_selector = None
|
||||
|
||||
if normalized_selector and normalized_selector in self.model_profiles:
|
||||
return normalized_selector
|
||||
|
||||
normalized_candidates: Dict[str, list[str]] = {}
|
||||
compact_candidates: Dict[str, list[str]] = {}
|
||||
for key in ordered_keys:
|
||||
try:
|
||||
normalized_key = self._normalize_model_key(key)
|
||||
except ValueError:
|
||||
normalized_key = key.strip().lower()
|
||||
normalized_candidates.setdefault(normalized_key, []).append(key)
|
||||
compact_candidates.setdefault(
|
||||
self._compact_model_key(normalized_key), []
|
||||
).append(key)
|
||||
|
||||
if normalized_selector and normalized_selector in normalized_candidates:
|
||||
matches = normalized_candidates[normalized_selector]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}")
|
||||
|
||||
compact_selector = self._compact_model_key(normalized_selector or raw)
|
||||
if compact_selector in compact_candidates:
|
||||
matches = compact_candidates[compact_selector]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}")
|
||||
|
||||
raise ValueError(f"模型配置不存在: {raw}")
|
||||
|
||||
@classmethod
|
||||
def _parse_provider(cls, raw_provider: str) -> ModelProvider:
|
||||
provider = cls._provider_map().get(raw_provider.strip().lower())
|
||||
@@ -249,9 +310,29 @@ class MessageHandler:
|
||||
logger.warning(f"load model profiles failed, reset to defaults: {exc}")
|
||||
payload = {}
|
||||
|
||||
profiles = payload.get("profiles")
|
||||
if not isinstance(profiles, dict):
|
||||
profiles = {}
|
||||
raw_profiles = payload.get("profiles")
|
||||
profiles: Dict[str, Dict[str, Any]] = {}
|
||||
if isinstance(raw_profiles, dict):
|
||||
for raw_key, raw_profile in raw_profiles.items():
|
||||
if not isinstance(raw_profile, dict):
|
||||
continue
|
||||
|
||||
key_text = str(raw_key or "").strip()
|
||||
if not key_text:
|
||||
continue
|
||||
|
||||
try:
|
||||
normalized_key = self._normalize_model_key(key_text)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if normalized_key in profiles and profiles[normalized_key] != raw_profile:
|
||||
logger.warning(
|
||||
f"duplicate model key after normalization, keep first: {normalized_key}"
|
||||
)
|
||||
continue
|
||||
|
||||
profiles[normalized_key] = raw_profile
|
||||
|
||||
if not profiles:
|
||||
profiles = {
|
||||
@@ -260,8 +341,19 @@ class MessageHandler:
|
||||
)
|
||||
}
|
||||
|
||||
active = str(payload.get("active") or "")
|
||||
if active not in profiles:
|
||||
active_raw = str(payload.get("active") or "").strip()
|
||||
active = ""
|
||||
if active_raw in profiles:
|
||||
active = active_raw
|
||||
elif active_raw:
|
||||
try:
|
||||
normalized_active = self._normalize_model_key(active_raw)
|
||||
except ValueError:
|
||||
normalized_active = ""
|
||||
if normalized_active in profiles:
|
||||
active = normalized_active
|
||||
|
||||
if not active:
|
||||
active = "default" if "default" in profiles else sorted(profiles.keys())[0]
|
||||
|
||||
self.model_profiles = profiles
|
||||
@@ -918,12 +1010,16 @@ class MessageHandler:
|
||||
|
||||
if action in {"list", "ls"} and len(parts) <= 2:
|
||||
lines = [f"当前模型配置: {self.active_model_key}"]
|
||||
for key in sorted(self.model_profiles.keys()):
|
||||
ordered_keys = self._ordered_model_keys()
|
||||
for idx, key in enumerate(ordered_keys, start=1):
|
||||
profile = self.model_profiles.get(key, {})
|
||||
marker = "*" if key == self.active_model_key else "-"
|
||||
provider = str(profile.get("provider") or "?")
|
||||
model_name = str(profile.get("model_name") or "?")
|
||||
lines.append(f"{marker} {key}: {provider}/{model_name}")
|
||||
lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}")
|
||||
|
||||
if ordered_keys:
|
||||
lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2")
|
||||
|
||||
lines.append(self._build_models_usage("/models"))
|
||||
await self._reply_plain(message, "\n".join(lines))
|
||||
@@ -947,15 +1043,11 @@ class MessageHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
key = self._normalize_model_key(parts[2])
|
||||
key = self._resolve_model_selector(parts[2])
|
||||
except ValueError as exc:
|
||||
await self._reply_plain(message, str(exc))
|
||||
return
|
||||
|
||||
if key not in self.model_profiles:
|
||||
await self._reply_plain(message, f"模型配置不存在: {key}")
|
||||
return
|
||||
|
||||
try:
|
||||
config = self._model_config_from_dict(
|
||||
self.model_profiles[key], self.ai_client.config
|
||||
@@ -1065,7 +1157,7 @@ class MessageHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
key = self._normalize_model_key(parts[2])
|
||||
key = self._resolve_model_selector(parts[2])
|
||||
except ValueError as exc:
|
||||
await self._reply_plain(message, str(exc))
|
||||
return
|
||||
@@ -1074,10 +1166,6 @@ class MessageHandler:
|
||||
await self._reply_plain(message, "默认模型配置不能删除")
|
||||
return
|
||||
|
||||
if key not in self.model_profiles:
|
||||
await self._reply_plain(message, f"模型配置不存在: {key}")
|
||||
return
|
||||
|
||||
del self.model_profiles[key]
|
||||
switched_to = None
|
||||
|
||||
@@ -1146,8 +1234,8 @@ class MessageHandler:
|
||||
"/models current\n"
|
||||
"/models add <model_name>\n"
|
||||
"/models add <key> <provider> <model_name> [api_base]\n"
|
||||
"/models switch <key>\n"
|
||||
"/models remove <key>\n"
|
||||
"/models switch <key|index>\n"
|
||||
"/models remove <key|index>\n"
|
||||
"\n"
|
||||
"记忆命令\n"
|
||||
"/memory\n"
|
||||
|
||||
Reference in New Issue
Block a user