diff --git a/README.md b/README.md index 59d2099..fa0875f 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,8 @@ python main.py - `/models current` - `/models add ` - `/models add [api_base]` -- `/models switch ` -- `/models remove ` +- `/models switch ` +- `/models remove ` 说明: - `/models add ` 只替换模型名,沿用当前 API Base 和 API Key diff --git a/main.py b/main.py index 70c6e94..2d1e162 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,53 @@ from pathlib import Path project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) + +def _sqlite_supports_trigram(sqlite_module) -> bool: + conn = None + try: + conn = sqlite_module.connect(":memory:") + conn.execute("create virtual table t using fts5(content, tokenize='trigram')") + return True + except Exception: + return False + finally: + if conn is not None: + conn.close() + + +def _ensure_sqlite_for_chroma(): + """ + Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma. + On some cloud images, system sqlite lacks trigram support. + """ + try: + import sqlite3 + except Exception: + return + + if _sqlite_supports_trigram(sqlite3): + return + + try: + import pysqlite3 + except Exception as exc: + print( + "[WARN] sqlite3 does not support trigram tokenizer and pysqlite3 is unavailable: " + f"{exc}" + ) + print("[WARN] Chroma may fail and fallback to JSON storage.") + return + + if _sqlite_supports_trigram(pysqlite3): + sys.modules["sqlite3"] = pysqlite3 + print("[INFO] sqlite3 switched to pysqlite3 for Chroma trigram support.") + else: + print("[WARN] pysqlite3 is installed but still lacks trigram tokenizer support.") + print("[WARN] Chroma may fail and fallback to JSON storage.") + + +_ensure_sqlite_for_chroma() + from src.core.bot import MyClient, build_intents from src.core.config import Config from src.utils.logger import setup_logger diff --git a/requirements.txt b/requirements.txt index 8a01ddf..a1a0525 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ openai>=1.0.0 anthropic>=0.18.0 numpy>=1.24.0 chromadb>=0.4.0 # 向量数据库,用于记忆存储 +pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持 diff --git a/src/ai/memory.py b/src/ai/memory.py index 76151ae..6c20dee 100644 --- a/src/ai/memory.py +++ b/src/ai/memory.py @@ -91,12 +91,15 @@ class MemorySystem: embed_func: Optional[callable] = None, importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None, importance_threshold: float = 0.6, - use_vector_db: bool = True + use_vector_db: bool = True, + use_query_embedding: bool = False, ): self.short_term = ShortTermMemory() self.embed_func = embed_func self.importance_evaluator = importance_evaluator self.importance_threshold = importance_threshold + # Only embed retrieval queries when explicitly enabled. + self.use_query_embedding = use_query_embedding # 初始化向量存储 if use_vector_db: @@ -117,11 +120,23 @@ class MemorySystem: msg = str(error).lower() return "table embeddings already exists" in msg + @staticmethod + def _is_chroma_trigram_error(error: Exception) -> bool: + msg = str(error).lower() + return "no such tokenizer: trigram" in msg + def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]: """初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。""" try: return ChromaVectorStore(chroma_path) except Exception as error: + if self._is_chroma_trigram_error(error): + logger.warning( + "Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。" + "请在运行环境升级 sqlite 或安装 pysqlite3-binary。" + ) + return None + if not self._is_chroma_table_conflict(error): logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}") return None @@ -327,7 +342,7 @@ class MemorySystem: # 获取相关长期记忆 long_term_memories = [] - if query: + if query and self.use_query_embedding: try: # 使用向量检索 query_embedding = await self._build_embedding(query) @@ -430,15 +445,16 @@ class MemorySystem: if not query: return [] - query_embedding = await self._build_embedding(query) - results = await self.vector_store.search( - user_id=user_id, - query_embedding=query_embedding, - limit=limit, - min_importance=0.0, - ) - if results: - return results + if self.use_query_embedding: + query_embedding = await self._build_embedding(query) + results = await self.vector_store.search( + user_id=user_id, + query_embedding=query_embedding, + limit=limit, + min_importance=0.0, + ) + if results: + return results all_memories = await self.vector_store.get_all(user_id) query_lower = query.lower() diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index dd8559c..0b2a2a7 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -6,7 +6,7 @@ import asyncio import json from pathlib import Path import re -from typing import Any, Dict +from typing import Any, Dict, Optional from botpy.message import Message @@ -33,7 +33,8 @@ class MessageHandler: (re.compile(r"\*\*([^*]+)\*\*"), r"\1"), (re.compile(r"\*([^*]+)\*"), r"\1"), (re.compile(r"__([^_]+)__"), r"\1"), - (re.compile(r"_([^_]+)_"), r"\1"), + # Avoid stripping underscores inside identifiers like model keys. + (re.compile(r"(?\s?", re.MULTILINE), ""), (re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1: \2"), @@ -100,8 +101,8 @@ class MessageHandler: f"{command_name} add [api_base]\n" f"{command_name} add \n" " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n" - f"{command_name} switch \n" - f"{command_name} remove " + f"{command_name} switch \n" + f"{command_name} remove " ) @staticmethod @@ -142,6 +143,66 @@ class MessageHandler: key = f"model_{key}" return key + @staticmethod + def _compact_model_key(raw_key: str) -> str: + return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower()) + + def _ordered_model_keys(self) -> list[str]: + return sorted(self.model_profiles.keys()) + + def _resolve_model_selector(self, selector: str) -> str: + raw = (selector or "").strip() + if not raw: + raise ValueError("模型 key 不能为空") + + ordered_keys = self._ordered_model_keys() + if raw.isdigit(): + index = int(raw) + if index < 1 or index > len(ordered_keys): + raise ValueError( + f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}" + ) + return ordered_keys[index - 1] + + if raw in self.model_profiles: + return raw + + normalized_selector: Optional[str] + try: + normalized_selector = self._normalize_model_key(raw) + except ValueError: + normalized_selector = None + + if normalized_selector and normalized_selector in self.model_profiles: + return normalized_selector + + normalized_candidates: Dict[str, list[str]] = {} + compact_candidates: Dict[str, list[str]] = {} + for key in ordered_keys: + try: + normalized_key = self._normalize_model_key(key) + except ValueError: + normalized_key = key.strip().lower() + normalized_candidates.setdefault(normalized_key, []).append(key) + compact_candidates.setdefault( + self._compact_model_key(normalized_key), [] + ).append(key) + + if normalized_selector and normalized_selector in normalized_candidates: + matches = normalized_candidates[normalized_selector] + if len(matches) == 1: + return matches[0] + raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") + + compact_selector = self._compact_model_key(normalized_selector or raw) + if compact_selector in compact_candidates: + matches = compact_candidates[compact_selector] + if len(matches) == 1: + return matches[0] + raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") + + raise ValueError(f"模型配置不存在: {raw}") + @classmethod def _parse_provider(cls, raw_provider: str) -> ModelProvider: provider = cls._provider_map().get(raw_provider.strip().lower()) @@ -249,9 +310,29 @@ class MessageHandler: logger.warning(f"load model profiles failed, reset to defaults: {exc}") payload = {} - profiles = payload.get("profiles") - if not isinstance(profiles, dict): - profiles = {} + raw_profiles = payload.get("profiles") + profiles: Dict[str, Dict[str, Any]] = {} + if isinstance(raw_profiles, dict): + for raw_key, raw_profile in raw_profiles.items(): + if not isinstance(raw_profile, dict): + continue + + key_text = str(raw_key or "").strip() + if not key_text: + continue + + try: + normalized_key = self._normalize_model_key(key_text) + except ValueError: + continue + + if normalized_key in profiles and profiles[normalized_key] != raw_profile: + logger.warning( + f"duplicate model key after normalization, keep first: {normalized_key}" + ) + continue + + profiles[normalized_key] = raw_profile if not profiles: profiles = { @@ -260,8 +341,19 @@ class MessageHandler: ) } - active = str(payload.get("active") or "") - if active not in profiles: + active_raw = str(payload.get("active") or "").strip() + active = "" + if active_raw in profiles: + active = active_raw + elif active_raw: + try: + normalized_active = self._normalize_model_key(active_raw) + except ValueError: + normalized_active = "" + if normalized_active in profiles: + active = normalized_active + + if not active: active = "default" if "default" in profiles else sorted(profiles.keys())[0] self.model_profiles = profiles @@ -918,12 +1010,16 @@ class MessageHandler: if action in {"list", "ls"} and len(parts) <= 2: lines = [f"当前模型配置: {self.active_model_key}"] - for key in sorted(self.model_profiles.keys()): + ordered_keys = self._ordered_model_keys() + for idx, key in enumerate(ordered_keys, start=1): profile = self.model_profiles.get(key, {}) marker = "*" if key == self.active_model_key else "-" provider = str(profile.get("provider") or "?") model_name = str(profile.get("model_name") or "?") - lines.append(f"{marker} {key}: {provider}/{model_name}") + lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}") + + if ordered_keys: + lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2") lines.append(self._build_models_usage("/models")) await self._reply_plain(message, "\n".join(lines)) @@ -947,15 +1043,11 @@ class MessageHandler: return try: - key = self._normalize_model_key(parts[2]) + key = self._resolve_model_selector(parts[2]) except ValueError as exc: await self._reply_plain(message, str(exc)) return - if key not in self.model_profiles: - await self._reply_plain(message, f"模型配置不存在: {key}") - return - try: config = self._model_config_from_dict( self.model_profiles[key], self.ai_client.config @@ -1065,7 +1157,7 @@ class MessageHandler: return try: - key = self._normalize_model_key(parts[2]) + key = self._resolve_model_selector(parts[2]) except ValueError as exc: await self._reply_plain(message, str(exc)) return @@ -1074,10 +1166,6 @@ class MessageHandler: await self._reply_plain(message, "默认模型配置不能删除") return - if key not in self.model_profiles: - await self._reply_plain(message, f"模型配置不存在: {key}") - return - del self.model_profiles[key] switched_to = None @@ -1146,8 +1234,8 @@ class MessageHandler: "/models current\n" "/models add \n" "/models add [api_base]\n" - "/models switch \n" - "/models remove \n" + "/models switch \n" + "/models remove \n" "\n" "记忆命令\n" "/memory\n" diff --git a/tests/test_memory_embedding_policy.py b/tests/test_memory_embedding_policy.py new file mode 100644 index 0000000..ff574a4 --- /dev/null +++ b/tests/test_memory_embedding_policy.py @@ -0,0 +1,68 @@ +import asyncio +from pathlib import Path + +from src.ai.memory import MemorySystem + + +def test_query_does_not_use_embedding_when_disabled(tmp_path: Path): + calls = {"count": 0} + + async def fake_embed(_text: str): + calls["count"] += 1 + return [0.1] * 8 + + memory = MemorySystem( + storage_path=tmp_path / "memory.json", + embed_func=fake_embed, + use_vector_db=False, + use_query_embedding=False, + ) + + # 入库路径仍会触发 embedding + stored = asyncio.run( + memory.add_long_term( + user_id="u1", + content="我喜欢苹果和香蕉", + importance=0.9, + metadata={"source": "test"}, + ) + ) + assert stored is not None + assert calls["count"] == 1 + + # 查询上下文与搜索均不触发 embedding + asyncio.run(memory.get_context(user_id="u1", query="苹果")) + asyncio.run(memory.search_long_term(user_id="u1", query="苹果", limit=5)) + assert calls["count"] == 1 + + asyncio.run(memory.close()) + + +def test_query_uses_embedding_when_enabled(tmp_path: Path): + calls = {"count": 0} + + async def fake_embed(_text: str): + calls["count"] += 1 + return [0.1] * 8 + + memory = MemorySystem( + storage_path=tmp_path / "memory.json", + embed_func=fake_embed, + use_vector_db=False, + use_query_embedding=True, + ) + + asyncio.run( + memory.add_long_term( + user_id="u1", + content="北京天气不错", + importance=0.9, + metadata={"source": "test"}, + ) + ) + assert calls["count"] == 1 + + asyncio.run(memory.get_context(user_id="u1", query="北京")) + assert calls["count"] >= 2 + + asyncio.run(memory.close())