From a754e7843f1ee790dead7d1bf75289538c3e9e43 Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Tue, 3 Mar 2026 21:56:33 +0800 Subject: [PATCH] Refactor configuration and enhance AI capabilities Updated .env.example to improve clarity and added new configuration options for memory and reliability settings. Refactored main.py to streamline the bot's entry point and improved error handling. Enhanced README to reflect new features and command structure. Removed deprecated cmd_zip_skill and skills_creator modules to clean up the codebase. Updated AIClient and MemorySystem for better performance and flexibility in handling user interactions. --- .env.example | 36 +- README.md | 131 +- config/mcp.json | 6 - main.py | 41 +- requirements.txt | 15 +- skills/cmd_zip_skill/README.md | 7 - skills/cmd_zip_skill/__init__.py | 0 skills/cmd_zip_skill/main.py | 13 - skills/cmd_zip_skill/skill.json | 8 - skills/cmd_zip_skill_1772465404375/README.md | 7 - .../cmd_zip_skill_1772465404375/__init__.py | 0 skills/cmd_zip_skill_1772465404375/main.py | 13 - skills/cmd_zip_skill_1772465404375/skill.json | 8 - skills/cmd_zip_skill_1772465434774/README.md | 7 - .../cmd_zip_skill_1772465434774/__init__.py | 0 skills/cmd_zip_skill_1772465434774/main.py | 13 - skills/cmd_zip_skill_1772465434774/skill.json | 8 - skills/cmd_zip_skill_1772465467809/README.md | 7 - .../cmd_zip_skill_1772465467809/__init__.py | 0 skills/cmd_zip_skill_1772465467809/main.py | 13 - skills/cmd_zip_skill_1772465467809/skill.json | 8 - skills/cmd_zip_skill_1772465652075/README.md | 7 - .../cmd_zip_skill_1772465652075/__init__.py | 0 skills/cmd_zip_skill_1772465652075/main.py | 13 - skills/cmd_zip_skill_1772465652075/skill.json | 8 - skills/cmd_zip_skill_1772465685352/README.md | 7 - .../cmd_zip_skill_1772465685352/__init__.py | 0 skills/cmd_zip_skill_1772465685352/main.py | 13 - skills/cmd_zip_skill_1772465685352/skill.json | 8 - skills/cmd_zip_skill_1772465936294/README.md | 7 - .../cmd_zip_skill_1772465936294/__init__.py | 0 skills/cmd_zip_skill_1772465936294/main.py | 13 - skills/cmd_zip_skill_1772465936294/skill.json | 8 - skills/cmd_zip_skill_1772465966322/README.md | 7 - .../cmd_zip_skill_1772465966322/__init__.py | 0 skills/cmd_zip_skill_1772465966322/main.py | 13 - skills/cmd_zip_skill_1772465966322/skill.json | 8 - skills/cmd_zip_skill_1772466071278/README.md | 7 - .../cmd_zip_skill_1772466071278/__init__.py | 0 skills/cmd_zip_skill_1772466071278/main.py | 13 - skills/cmd_zip_skill_1772466071278/skill.json | 8 - skills/skills_creator/README.md | 12 - skills/skills_creator/__init__.py | 0 skills/skills_creator/main.py | 147 -- skills/skills_creator/skill.json | 8 - src/ai/__init__.py | 19 +- src/ai/client.py | 863 +++------- src/ai/mcp/__init__.py | 13 - src/ai/mcp/base.py | 230 --- src/ai/mcp/servers/__init__.py | 6 - src/ai/mcp/servers/filesystem.py | 123 -- src/ai/memory.py | 365 ++--- src/ai/models/anthropic_model.py | 91 +- src/ai/models/openai_model.py | 9 +- src/ai/personality.py | 252 ++- src/ai/skills/__init__.py | 6 - src/ai/skills/base.py | 750 --------- src/ai/task_manager.py | 327 ---- src/ai/vector_store/json_store.py | 255 +-- src/core/bot.py | 102 +- src/core/config.py | 147 +- src/handlers/__init__.py | 9 +- src/handlers/message_handler_ai.py | 1383 ++++++----------- src/utils/logger.py | 126 +- tests/conftest.py | 7 + tests/test_ai.py | 628 -------- tests/test_ai_client_forced_tool.py | 73 - tests/test_mcp_tool_registration.py | 44 - tests/test_message_dedup.py | 22 + tests/test_message_handler_text_sanitize.py | 6 +- tests/test_personality_scope_priority.py | 44 + tests/test_skills_install_source.py | 74 - 72 files changed, 1607 insertions(+), 5015 deletions(-) delete mode 100644 config/mcp.json delete mode 100644 skills/cmd_zip_skill/README.md delete mode 100644 skills/cmd_zip_skill/__init__.py delete mode 100644 skills/cmd_zip_skill/main.py delete mode 100644 skills/cmd_zip_skill/skill.json delete mode 100644 skills/cmd_zip_skill_1772465404375/README.md delete mode 100644 skills/cmd_zip_skill_1772465404375/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465404375/main.py delete mode 100644 skills/cmd_zip_skill_1772465404375/skill.json delete mode 100644 skills/cmd_zip_skill_1772465434774/README.md delete mode 100644 skills/cmd_zip_skill_1772465434774/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465434774/main.py delete mode 100644 skills/cmd_zip_skill_1772465434774/skill.json delete mode 100644 skills/cmd_zip_skill_1772465467809/README.md delete mode 100644 skills/cmd_zip_skill_1772465467809/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465467809/main.py delete mode 100644 skills/cmd_zip_skill_1772465467809/skill.json delete mode 100644 skills/cmd_zip_skill_1772465652075/README.md delete mode 100644 skills/cmd_zip_skill_1772465652075/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465652075/main.py delete mode 100644 skills/cmd_zip_skill_1772465652075/skill.json delete mode 100644 skills/cmd_zip_skill_1772465685352/README.md delete mode 100644 skills/cmd_zip_skill_1772465685352/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465685352/main.py delete mode 100644 skills/cmd_zip_skill_1772465685352/skill.json delete mode 100644 skills/cmd_zip_skill_1772465936294/README.md delete mode 100644 skills/cmd_zip_skill_1772465936294/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465936294/main.py delete mode 100644 skills/cmd_zip_skill_1772465936294/skill.json delete mode 100644 skills/cmd_zip_skill_1772465966322/README.md delete mode 100644 skills/cmd_zip_skill_1772465966322/__init__.py delete mode 100644 skills/cmd_zip_skill_1772465966322/main.py delete mode 100644 skills/cmd_zip_skill_1772465966322/skill.json delete mode 100644 skills/cmd_zip_skill_1772466071278/README.md delete mode 100644 skills/cmd_zip_skill_1772466071278/__init__.py delete mode 100644 skills/cmd_zip_skill_1772466071278/main.py delete mode 100644 skills/cmd_zip_skill_1772466071278/skill.json delete mode 100644 skills/skills_creator/README.md delete mode 100644 skills/skills_creator/__init__.py delete mode 100644 skills/skills_creator/main.py delete mode 100644 skills/skills_creator/skill.json delete mode 100644 src/ai/mcp/__init__.py delete mode 100644 src/ai/mcp/base.py delete mode 100644 src/ai/mcp/servers/__init__.py delete mode 100644 src/ai/mcp/servers/filesystem.py delete mode 100644 src/ai/skills/__init__.py delete mode 100644 src/ai/skills/base.py delete mode 100644 src/ai/task_manager.py create mode 100644 tests/conftest.py delete mode 100644 tests/test_ai.py delete mode 100644 tests/test_ai_client_forced_tool.py delete mode 100644 tests/test_mcp_tool_registration.py create mode 100644 tests/test_message_dedup.py create mode 100644 tests/test_personality_scope_priority.py delete mode 100644 tests/test_skills_install_source.py diff --git a/.env.example b/.env.example index 104e807..e1d4ea3 100644 --- a/.env.example +++ b/.env.example @@ -1,36 +1,36 @@ -# QQ 机器人配置 -# 可从 https://bot.q.qq.com/open 获取 - -# 机器人 AppID(必填) +# QQ Bot credentials BOT_APPID=your_app_id_here - -# 机器人 AppSecret(必填) BOT_SECRET=your_app_secret_here -# 日志级别: DEBUG / INFO / WARNING / ERROR +# Runtime +APP_ENV=dev LOG_LEVEL=INFO - -# 是否启用沙箱环境 +LOG_FORMAT=text SANDBOX_MODE=False -# ==================== AI 配置 ==================== +# Optional admin allow-list (comma separated user IDs). +# Empty means all users are treated as admin. +BOT_ADMIN_IDS= -# 主模型配置 -# 可选 provider: openai / anthropic / deepseek / qwen +# AI chat model AI_PROVIDER=openai AI_MODEL=gpt-4 AI_API_KEY=your_api_key_here -# 可选,自定义 API 地址 AI_API_BASE=https://api.openai.com/v1 -# 嵌入模型配置(用于长期记忆检索) -# 不配置时将回退为主模型的 embedding 能力(如果可用) +# Embedding model (optional) AI_EMBED_PROVIDER=openai AI_EMBED_MODEL=text-embedding-3-small AI_EMBED_API_KEY= AI_EMBED_API_BASE= -# 向量数据库配置 -# true: 使用 Chroma(推荐) -# false: 使用 JSON 存储 +# Memory storage and retrieval AI_USE_VECTOR_DB=true +AI_USE_QUERY_EMBEDDING=false +AI_MEMORY_SCOPE=session + +# Reliability +AI_CHAT_RETRIES=1 +AI_CHAT_RETRY_BACKOFF_SECONDS=0.8 +MESSAGE_DEDUP_SECONDS=30 +MESSAGE_DEDUP_MAX_SIZE=4096 diff --git a/README.md b/README.md index fa0875f..276fec5 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,16 @@ -# QQbot(AI 聊天机器人) +# QQbot (Memory + Persona Core) -一个基于 `botpy` 的 QQ 机器人项目,支持多模型切换、长期/短期记忆、人设管理、Skills 插件与 MCP 能力。 +QQ 机器人项目,保留并强化两大核心能力: +- `Memory`:短期/长期记忆、检索、清理、会话作用域 +- `Persona`:角色配置、作用域优先级(session > group > user > global) -## 功能概览 +## 主要能力 -- 多模型配置与运行时切换(`/models`) -- 人设增删改切换(`/personality`) -- 短期/长期记忆管理(`/clear`、`/memory`) -- Skills 本地与网络安装/卸载/重载(`/skills`) -- 自动去除 Markdown 格式后再回复(适配 QQ 聊天) +- 多模型配置与运行时切换:`/models` +- 人设管理:`/personality` +- 记忆管理:`/memory`、`/clear` +- QQ 消息安全输出:自动清理 Markdown/URL +- 工程增强:消息去重、失败重试、权限边界、结构化日志 ## 快速开始 @@ -20,7 +22,9 @@ pip install -r requirements.txt 2. 配置环境变量 -复制 `.env.example` 为 `.env`,填写 QQ 机器人和 AI 配置。 +```bash +copy .env.example .env +``` 3. 启动 @@ -28,89 +32,42 @@ pip install -r requirements.txt python main.py ``` -## 命令说明 +## 命令 -### 通用 +- 基础 + - `/help` + - `/clear` `/clear short` `/clear long` `/clear all` +- 人设 + - `/personality` + - `/personality list` + - `/personality set [global|user|group|session]` + - `/personality add ` + - `/personality remove ` +- 模型 + - `/models` + - `/models current` + - `/models add ` + - `/models add [api_base]` + - `/models switch ` + - `/models remove ` +- 记忆 + - `/memory` + - `/memory get ` + - `/memory add ` + - `/memory update ` + - `/memory delete ` + - `/memory search [limit]` -- `/help` -- `/clear`(默认等价 `/clear short`) -- `/clear short` -- `/clear long` -- `/clear all` +## 关键配置 -### 人设命令 - -- `/personality` -- `/personality list` -- `/personality set ` -- `/personality add ` -- `/personality remove ` - -说明: -- `add` 会新增并切换到该人设 -- `Introduction` 会作为人设简介与自定义指令 - -### Skills 命令 - -- `/skills` -- `/skills install [skill_name]` -- `/skills uninstall ` -- `/skills reload ` - -`source` 支持: -- 本地技能名(如 `weather`) -- URL(zip 包) -- GitHub 简写(`owner/repo` 或 `owner/repo#branch`) -- GitHub 仓库 URL(如 `https://github.com/op7418/Humanizer-zh.git`) - -兼容说明: -- 若源中包含标准技能结构(`skill.json` + `main.py`),按原方式安装 -- 若仅包含 `SKILL.md`,会自动生成适配技能并提供 `read_skill_doc` 工具读取文档内容 - -### 模型命令 - -- `/models` -- `/models current` -- `/models add ` -- `/models add [api_base]` -- `/models switch ` -- `/models remove ` - -说明: -- `/models add ` 只替换模型名,沿用当前 API Base 和 API Key - -### 长期记忆命令 - -- `/memory` -- `/memory get ` -- `/memory add ` -- `/memory update ` -- `/memory delete ` -- `/memory search [limit]` - -## 目录结构 - -```text -QQbot/ -├─ src/ -│ ├─ ai/ -│ ├─ handlers/ -│ ├─ core/ -│ └─ utils/ -├─ skills/ -├─ config/ -├─ docs/ -└─ tests/ -``` +- `AI_MEMORY_SCOPE=user|session`:记忆作用域 +- `BOT_ADMIN_IDS`:管理员白名单(逗号分隔) +- `AI_CHAT_RETRIES` / `AI_CHAT_RETRY_BACKOFF_SECONDS`:聊天失败重试 +- `MESSAGE_DEDUP_SECONDS` / `MESSAGE_DEDUP_MAX_SIZE`:消息去重窗口 +- `LOG_FORMAT=text|json`:日志输出格式 ## 测试 ```bash -python -m pytest -q -``` - -如果你使用 conda 环境,请先执行: - -```bash -conda activate qqbot +pytest -q ``` diff --git a/config/mcp.json b/config/mcp.json deleted file mode 100644 index d13402b..0000000 --- a/config/mcp.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "filesystem": { - "enabled": true, - "root_path": "data" - } -} diff --git a/main.py b/main.py index 2d1e162..9be020d 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ """ -QQ机器人主入口 +Project entrypoint. """ + +from __future__ import annotations + import sys from pathlib import Path -# 添加项目根目录到Python路径 project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) @@ -23,10 +25,6 @@ def _sqlite_supports_trigram(sqlite_module) -> bool: def _ensure_sqlite_for_chroma(): - """ - Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma. - On some cloud images, system sqlite lacks trigram support. - """ try: import sqlite3 except Exception: @@ -55,35 +53,8 @@ def _ensure_sqlite_for_chroma(): _ensure_sqlite_for_chroma() -from src.core.bot import MyClient, build_intents -from src.core.config import Config -from src.utils.logger import setup_logger - - -def main(): - """主函数""" - # 设置日志 - logger = setup_logger() - - try: - # 验证配置 - Config.validate() - logger.info("配置验证通过") - - # 创建并启动机器人(最小权限,避免 4014 disallowed intents) - logger.info("正在启动QQ机器人...") - intents = build_intents() - client = MyClient(intents=intents) - client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) - - except ValueError as e: - logger.error(f"配置错误: {e}") - logger.error("请检查 .env 文件配置") - sys.exit(1) - except Exception as e: - logger.error(f"启动失败: {e}", exc_info=True) - sys.exit(1) +from src.core.bot import main as run_bot_main if __name__ == "__main__": - main() + run_bot_main() diff --git a/requirements.txt b/requirements.txt index a1a0525..295ae36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,12 @@ -# ============================================ -# 核心依赖(必须安装) -# ============================================ +# Core runtime qq-botpy python-dotenv>=1.0.0 -# ============================================ -# AI 功能依赖(可选) -# 如果不需要 AI 对话功能,可以注释掉下面的依赖 -# ============================================ +# AI providers openai>=1.0.0 anthropic>=0.18.0 + +# Memory storage numpy>=1.24.0 -chromadb>=0.4.0 # 向量数据库,用于记忆存储 -pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持 +chromadb>=0.4.0 +pysqlite3-binary>=0.5.3; platform_system != "Windows" diff --git a/skills/cmd_zip_skill/README.md b/skills/cmd_zip_skill/README.md deleted file mode 100644 index f196536..0000000 --- a/skills/cmd_zip_skill/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill/__init__.py b/skills/cmd_zip_skill/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill/main.py b/skills/cmd_zip_skill/main.py deleted file mode 100644 index e673dff..0000000 --- a/skills/cmd_zip_skill/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkillSkill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill/skill.json b/skills/cmd_zip_skill/skill.json deleted file mode 100644 index 7acae47..0000000 --- a/skills/cmd_zip_skill/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465404375/README.md b/skills/cmd_zip_skill_1772465404375/README.md deleted file mode 100644 index 7b863a6..0000000 --- a/skills/cmd_zip_skill_1772465404375/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465404375 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465404375/__init__.py b/skills/cmd_zip_skill_1772465404375/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465404375/main.py b/skills/cmd_zip_skill_1772465404375/main.py deleted file mode 100644 index 37808f8..0000000 --- a/skills/cmd_zip_skill_1772465404375/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465404375 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465404375Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465404375 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465404375/skill.json b/skills/cmd_zip_skill_1772465404375/skill.json deleted file mode 100644 index c046c86..0000000 --- a/skills/cmd_zip_skill_1772465404375/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465404375", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465434774/README.md b/skills/cmd_zip_skill_1772465434774/README.md deleted file mode 100644 index d0516de..0000000 --- a/skills/cmd_zip_skill_1772465434774/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465434774 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465434774/__init__.py b/skills/cmd_zip_skill_1772465434774/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465434774/main.py b/skills/cmd_zip_skill_1772465434774/main.py deleted file mode 100644 index 2af7b12..0000000 --- a/skills/cmd_zip_skill_1772465434774/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465434774 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465434774Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465434774 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465434774/skill.json b/skills/cmd_zip_skill_1772465434774/skill.json deleted file mode 100644 index 60148e6..0000000 --- a/skills/cmd_zip_skill_1772465434774/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465434774", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465467809/README.md b/skills/cmd_zip_skill_1772465467809/README.md deleted file mode 100644 index f42c1d7..0000000 --- a/skills/cmd_zip_skill_1772465467809/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465467809 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465467809/__init__.py b/skills/cmd_zip_skill_1772465467809/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465467809/main.py b/skills/cmd_zip_skill_1772465467809/main.py deleted file mode 100644 index a61bd41..0000000 --- a/skills/cmd_zip_skill_1772465467809/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465467809 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465467809Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465467809 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465467809/skill.json b/skills/cmd_zip_skill_1772465467809/skill.json deleted file mode 100644 index dc7fc87..0000000 --- a/skills/cmd_zip_skill_1772465467809/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465467809", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465652075/README.md b/skills/cmd_zip_skill_1772465652075/README.md deleted file mode 100644 index 39ed797..0000000 --- a/skills/cmd_zip_skill_1772465652075/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465652075 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465652075/__init__.py b/skills/cmd_zip_skill_1772465652075/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465652075/main.py b/skills/cmd_zip_skill_1772465652075/main.py deleted file mode 100644 index c0dfa8e..0000000 --- a/skills/cmd_zip_skill_1772465652075/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465652075 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465652075Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465652075 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465652075/skill.json b/skills/cmd_zip_skill_1772465652075/skill.json deleted file mode 100644 index 24e2772..0000000 --- a/skills/cmd_zip_skill_1772465652075/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465652075", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465685352/README.md b/skills/cmd_zip_skill_1772465685352/README.md deleted file mode 100644 index 42f156b..0000000 --- a/skills/cmd_zip_skill_1772465685352/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465685352 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465685352/__init__.py b/skills/cmd_zip_skill_1772465685352/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465685352/main.py b/skills/cmd_zip_skill_1772465685352/main.py deleted file mode 100644 index cff99cf..0000000 --- a/skills/cmd_zip_skill_1772465685352/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465685352 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465685352Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465685352 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465685352/skill.json b/skills/cmd_zip_skill_1772465685352/skill.json deleted file mode 100644 index 23cbd37..0000000 --- a/skills/cmd_zip_skill_1772465685352/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465685352", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465936294/README.md b/skills/cmd_zip_skill_1772465936294/README.md deleted file mode 100644 index 43b734c..0000000 --- a/skills/cmd_zip_skill_1772465936294/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465936294 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465936294/__init__.py b/skills/cmd_zip_skill_1772465936294/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465936294/main.py b/skills/cmd_zip_skill_1772465936294/main.py deleted file mode 100644 index 1c25f33..0000000 --- a/skills/cmd_zip_skill_1772465936294/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465936294 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465936294Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465936294 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465936294/skill.json b/skills/cmd_zip_skill_1772465936294/skill.json deleted file mode 100644 index dcb35e7..0000000 --- a/skills/cmd_zip_skill_1772465936294/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465936294", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465966322/README.md b/skills/cmd_zip_skill_1772465966322/README.md deleted file mode 100644 index d8449f7..0000000 --- a/skills/cmd_zip_skill_1772465966322/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465966322 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465966322/__init__.py b/skills/cmd_zip_skill_1772465966322/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465966322/main.py b/skills/cmd_zip_skill_1772465966322/main.py deleted file mode 100644 index f0b63f6..0000000 --- a/skills/cmd_zip_skill_1772465966322/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465966322 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465966322Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465966322 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465966322/skill.json b/skills/cmd_zip_skill_1772465966322/skill.json deleted file mode 100644 index 6635930..0000000 --- a/skills/cmd_zip_skill_1772465966322/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465966322", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772466071278/README.md b/skills/cmd_zip_skill_1772466071278/README.md deleted file mode 100644 index a8b7ecf..0000000 --- a/skills/cmd_zip_skill_1772466071278/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772466071278 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772466071278/__init__.py b/skills/cmd_zip_skill_1772466071278/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772466071278/main.py b/skills/cmd_zip_skill_1772466071278/main.py deleted file mode 100644 index e1180e3..0000000 --- a/skills/cmd_zip_skill_1772466071278/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772466071278 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772466071278Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772466071278 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772466071278/skill.json b/skills/cmd_zip_skill_1772466071278/skill.json deleted file mode 100644 index e1fa141..0000000 --- a/skills/cmd_zip_skill_1772466071278/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772466071278", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/skills_creator/README.md b/skills/skills_creator/README.md deleted file mode 100644 index 5e9eb9d..0000000 --- a/skills/skills_creator/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# skills_creator - -Tools exposed to AI: -- create_skill -- write_skill_main -- update_skill_metadata -- delete_skill -- load_skill -- reload_skill -- list_skills - -This skill enables creating and managing other skills from chat-driven AI tool calls. \ No newline at end of file diff --git a/skills/skills_creator/__init__.py b/skills/skills_creator/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/skills_creator/main.py b/skills/skills_creator/main.py deleted file mode 100644 index a101f7f..0000000 --- a/skills/skills_creator/main.py +++ /dev/null @@ -1,147 +0,0 @@ -"""skills_creator skill""" - -import json -from pathlib import Path -import shutil -from typing import Any, Dict - -from src.ai.skills.base import Skill, SkillsManager, create_skill_template - - -class SkillsCreatorSkill(Skill): - """Provide tools to create and manage local skills from chat.""" - - def __init__(self): - super().__init__() - self.skills_root = Path("skills") - self.protected_skills = {"skills_creator"} - - async def initialize(self): - self.register_tool("create_skill", self.create_skill) - self.register_tool("write_skill_main", self.write_skill_main) - self.register_tool("update_skill_metadata", self.update_skill_metadata) - self.register_tool("delete_skill", self.delete_skill) - self.register_tool("load_skill", self.load_skill) - self.register_tool("reload_skill", self.reload_skill) - self.register_tool("list_skills", self.list_skills) - - def _skill_key(self, skill_name: str) -> str: - return SkillsManager.normalize_skill_key(skill_name) - - def _skill_path(self, skill_name: str) -> Path: - return self.skills_root / self._skill_key(skill_name) - - async def create_skill( - self, - skill_name: str, - description: str = "custom skill", - author: str = "chat_user", - auto_load: bool = True, - overwrite: bool = False, - ) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - - if skill_path.exists() and not overwrite: - return f"技能已存在: {skill_key}" - - if skill_path.exists() and overwrite: - shutil.rmtree(skill_path) - - create_skill_template(skill_key, self.skills_root, description=description, author=author) - - if auto_load and self.manager: - await self.manager.load_skill(skill_key) - - return f"已创建技能: {skill_key}" - - async def write_skill_main(self, skill_name: str, code: str, auto_reload: bool = True) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - if not skill_path.exists(): - return f"技能不存在: {skill_key}" - - main_file = skill_path / "main.py" - main_file.write_text(code, encoding="utf-8") - - if auto_reload and self.manager: - await self.manager.reload_skill(skill_key) - - return f"已更新技能代码: {skill_key}/main.py" - - async def update_skill_metadata(self, skill_name: str, fields: Dict[str, Any]) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - if not skill_path.exists(): - return f"技能不存在: {skill_key}" - - metadata_file = skill_path / "skill.json" - if metadata_file.exists(): - metadata = json.loads(metadata_file.read_text(encoding="utf-8")) - else: - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": "", - "author": "chat_user", - "dependencies": [], - "enabled": True, - } - - for key, value in fields.items(): - metadata[key] = value - - metadata.setdefault("name", skill_key) - metadata.setdefault("version", "1.0.0") - metadata.setdefault("description", "") - metadata.setdefault("author", "chat_user") - metadata.setdefault("dependencies", []) - metadata.setdefault("enabled", True) - - metadata_file.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8") - return f"已更新技能元数据: {skill_key}/skill.json" - - async def delete_skill(self, skill_name: str, delete_files: bool = True) -> str: - skill_key = self._skill_key(skill_name) - if skill_key in self.protected_skills: - return f"拒绝删除受保护技能: {skill_key}" - - if self.manager: - removed = await self.manager.uninstall_skill(skill_key, delete_files=delete_files) - return f"删除结果({skill_key}): {removed}" - - skill_path = self._skill_path(skill_key) - if delete_files and skill_path.exists(): - shutil.rmtree(skill_path) - return f"已删除技能目录: {skill_key}" - - return f"技能不存在或未删除: {skill_key}" - - async def load_skill(self, skill_name: str) -> str: - skill_key = self._skill_key(skill_name) - if not self.manager: - return "SkillsManager 不可用" - - success = await self.manager.load_skill(skill_key) - return f"加载技能 {skill_key}: {success}" - - async def reload_skill(self, skill_name: str) -> str: - skill_key = self._skill_key(skill_name) - if not self.manager: - return "SkillsManager 不可用" - - success = await self.manager.reload_skill(skill_key) - return f"重载技能 {skill_key}: {success}" - - async def list_skills(self) -> str: - if not self.manager: - return "SkillsManager 不可用" - - payload = { - "loaded": self.manager.list_skills(), - "available": self.manager.list_available_skills(), - } - return json.dumps(payload, ensure_ascii=False) - - async def cleanup(self): - pass \ No newline at end of file diff --git a/skills/skills_creator/skill.json b/skills/skills_creator/skill.json deleted file mode 100644 index 7875637..0000000 --- a/skills/skills_creator/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "skills_creator", - "version": "1.0.0", - "description": "Create, edit, load, and remove skills from chat", - "author": "QQBot", - "dependencies": [], - "enabled": true -} \ No newline at end of file diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 7650318..32a0906 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,14 +1,7 @@ -""" -AI模块 - 提供AI模型接入、人格系统、记忆系统和长任务处理能力 -""" -from .client import AIClient -from .personality import PersonalitySystem -from .memory import MemorySystem -from .task_manager import LongTaskManager +"""AI package exports.""" -__all__ = [ - 'AIClient', - 'PersonalitySystem', - 'MemorySystem', - 'LongTaskManager' -] +from .client import AIClient +from .memory import MemorySystem +from .personality import PersonalitySystem + +__all__ = ["AIClient", "MemorySystem", "PersonalitySystem"] diff --git a/src/ai/client.py b/src/ai/client.py index 6838621..ac92a8f 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -1,98 +1,97 @@ -""" -AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘 """ -import inspect +Unified AI client for chat, memory and persona. +""" + +from __future__ import annotations + +import asyncio import json import re -from typing import List, Optional, Dict, Any, AsyncIterator, Tuple from pathlib import Path -from .base import ModelConfig, ModelProvider, Message, ToolRegistry -from .models import OpenAIModel, AnthropicModel -from .personality import PersonalitySystem +from typing import Any, Dict, List, Optional + +import httpx + +from .base import Message, ModelConfig, ModelProvider from .memory import MemorySystem -from .task_manager import LongTaskManager +from .models import AnthropicModel, OpenAIModel +from .personality import PersonalitySystem from src.utils.logger import setup_logger -logger = setup_logger('AIClient') +logger = setup_logger("AIClient") class AIClient: - """AI瀹㈡埛绔?- 缁熶竴鎺ュ彛""" - + """High-level application service for chat and memory/persona orchestration.""" + def __init__( self, model_config: ModelConfig, embed_config: Optional[ModelConfig] = None, data_dir: Path = Path("data/ai"), - use_vector_db: bool = True + use_vector_db: bool = True, + use_query_embedding: bool = False, + chat_retries: int = 1, + chat_retry_backoff: float = 0.8, ): self.config = model_config self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) - - # 初始化主模型 + + self.chat_retries = max(0, int(chat_retries)) + self.chat_retry_backoff = max(0.0, float(chat_retry_backoff)) + self.model = self._create_model(model_config) - - # 初始化嵌入模型(如果提供) - self.embed_model = None - if embed_config: - self.embed_model = self._create_model(embed_config) - logger.info( - f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}" - ) - - # 初始化工具注册表 - self.tools = ToolRegistry() - self._tool_sources: Dict[str, str] = {} - - # 初始化人格系统 + self.embed_model = self._create_model(embed_config) if embed_config else None + self.personality = PersonalitySystem( - config_path=data_dir / "personalities.json" + config_path=data_dir / "personalities.json", + state_path=data_dir / "personality_state.json", ) - - # 初始化记忆系统 + self.memory = MemorySystem( storage_path=data_dir / "long_term_memory.json", embed_func=self._embed_wrapper, importance_evaluator=self._evaluate_memory_importance, - use_vector_db=use_vector_db + use_vector_db=use_vector_db, + use_query_embedding=use_query_embedding, ) - - # 初始化长任务管理器 - self.task_manager = LongTaskManager( - storage_path=data_dir / "tasks.json" - ) - + logger.info( - f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}" + "AI client initialized", + extra={ + "provider": model_config.provider.value, + "model": model_config.model_name, + "use_vector_db": use_vector_db, + "use_query_embedding": use_query_embedding, + "chat_retries": self.chat_retries, + }, ) - - def _create_model(self, config: ModelConfig): - """创建模型实例。""" - if config.provider == ModelProvider.OPENAI: + + def _create_model(self, config: Optional[ModelConfig]): + if config is None: + return None + + if config.provider in { + ModelProvider.OPENAI, + ModelProvider.DEEPSEEK, + ModelProvider.QWEN, + }: return OpenAIModel(config) - elif config.provider == ModelProvider.ANTHROPIC: + if config.provider == ModelProvider.ANTHROPIC: return AnthropicModel(config) - elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]: - # DeepSeek 和 Qwen 使用 OpenAI 兼容接口 - return OpenAIModel(config) - else: - raise ValueError(f"不支持的模型提供商: {config.provider}") - + raise ValueError(f"Unsupported model provider: {config.provider}") + async def _embed_wrapper(self, text: str) -> List[float]: - """嵌入向量包装器。""" try: - # 如果有独立的嵌入模型,优先使用 if self.embed_model: return await self.embed_model.embed(text) - # 否则尝试使用主模型 return await self.model.embed(text) except NotImplementedError: - # 如果都不支持嵌入,返回 None(记忆系统会降级) - logger.warning("Current model does not support embeddings; vector retrieval disabled") + logger.warning("Current model does not support embeddings; fallback to local embedding.") return None - except Exception as e: - logger.error(f"生成嵌入向量失败: {e}") + except Exception as exc: + logger.warning(f"Embedding generation failed: {exc}") return None @staticmethod @@ -120,17 +119,12 @@ class AIClient: async def _evaluate_memory_importance( self, content: str, metadata: Optional[Dict] = None ) -> float: - """ - 调用主模型评估记忆重要性,返回 [0, 1] 分值。 - """ system_prompt = ( - "你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。" - "输出一个 0 到 1 的数字,数字越大表示越重要。" - "只输出数字,不要输出任何解释、单位或多余文本。" + "You evaluate if content should be kept as long-term memory. " + "Return only a float between 0 and 1." ) payload = json.dumps( - {"content": content, "metadata": metadata or {}}, - ensure_ascii=False, + {"content": content, "metadata": metadata or {}}, ensure_ascii=False ) messages = [ Message(role="system", content=system_prompt), @@ -146,636 +140,181 @@ class AIClient: ) score = self._parse_importance_score(response.content) return max(0.0, min(1.0, score)) - except Exception as e: - logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}") + except Exception as exc: + logger.warning(f"Memory importance evaluation failed, fallback to 0.5: {exc}") return 0.5 - + + @staticmethod + def _preview_log_payload(payload: Any, max_len: int = 240) -> str: + try: + text = json.dumps(payload, ensure_ascii=False, default=str) + except Exception: + text = str(payload) + if len(text) > max_len: + return text[:max_len] + "..." + return text + + @staticmethod + def _is_retriable_error(exc: Exception) -> bool: + if isinstance(exc, (httpx.ReadTimeout, httpx.ConnectError, asyncio.TimeoutError)): + return True + + lower = str(exc).lower() + retriable_signals = [ + "timed out", + "timeout", + "temporarily unavailable", + "connection reset", + "service unavailable", + "rate limit", + "429", + "502", + "503", + "504", + ] + return any(signal in lower for signal in retriable_signals) + + async def _chat_with_retry(self, messages: List[Message], **kwargs): + attempts = 1 + self.chat_retries + last_error: Optional[Exception] = None + + for attempt in range(1, attempts + 1): + try: + return await self.model.chat(messages, None, **kwargs) + except Exception as exc: + last_error = exc + if attempt >= attempts or not self._is_retriable_error(exc): + break + + sleep_seconds = self.chat_retry_backoff * (2 ** (attempt - 1)) + logger.warning( + "Chat request failed, retrying", + extra={ + "attempt": attempt, + "max_attempts": attempts, + "sleep_seconds": sleep_seconds, + "error": str(exc), + }, + ) + await asyncio.sleep(sleep_seconds) + + if last_error: + raise last_error + raise RuntimeError("chat retry loop ended unexpectedly") + async def chat( self, user_id: str, user_message: str, system_prompt: Optional[str] = None, use_memory: bool = True, - use_tools: bool = True, stream: bool = False, - **kwargs + group_id: Optional[str] = None, + session_id: Optional[str] = None, + memory_key: Optional[str] = None, + **kwargs, ) -> str: - """对话接口。""" - try: - # 构建消息列表 - messages = [] - - # 系统提示词 - if system_prompt is None: - system_prompt = self.personality.get_system_prompt() - - # 注入记忆上下文 - if use_memory: - short_term, long_term = await self.memory.get_context( - user_id=user_id, - query=user_message - ) - - if short_term or long_term: - memory_context = self.memory.format_context(short_term, long_term) - system_prompt += f"\n\n{memory_context}" - - messages.append(Message(role="system", content=system_prompt)) - - # 添加用户消息 - messages.append(Message(role="user", content=user_message)) - - # 准备工具 - tools = None - tool_names: List[str] = [] - if use_tools and self.tools.list(): - tools = self.tools.to_openai_format() - tool_names = [tool.name for tool in self.tools.list()] - forced_tool_name = self._extract_forced_tool_name(user_message, tool_names) - if forced_tool_name: - kwargs = dict(kwargs) - kwargs["forced_tool_name"] = forced_tool_name - if tools: - before_count = len(tools) - tools = [ - tool - for tool in tools - if ((tool.get("function") or {}).get("name") == forced_tool_name) - ] - if len(tools) == 1: - tool_names = [forced_tool_name] - logger.info( - "显式工具调用已收敛工具列表: " - f"{before_count} -> {len(tools)}" - ) - logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") + if stream: + raise NotImplementedError("stream mode is not supported by AIClient.chat") - logger.info( - "LLM请求: " - f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " - f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " - f"tool_names={self._preview_log_payload(tool_names)}, " - f"forced_tool={forced_tool_name or '-'}" + memory_user_key = memory_key or user_id + messages: List[Message] = [] + + if system_prompt is None: + system_prompt = self.personality.get_system_prompt( + user_id=user_id, + group_id=group_id, + session_id=session_id, ) - logger.info( - "LLM输入: " - f"user_message={self._preview_log_payload(user_message)}" + + if use_memory: + short_term, long_term = await self.memory.get_context( + user_id=memory_user_key, + query=user_message, ) - - # 调用模型 - if stream: - return self._chat_stream(messages, tools, **kwargs) - else: - response = await self.model.chat(messages, tools, **kwargs) - response_tool_count = len(response.tool_calls or []) - response_tool_names = [] - for tool_call in response.tool_calls or []: - if isinstance(tool_call, dict): - function_info = tool_call.get("function") or {} - response_tool_names.append(function_info.get("name")) - else: - function_info = getattr(tool_call, "function", None) - response_tool_names.append( - getattr(function_info, "name", None) if function_info else None - ) - logger.info( - "LLM首轮输出: " - f"tool_calls={response_tool_count}, " - f"tool_names={self._preview_log_payload(response_tool_names)}, " - f"content={self._preview_log_payload(response.content)}" - ) - - # 处理工具调用 - if response.tool_calls: - response = await self._handle_tool_calls( - messages, response, tools, **kwargs - ) - elif forced_tool_name: - forced_response = await self._run_forced_tool_fallback( - forced_tool_name=forced_tool_name, - user_message=user_message, - ) - if forced_response is not None: - response = forced_response - - # 写入记忆 - if use_memory: - stored_memory = await self.memory.add_qa_pair( - user_id=user_id, - question=user_message, - answer=response.content, - metadata={"source": "chat"}, - ) - if stored_memory: - logger.info( - "已写入长期记忆问答对:\n" - f"{stored_memory.content}\n" - f"memory_id={stored_memory.id}, " - f"importance={stored_memory.importance:.2f}" - ) - - return response.content - - except Exception as e: - logger.error(f"对话失败: {type(e).__name__}: {e!r}") - raise + if short_term or long_term: + memory_context = self.memory.format_context(short_term, long_term) + system_prompt = f"{system_prompt}\n\n{memory_context}".strip() - async def _run_forced_tool_fallback( - self, forced_tool_name: str, user_message: str - ) -> Optional[Message]: - """Execute forced tool locally when model did not emit tool_calls.""" - tool_def = self.tools.get(forced_tool_name) - tool_source = self._tool_sources.get(forced_tool_name, "custom") - if not tool_def: - logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}") - return None - - logger.warning( - "模型未返回 tool_calls,启用本地强制工具执行: " - f"source={tool_source}, name={forced_tool_name}" - ) - - try: - result = tool_def.function() - if inspect.isawaitable(result): - result = await result - except TypeError as exc: - logger.warning( - "本地强制工具执行失败(参数不匹配): " - f"name={forced_tool_name}, error={exc}" - ) - return None - except Exception as exc: - logger.warning( - "本地强制工具执行失败: " - f"name={forced_tool_name}, error={exc}" - ) - return None - - result_text = str(result) - pipelined_text = await self._run_skill_doc_pipeline( - forced_tool_name=forced_tool_name, - skill_doc=result_text, - user_message=user_message, - ) - if pipelined_text is not None: - result_text = pipelined_text - - prefix_limit = self._extract_prefix_limit(user_message) - if prefix_limit: - result_text = result_text[:prefix_limit] + messages.append(Message(role="system", content=system_prompt)) + messages.append(Message(role="user", content=user_message)) logger.info( - "本地强制工具执行成功: " - f"source={tool_source}, name={forced_tool_name}, " - f"result={self._preview_log_payload(result_text)}" + "LLM request", + extra={ + "user_id": user_id, + "group_id": group_id, + "session_id": session_id, + "memory_key": memory_user_key, + "use_memory": use_memory, + "message_preview": self._preview_log_payload(user_message), + }, ) - return Message(role="assistant", content=result_text) - - async def _run_skill_doc_pipeline( - self, forced_tool_name: str, skill_doc: str, user_message: str - ) -> Optional[str]: - """Run an extra model step: execute instructions from skill doc on user text.""" - if not forced_tool_name.endswith(".read_skill_doc"): - return None - - target_text = self._extract_processing_payload(user_message) - if not target_text: - return None + response = await self._chat_with_retry(messages, **kwargs) logger.info( - "强制工具后续处理开始: " - f"name={forced_tool_name}, target_len={len(target_text)}" + "LLM response", + extra={"content_preview": self._preview_log_payload(response.content)}, ) - messages = [ - Message( - role="system", - content=( - "你是技能执行器。请严格按下面技能文档处理用户文本。" - "不要复述技能文档,不要解释工具调用过程,只输出最终处理结果。\n\n" - "[技能文档开始]\n" - f"{skill_doc}\n" - "[技能文档结束]" - ), - ), - Message( - role="user", - content=( - "请根据技能文档处理以下文本,保持原意并提升自然度:\n" - f"{target_text}" - ), - ), - ] - - try: - response = await self.model.chat(messages=messages, tools=None) - content = (response.content or "").strip() - if not content: - return None - - logger.info( - "强制工具后续处理完成: " - f"name={forced_tool_name}, output_len={len(content)}" + if use_memory: + stored_memory = await self.memory.add_qa_pair( + user_id=memory_user_key, + question=user_message, + answer=response.content, + metadata={ + "source": "chat", + "user_id": user_id, + "group_id": group_id, + "session_id": session_id, + }, ) - return content - except Exception as exc: - logger.warning( - "强制工具后续处理失败,回退为工具原始输出: " - f"name={forced_tool_name}, error={exc}" - ) - return None - - async def _chat_stream( - self, - messages: List[Message], - tools: Optional[List[Dict]], - **kwargs - ) -> AsyncIterator[str]: - """流式对话。""" - async for chunk in self.model.chat_stream(messages, tools, **kwargs): - yield chunk - - async def _handle_tool_calls( - self, - messages: List[Message], - response: Message, - tools: Optional[List[Dict]], - **kwargs - ) -> Message: - """处理工具调用。""" - messages.append(response) - total_calls = len(response.tool_calls or []) - if total_calls: - logger.info(f"检测到工具调用请求: {total_calls} 个") - - # 执行工具调用 - for tool_call in response.tool_calls or []: - try: - tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) - except Exception as e: - logger.warning(f"解析工具调用失败: {e}") - fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None) - if fallback_id: - messages.append(Message( - role="tool", - content=f"工具参数解析失败: {str(e)}", - tool_call_id=fallback_id, - name="tool" - )) - continue - if not tool_name: - logger.warning(f"跳过无效工具调用: {tool_call}") - continue - - tool_def = self.tools.get(tool_name) - tool_source = self._tool_sources.get(tool_name, "custom") - if not tool_def: - error_msg = f"未找到工具: {tool_name}" - logger.warning(error_msg) - messages.append(Message( - role="tool", - name=tool_name, - content=error_msg, - tool_call_id=tool_call_id - )) - continue - - try: + if stored_memory: logger.info( - "工具调用开始: " - f"source={tool_source}, name={tool_name}, " - f"args={self._preview_log_payload(tool_args)}" + "Long-term memory stored", + extra={ + "memory_id": stored_memory.id, + "importance": stored_memory.importance, + }, ) - result = tool_def.function(**tool_args) - if inspect.isawaitable(result): - result = await result - logger.info( - "工具调用成功: " - f"source={tool_source}, name={tool_name}, " - f"result={self._preview_log_payload(result)}" - ) - messages.append(Message( - role="tool", - name=tool_name, - content=str(result), - tool_call_id=tool_call_id - )) - except Exception as e: - logger.warning( - "工具调用失败: " - f"source={tool_source}, name={tool_name}, error={e}" - ) - messages.append(Message( - role="tool", - name=tool_name, - content=f"工具执行失败: {str(e)}", - tool_call_id=tool_call_id - )) - # 再次调用模型获取最终响应 - final_kwargs = dict(kwargs) - # Force only the first model turn, avoid recursive force after tool result. - final_kwargs.pop("forced_tool_name", None) - final_response = await self.model.chat(messages, tools, **final_kwargs) - logger.info( - "LLM最终输出: " - f"content={self._preview_log_payload(final_response.content)}" + return response.content + + def set_personality( + self, personality_name: str, scope: str = "global", scope_id: Optional[str] = None + ) -> bool: + return self.personality.set_personality( + key=personality_name, + scope=scope, + scope_id=scope_id, ) - return final_response - def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]: - """兼容不同 SDK 返回的工具调用结构。""" - if isinstance(tool_call, dict): - tool_call_id = tool_call.get('id') - function = tool_call.get('function') or {} - tool_name = function.get('name') - raw_args = function.get('arguments') - else: - tool_call_id = getattr(tool_call, 'id', None) - function = getattr(tool_call, 'function', None) - tool_name = getattr(function, 'name', None) if function else None - raw_args = getattr(function, 'arguments', None) if function else None - - tool_args = self._normalize_tool_args(raw_args) - return tool_name, tool_args, tool_call_id - - def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]: - """将工具参数统一转换为字典。""" - if raw_args is None: - return {} - - if isinstance(raw_args, dict): - return raw_args - - if isinstance(raw_args, str): - raw_args = raw_args.strip() - if not raw_args: - return {} - parsed = json.loads(raw_args) - if not isinstance(parsed, dict): - raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}") - return parsed - - if hasattr(raw_args, 'model_dump'): - parsed = raw_args.model_dump() - if isinstance(parsed, dict): - return parsed - - raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") - - @staticmethod - def _preview_log_payload(payload: Any, max_len: int = 240) -> str: - """日志中展示参数/结果时使用的简短预览。""" - try: - text = json.dumps(payload, ensure_ascii=False, default=str) - except Exception: - text = str(payload) - - if len(text) > max_len: - return text[:max_len] + "..." - return text - - @staticmethod - def _extract_prefix_limit(user_message: str) -> Optional[int]: - """Extract requested output prefix length like '前100字'.""" - if not user_message: - return None - - match = re.search(r"前\s*(\d{1,4})\s*字", user_message) - if not match: - return None - - try: - limit = int(match.group(1)) - except (TypeError, ValueError): - return None - - if limit <= 0: - return None - return min(limit, 5000) - - @staticmethod - def _extract_processing_payload(user_message: str) -> Optional[str]: - """Extract text payload like '处理以下文本:...' from user message.""" - if not user_message: - return None - - text = user_message.strip() - markers = [ - "以下文本:", - "以下文本:", - "文本:", - "文本:", - ] - for marker in markers: - idx = text.find(marker) - if idx < 0: - continue - payload = text[idx + len(marker) :].strip() - if payload: - return payload - - pattern = re.compile( - r"(?:处理|润色|改写|人性化处理|优化)[\s\S]{0,32}(?:如下|以下)[::]\s*([\s\S]+)$" - ) - match = pattern.search(text) - if match: - payload = (match.group(1) or "").strip() - if payload: - return payload - - return None - - @staticmethod - def _compact_identifier(text: str) -> str: - """Compact identifier for fuzzy matching (e.g. humanizer_zh -> humanizerzh).""" - return re.sub(r"[^a-z0-9]+", "", (text or "").lower()) - - @staticmethod - def _extract_forced_tool_name( - user_message: str, available_tool_names: List[str] - ) -> Optional[str]: - if not user_message or not available_tool_names: - return None - - triggers = [ - "调用工具", - "使用工具", - "只调用", - "务必调用", - "必须调用", - "调用", - "使用", - "tool", - ] - if not any(trigger in user_message for trigger in triggers): - return None - - pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)") - explicit_matches = [ - name for name in pattern.findall(user_message) if name in available_tool_names - ] - if len(explicit_matches) == 1: - return explicit_matches[0] - if len(explicit_matches) > 1: - return None - - contained = [name for name in available_tool_names if name in user_message] - if len(contained) == 1: - return contained[0] - - # 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。 - prefixes = sorted( - {name.split(".", 1)[0] for name in available_tool_names}, - key=len, - reverse=True, - ) - matched_prefixes = [ - prefix - for prefix in prefixes - if re.search(rf"\b{re.escape(prefix)}\b", user_message) - ] - if len(matched_prefixes) == 1: - prefix_tools = [ - name - for name in available_tool_names - if name.startswith(f"{matched_prefixes[0]}.") - ] - if len(prefix_tools) == 1: - return prefix_tools[0] - - # 模糊匹配:支持省略下划线/点号的写法(如 humanizerzh)。 - compact_message = AIClient._compact_identifier(user_message) - if compact_message: - compact_full_matches = [] - for tool_name in available_tool_names: - compact_tool_name = AIClient._compact_identifier(tool_name) - if compact_tool_name and compact_tool_name in compact_message: - compact_full_matches.append(tool_name) - - if len(compact_full_matches) == 1: - return compact_full_matches[0] - if len(compact_full_matches) > 1: - return None - - compact_prefix_map: Dict[str, List[str]] = {} - for tool_name in available_tool_names: - prefix = tool_name.split(".", 1)[0] - compact_prefix = AIClient._compact_identifier(prefix) - if not compact_prefix: - continue - compact_prefix_map.setdefault(compact_prefix, []).append(tool_name) - - compact_prefix_matches = [ - compact_prefix - for compact_prefix in compact_prefix_map - if compact_prefix in compact_message - ] - if len(compact_prefix_matches) == 1: - matched_tools = compact_prefix_map[compact_prefix_matches[0]] - if len(matched_tools) == 1: - return matched_tools[0] - - return None - - def set_personality(self, personality_name: str) -> bool: - """设置人格。""" - return self.personality.set_personality(personality_name) - def list_personalities(self) -> List[str]: - """列出所有人格。""" return self.personality.list_personalities() def switch_model(self, model_config: ModelConfig) -> bool: - """Runtime switch for primary chat model.""" - new_model = self._create_model(model_config) - self.model = new_model + self.model = self._create_model(model_config) self.config = model_config logger.info( - f"已切换主模型: {model_config.provider.value}/{model_config.model_name}" + "Primary model switched", + extra={ + "provider": model_config.provider.value, + "model": model_config.model_name, + }, ) return True - async def create_long_task( - self, - user_id: str, - title: str, - description: str, - steps: List[Dict], - metadata: Optional[Dict] = None - ) -> str: - """创建长任务。""" - return self.task_manager.create_task( - user_id=user_id, - title=title, - description=description, - steps=steps, - metadata=metadata - ) - - async def start_task( - self, - task_id: str, - progress_callback: Optional[callable] = None - ): - """启动任务。""" - await self.task_manager.start_task(task_id, progress_callback) - - def get_task_status(self, task_id: str) -> Optional[Dict]: - """获取任务状态。""" - return self.task_manager.get_task_status(task_id) - - def register_tool( - self, - name: str, - description: str, - parameters: Dict, - function: callable, - source: str = "custom", - ): - """注册工具。""" - from .base import ToolDefinition - tool = ToolDefinition( - name=name, - description=description, - parameters=parameters, - function=function - ) - self.tools.register(tool) - self._tool_sources[name] = source - logger.info(f"已注册工具: {name} (source={source})") - - def unregister_tool(self, name: str) -> bool: - """卸载工具。""" - removed = self.tools.unregister(name) - if removed: - self._tool_sources.pop(name, None) - logger.info(f"已卸载工具: {name}") - return removed - - def unregister_tools_by_prefix(self, prefix: str) -> int: - """按前缀批量卸载工具。""" - removed_count = self.tools.unregister_by_prefix(prefix) - for tool_name in list(self._tool_sources.keys()): - if tool_name.startswith(prefix): - self._tool_sources.pop(tool_name, None) - if removed_count: - logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") - return removed_count - def clear_memory(self, user_id: str): - """清除用户短期记忆。""" self.memory.clear_short_term(user_id) - logger.info(f"Cleared short-term memory for user {user_id}") + logger.info(f"Cleared short-term memory for {user_id}") async def clear_long_term_memory(self, user_id: str) -> bool: try: await self.memory.clear_long_term(user_id) - logger.info(f"Cleared long-term memory for user {user_id}") + logger.info(f"Cleared long-term memory for {user_id}") return True - except Exception as e: - logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}") + except Exception as exc: + logger.warning(f"Failed to clear long-term memory for {user_id}: {exc}") return False async def list_long_term_memories(self, user_id: str, limit: int = 20): @@ -823,10 +362,18 @@ class AIClient: return await self.memory.delete_long_term(user_id, memory_id) async def clear_all_memory(self, user_id: str) -> bool: - """清除用户全部记忆(短期 + 长期)。""" self.clear_memory(user_id) - try: - return await self.clear_long_term_memory(user_id) - except Exception: - return False + return await self.clear_long_term_memory(user_id) + async def close(self): + await self.memory.close() + + for model in [self.model, self.embed_model]: + if model is None: + continue + close = getattr(model, "close", None) + if close is None: + continue + maybe_awaitable = close() + if asyncio.iscoroutine(maybe_awaitable): + await maybe_awaitable diff --git a/src/ai/mcp/__init__.py b/src/ai/mcp/__init__.py deleted file mode 100644 index 4bf4500..0000000 --- a/src/ai/mcp/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -MCP模块 -""" -from .base import MCPServer, MCPClient, MCPManager, MCPResource, MCPTool, MCPPrompt - -__all__ = [ - 'MCPServer', - 'MCPClient', - 'MCPManager', - 'MCPResource', - 'MCPTool', - 'MCPPrompt' -] diff --git a/src/ai/mcp/base.py b/src/ai/mcp/base.py deleted file mode 100644 index d375176..0000000 --- a/src/ai/mcp/base.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -MCP (Model Context Protocol) 支持 -""" -import asyncio -import json -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass, asdict -from pathlib import Path -from src.utils.logger import setup_logger - -logger = setup_logger('MCPSystem') - - -@dataclass -class MCPResource: - """MCP资源""" - uri: str - name: str - description: str - mime_type: str - - -@dataclass -class MCPTool: - """MCP工具""" - name: str - description: str - input_schema: Dict[str, Any] - - -@dataclass -class MCPPrompt: - """MCP提示词""" - name: str - description: str - arguments: List[Dict[str, Any]] - - -class MCPServer: - """MCP服务器基类""" - - def __init__(self, name: str, version: str): - self.name = name - self.version = version - self.resources: Dict[str, MCPResource] = {} - self.tools: Dict[str, Callable] = {} - self.tool_specs: Dict[str, MCPTool] = {} - self.prompts: Dict[str, MCPPrompt] = {} - - async def initialize(self): - """初始化服务器""" - pass - - async def shutdown(self): - """关闭服务器""" - pass - - def register_resource(self, resource: MCPResource): - """注册资源""" - self.resources[resource.uri] = resource - - def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): - """注册工具""" - tool = MCPTool(name=name, description=description, input_schema=input_schema) - self.tool_specs[name] = tool - self.tools[name] = handler - logger.info(f"✅ MCP工具注册: {self.name}.{name}") - - def register_prompt(self, prompt: MCPPrompt): - """注册提示词""" - self.prompts[prompt.name] = prompt - - async def list_resources(self) -> List[MCPResource]: - """列出资源""" - return list(self.resources.values()) - - async def read_resource(self, uri: str) -> Optional[str]: - """读取资源""" - raise NotImplementedError - - async def list_tools(self) -> List[MCPTool]: - """列出工具""" - return list(self.tool_specs.values()) - - async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - if name not in self.tools: - raise ValueError(f"工具不存在: {name}") - - handler = self.tools[name] - logger.info( - "MCP工具调用开始: " - f"server={self.name}, tool={name}, " - f"args={json.dumps(arguments, ensure_ascii=False, default=str)}" - ) - try: - result = await handler(**arguments) - except Exception as exc: - logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}") - raise - logger.info(f"MCP工具调用成功: server={self.name}, tool={name}") - return result - - async def list_prompts(self) -> List[MCPPrompt]: - """列出提示词""" - return list(self.prompts.values()) - - async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]: - """获取提示词""" - raise NotImplementedError - - -class MCPClient: - """MCP客户端""" - - def __init__(self): - self.servers: Dict[str, MCPServer] = {} - - async def connect_server(self, server: MCPServer): - """连接服务器""" - await server.initialize() - self.servers[server.name] = server - logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}") - - async def disconnect_server(self, server_name: str): - """断开服务器""" - if server_name in self.servers: - await self.servers[server_name].shutdown() - del self.servers[server_name] - logger.info(f"✅ 断开MCP服务器: {server_name}") - - def get_server(self, name: str) -> Optional[MCPServer]: - """获取服务器""" - return self.servers.get(name) - - def list_servers(self) -> List[str]: - """列出所有服务器""" - return list(self.servers.keys()) - - async def list_all_resources(self) -> Dict[str, List[MCPResource]]: - """列出所有资源""" - result = {} - for name, server in self.servers.items(): - result[name] = await server.list_resources() - return result - - async def list_all_tools(self) -> Dict[str, List[MCPTool]]: - """列出所有工具""" - result = {} - for name, server in self.servers.items(): - result[name] = await server.list_tools() - return result - - async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - server = self.get_server(server_name) - if not server: - raise ValueError(f"服务器不存在: {server_name}") - - return await server.call_tool(tool_name, arguments) - - -class MCPManager: - """MCP管理器""" - - def __init__(self, config_path: Path): - self.config_path = config_path - self.client = MCPClient() - self.server_configs: Dict[str, Dict] = {} - self._load_config() - - def _load_config(self): - """加载配置""" - if self.config_path.exists(): - with open(self.config_path, 'r', encoding='utf-8') as f: - self.server_configs = json.load(f) - - def _save_config(self): - """保存配置""" - self.config_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.config_path, 'w', encoding='utf-8') as f: - json.dump(self.server_configs, f, ensure_ascii=False, indent=2) - - async def register_server(self, server: MCPServer, config: Optional[Dict] = None): - """注册服务器""" - await self.client.connect_server(server) - - if config: - self.server_configs[server.name] = config - self._save_config() - - async def unregister_server(self, server_name: str): - """注销服务器""" - await self.client.disconnect_server(server_name) - - if server_name in self.server_configs: - del self.server_configs[server_name] - self._save_config() - - def get_client(self) -> MCPClient: - """获取客户端""" - return self.client - - async def get_all_tools_for_ai(self) -> List[Dict]: - """获取所有工具(AI格式)""" - all_tools = [] - tools_by_server = await self.client.list_all_tools() - - for server_name, tools in tools_by_server.items(): - for tool in tools: - all_tools.append({ - "type": "function", - "function": { - "name": f"{server_name}.{tool.name}", - "description": tool.description, - "parameters": tool.input_schema - } - }) - - return all_tools - - async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any: - """执行工具""" - parts = full_tool_name.split('.', 1) - if len(parts) != 2: - raise ValueError(f"工具名格式错误: {full_tool_name}") - - server_name, tool_name = parts - logger.info(f"MCP执行请求: {full_tool_name}") - return await self.client.call_tool(server_name, tool_name, arguments) diff --git a/src/ai/mcp/servers/__init__.py b/src/ai/mcp/servers/__init__.py deleted file mode 100644 index 2b5a18c..0000000 --- a/src/ai/mcp/servers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -MCP服务器实现 -""" -from .filesystem import FileSystemMCPServer - -__all__ = ['FileSystemMCPServer'] diff --git a/src/ai/mcp/servers/filesystem.py b/src/ai/mcp/servers/filesystem.py deleted file mode 100644 index 8f793da..0000000 --- a/src/ai/mcp/servers/filesystem.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -MCP示例服务器 - 文件系统访问 -""" -from pathlib import Path -from typing import Optional -from ..base import MCPServer, MCPResource - - -class FileSystemMCPServer(MCPServer): - """文件系统MCP服务器""" - - def __init__(self, root_path: Path): - super().__init__(name="filesystem", version="1.0.0") - self.root_path = root_path - - async def initialize(self): - """初始化""" - # 注册工具 - self.register_tool( - name="read_file", - description="读取文件内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "文件路径" - } - }, - "required": ["path"] - }, - handler=self.read_file - ) - - self.register_tool( - name="write_file", - description="写入文件内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "文件路径" - }, - "content": { - "type": "string", - "description": "文件内容" - } - }, - "required": ["path", "content"] - }, - handler=self.write_file - ) - - self.register_tool( - name="list_directory", - description="列出目录内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "目录路径" - } - }, - "required": ["path"] - }, - handler=self.list_directory - ) - - def _resolve_path(self, path: str) -> Path: - """解析路径""" - full_path = (self.root_path / path).resolve() - - # 安全检查:确保路径在root_path内 - if not str(full_path).startswith(str(self.root_path)): - raise ValueError("路径超出允许范围") - - return full_path - - async def read_file(self, path: str) -> str: - """读取文件""" - file_path = self._resolve_path(path) - - if not file_path.exists(): - raise FileNotFoundError(f"文件不存在: {path}") - - if not file_path.is_file(): - raise ValueError(f"不是文件: {path}") - - with open(file_path, 'r', encoding='utf-8') as f: - return f.read() - - async def write_file(self, path: str, content: str) -> str: - """写入文件""" - file_path = self._resolve_path(path) - - file_path.parent.mkdir(parents=True, exist_ok=True) - - with open(file_path, 'w', encoding='utf-8') as f: - f.write(content) - - return f"文件已写入: {path}" - - async def list_directory(self, path: str) -> list: - """列出目录""" - dir_path = self._resolve_path(path) - - if not dir_path.exists(): - raise FileNotFoundError(f"目录不存在: {path}") - - if not dir_path.is_dir(): - raise ValueError(f"不是目录: {path}") - - items = [] - for item in dir_path.iterdir(): - items.append({ - "name": item.name, - "type": "directory" if item.is_dir() else "file", - "size": item.stat().st_size if item.is_file() else None - }) - - return items diff --git a/src/ai/memory.py b/src/ai/memory.py index 6c20dee..1918197 100644 --- a/src/ai/memory.py +++ b/src/ai/memory.py @@ -1,180 +1,166 @@ -""" -记忆系统:短期记忆、长期记忆与 RAG 检索(向量数据库)。 """ -import asyncio +Memory system: short-term window + long-term retrieval. +""" + +from __future__ import annotations + import hashlib import shutil import time import uuid -from typing import List, Dict, Optional, Tuple, Callable, Awaitable +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path -from collections import deque -from .vector_store import VectorStore, VectorMemory, ChromaVectorStore, JSONVectorStore +from typing import Awaitable, Callable, Dict, List, Optional, Tuple + +from .vector_store import ChromaVectorStore, JSONVectorStore, VectorMemory, VectorStore from src.utils.logger import setup_logger -logger = setup_logger('MemorySystem') +logger = setup_logger("MemorySystem") @dataclass class MemoryItem: - """记忆项(用于短期记忆)。""" + """In-memory short-term record.""" + content: str timestamp: datetime user_id: str importance: float = 0.5 metadata: Dict = field(default_factory=dict) - + def to_dict(self) -> Dict: - """转换为字典。""" return { - 'content': self.content, - 'timestamp': self.timestamp.isoformat(), - 'user_id': self.user_id, - 'importance': self.importance, - 'metadata': self.metadata + "content": self.content, + "timestamp": self.timestamp.isoformat(), + "user_id": self.user_id, + "importance": self.importance, + "metadata": self.metadata, } class ShortTermMemory: - """短期记忆(滑动窗口)。""" - + """Short-term memory window.""" + def __init__(self, max_size: int = 20, max_age_minutes: int = 30): self.max_size = max_size self.max_age = timedelta(minutes=max_age_minutes) - self.memories: Dict[str, deque] = {} # user_id -> deque of MemoryItem - + self.memories: Dict[str, deque] = {} + def add(self, user_id: str, content: str, metadata: Optional[Dict] = None): - """添加短期记忆。""" if user_id not in self.memories: self.memories[user_id] = deque(maxlen=self.max_size) - - memory = MemoryItem( - content=content, - timestamp=datetime.now(), - user_id=user_id, - metadata=metadata or {} + + self.memories[user_id].append( + MemoryItem( + content=content, + timestamp=datetime.now(), + user_id=user_id, + metadata=metadata or {}, + ) ) - - self.memories[user_id].append(memory) - + def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]: - """获取短期记忆。""" - if user_id not in self.memories: + items = list(self.memories.get(user_id, [])) + if not items: return [] - - # 过滤过期记忆 + now = datetime.now() - valid_memories = [ - m for m in self.memories[user_id] - if now - m.timestamp <= self.max_age - ] - - if limit: - valid_memories = valid_memories[-limit:] - - return valid_memories - + valid = [m for m in items if now - m.timestamp <= self.max_age] + self.memories[user_id] = deque(valid, maxlen=self.max_size) + + if limit and limit > 0: + return valid[-limit:] + return valid + def clear(self, user_id: str): - """清除用户短期记忆。""" - if user_id in self.memories: - self.memories.pop(user_id, None) + self.memories.pop(user_id, None) class MemorySystem: - """记忆系统:整合短期记忆与长期记忆。""" - + """Memory system: short-term + long-term storage.""" + def __init__( self, storage_path: Path, - embed_func: Optional[callable] = None, + embed_func: Optional[Callable[[str], Awaitable[List[float]]]] = None, importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None, importance_threshold: float = 0.6, use_vector_db: bool = True, use_query_embedding: bool = False, + max_long_term_per_user: int = 500, + dedup_window_seconds: int = 300, ): self.short_term = ShortTermMemory() self.embed_func = embed_func self.importance_evaluator = importance_evaluator self.importance_threshold = importance_threshold - # Only embed retrieval queries when explicitly enabled. self.use_query_embedding = use_query_embedding - - # 初始化向量存储 + self.max_long_term_per_user = max(10, int(max_long_term_per_user)) + self.dedup_window = timedelta(seconds=max(1, int(dedup_window_seconds))) + if use_vector_db: chroma_path = storage_path.parent / "chroma_db" chroma_store = self._init_chroma_store(chroma_path) if chroma_store is not None: self.vector_store = chroma_store - logger.info("Using Chroma vector store") else: self.vector_store = JSONVectorStore(storage_path) else: - # 使用 JSON 存储(向后兼容) self.vector_store = JSONVectorStore(storage_path) - logger.info("使用 JSON 存储") @staticmethod def _is_chroma_table_conflict(error: Exception) -> bool: - msg = str(error).lower() - return "table embeddings already exists" in msg + return "table embeddings already exists" in str(error).lower() @staticmethod def _is_chroma_trigram_error(error: Exception) -> bool: - msg = str(error).lower() - return "no such tokenizer: trigram" in msg + return "no such tokenizer: trigram" in str(error).lower() def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]: - """初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。""" try: return ChromaVectorStore(chroma_path) except Exception as error: if self._is_chroma_trigram_error(error): logger.warning( - "Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。" - "请在运行环境升级 sqlite 或安装 pysqlite3-binary。" + "Chroma unavailable (sqlite trigram unsupported), fallback to JSON store." ) return None - if not self._is_chroma_table_conflict(error): - logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}") + logger.warning(f"Chroma init failed, fallback to JSON: {error}") return None - # 先做一次短暂重试,处理并发启动时的瞬时冲突。 - logger.warning(f"Chroma 初始化出现 schema 冲突,正在重试: {error}") + logger.warning(f"Chroma schema conflict, retry once: {error}") time.sleep(0.2) try: return ChromaVectorStore(chroma_path) except Exception as retry_error: if not self._is_chroma_table_conflict(retry_error): - logger.warning(f"Chroma 重试失败,降级为 JSON 存储: {retry_error}") + logger.warning(f"Chroma retry failed, fallback to JSON: {retry_error}") return None - backup_name = ( + backup_path = chroma_path.parent / ( f"{chroma_path.name}_backup_conflict_" f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" ) - backup_path = chroma_path.parent / backup_name - try: if chroma_path.exists(): shutil.move(str(chroma_path), str(backup_path)) chroma_path.mkdir(parents=True, exist_ok=True) repaired = ChromaVectorStore(chroma_path) logger.warning( - f"检测到 Chroma 元数据库冲突,已重建目录并保留备份: {backup_path}" + f"Chroma metadata repaired by rebuilding directory. Backup: {backup_path}" ) return repaired except Exception as repair_error: - logger.warning(f"Chroma 修复失败,降级为 JSON 存储: {repair_error}") + logger.warning(f"Chroma repair failed, fallback to JSON: {repair_error}") return None - + @staticmethod def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]: if not values: return [0.0] * dim - normalized = [float(v) for v in values[:dim]] if len(normalized) < dim: normalized.extend([0.0] * (dim - len(normalized))) @@ -191,14 +177,11 @@ class MemorySystem: return vec for idx, byte in enumerate(encoded): - bucket = idx % dim - vec[bucket] += (byte / 255.0) + vec[idx % dim] += byte / 255.0 digest = hashlib.sha256(encoded).digest() for idx, byte in enumerate(digest): - bucket = idx % dim - vec[bucket] += ((byte / 255.0) - 0.5) * 0.1 - + vec[idx % dim] += ((byte / 255.0) - 0.5) * 0.1 return vec async def _build_embedding(self, text: str) -> List[float]: @@ -207,9 +190,8 @@ class MemorySystem: embedding = await self.embed_func(text) if embedding: return [float(v) for v in list(embedding)] - except Exception as e: - logger.warning(f"embedding generation failed: {e}") - + except Exception as exc: + logger.warning(f"Embedding generation failed, fallback to local: {exc}") return self._local_embedding(text) async def _add_vector_memory( @@ -231,10 +213,8 @@ class MemorySystem: ): return True - # Chroma collection may have a fixed historical embedding dimension. - candidate_dims = [] - base_len = len(embedding or []) - for dim in [base_len, 1024, 1536, 768, 384, 3072]: + candidate_dims: List[int] = [] + for dim in [len(embedding or []), 1024, 1536, 768, 384, 3072]: if dim and dim > 0 and dim not in candidate_dims: candidate_dims.append(dim) @@ -250,7 +230,6 @@ class MemorySystem: ) if ok: return True - return False @staticmethod @@ -261,15 +240,53 @@ class MemorySystem: value = 0.5 return max(0.0, min(1.0, value)) + async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float: + if not content or not content.strip(): + return 0.0 + + if self.importance_evaluator: + try: + score = await self.importance_evaluator(content, metadata) + return self._normalize_importance(score) + except Exception as exc: + logger.warning(f"Importance evaluation failed, fallback to 0.5: {exc}") + return 0.5 + + async def _is_duplicate_long_term(self, user_id: str, content: str) -> bool: + normalized = " ".join((content or "").split()) + if not normalized: + return True + + all_memories = await self.vector_store.get_all(user_id) + recent_cutoff = datetime.now() - self.dedup_window + for memory in all_memories[-20:]: + if ( + " ".join((memory.content or "").split()) == normalized + and memory.timestamp >= recent_cutoff + ): + return True + return False + + async def _trim_user_long_term(self, user_id: str): + all_memories = await self.vector_store.get_all(user_id) + if len(all_memories) <= self.max_long_term_per_user: + return + + all_memories.sort(key=lambda m: (m.importance, m.timestamp)) + to_delete = len(all_memories) - self.max_long_term_per_user + for memory in all_memories[:to_delete]: + await self.vector_store.delete(memory.id) + async def add_message( self, user_id: str, role: str, content: str, - metadata: Optional[Dict] = None + metadata: Optional[Dict] = None, ): - """向短期记忆添加单条消息(不做长期记忆评分)。""" - self.short_term.add(user_id, content, metadata) + payload = dict(metadata or {}) + payload.setdefault("role", role) + self.short_term.add(user_id, content, payload) async def add_qa_pair( self, @@ -278,9 +295,6 @@ class MemorySystem: answer: str, metadata: Optional[Dict] = None, ) -> Optional[VectorMemory]: - """ - 添加最新问答对,并仅对该问答对做模型重要性评估。 - """ user_meta = {"role": "user"} assistant_meta = {"role": "assistant"} if isinstance(metadata, dict): @@ -297,6 +311,8 @@ class MemorySystem: importance = await self._evaluate_importance(qa_content, qa_metadata) if importance < self.importance_threshold: return None + if await self._is_duplicate_long_term(user_id, qa_content): + return None embedding = await self._build_embedding(qa_content) memory_id = str(uuid.uuid4()) @@ -311,97 +327,88 @@ class MemorySystem: if not ok: return None + await self._trim_user_long_term(user_id) return await self.get_long_term(user_id, memory_id) - - async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float: - """评估记忆重要性。""" - if not content or not content.strip(): + + @staticmethod + def _simple_text_score(query: str, text: str) -> float: + query_tokens = [token for token in query.lower().split() if token] + if not query_tokens: return 0.0 + text_lower = text.lower() + hit = sum(1 for token in query_tokens if token in text_lower) + return hit / len(query_tokens) - if self.importance_evaluator: - try: - score = await self.importance_evaluator(content, metadata) - return self._normalize_importance(score) - except Exception as e: - logger.warning(f"importance evaluation failed, fallback to neutral score: {e}") - - # 当模型评估不可用时,使用中性分数作为兜底。 - return 0.5 - async def get_context( self, user_id: str, query: Optional[str] = None, max_short_term: int = 10, - max_long_term: int = 5 + max_long_term: int = 5, ) -> Tuple[List[MemoryItem], List[VectorMemory]]: - """获取上下文(短期 + 长期记忆)。""" - # 获取短期记忆 short_term_memories = self.short_term.get(user_id, limit=max_short_term) - - # 获取相关长期记忆 - long_term_memories = [] - + long_term_memories: List[VectorMemory] = [] + if query and self.use_query_embedding: try: - # 使用向量检索 query_embedding = await self._build_embedding(query) - if query_embedding: - long_term_memories = await self.vector_store.search( - user_id=user_id, - query_embedding=query_embedding, - limit=max_long_term - ) - except Exception as e: - logger.warning(f"向量检索失败,改用重要性检索: {e}") + long_term_memories = await self.vector_store.search( + user_id=user_id, + query_embedding=query_embedding, + limit=max_long_term, + ) + except Exception as exc: + logger.warning(f"Vector search failed, fallback to lexical: {exc}") if query and not long_term_memories: - query_lower = query.lower() - try: - candidates = await self.vector_store.get_all(user_id) - matches = [m for m in candidates if query_lower in m.content.lower()] - matches.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) - long_term_memories = matches[:max_long_term] - except Exception: - pass - - # 濡傛灉鍚戦噺妫€绱㈠け璐ユ垨娌℃湁缁撴灉锛屼娇鐢ㄩ噸瑕佹€ф绱? + candidates = await self.vector_store.get_all(user_id) + scored = [] + for memory in candidates: + score = self._simple_text_score(query, memory.content) + if score <= 0: + continue + combined = score * 0.7 + memory.importance * 0.3 + scored.append((combined, memory)) + scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True) + long_term_memories = [item[1] for item in scored[:max_long_term]] + if not long_term_memories: long_term_memories = await self.vector_store.get_by_importance( user_id=user_id, - limit=max_long_term + limit=max_long_term, ) - - # 更新长期记忆访问记录 + for memory in long_term_memories: await self.vector_store.update_access(memory.id) - - return short_term_memories, long_term_memories - - def format_context( - self, - short_term: List[MemoryItem], - long_term: List[VectorMemory] - ) -> str: - """格式化上下文为文本。""" - context = "" - - if long_term: - context += "## 相关历史记忆\n" - for i, memory in enumerate(long_term, 1): - context += f"{i}. {memory.content}\n" - context += "\n" - - if short_term: - context += "## 最近对话\n" - for memory in short_term: - context += f"- {memory.content}\n" - - return context - async def list_long_term( - self, user_id: str, limit: int = 20 - ) -> List[VectorMemory]: + return short_term_memories, long_term_memories + + def format_context( + self, short_term: List[MemoryItem], long_term: List[VectorMemory] + ) -> str: + lines: List[str] = [] + + if long_term: + lines.append("## 相关历史记忆") + for idx, memory in enumerate(long_term, 1): + lines.append(f"{idx}. {memory.content}") + lines.append("") + + if short_term: + lines.append("## 最近对话") + for item in short_term: + role = str((item.metadata or {}).get("role") or "").strip().lower() + if role == "user": + prefix = "用户" + elif role == "assistant": + prefix = "助手" + else: + prefix = "对话" + lines.append(f"- {prefix}: {item.content}") + + return "\n".join(lines).strip() + + async def list_long_term(self, user_id: str, limit: int = 20) -> List[VectorMemory]: memories = await self.vector_store.get_all(user_id) memories.sort(key=lambda m: m.timestamp, reverse=True) if limit > 0: @@ -422,21 +429,24 @@ class MemorySystem: importance: float = 0.8, metadata: Optional[Dict] = None, ) -> Optional[VectorMemory]: - memory_id = str(uuid.uuid4()) - importance = self._normalize_importance(importance) - embedding = await self._build_embedding(content) + if await self._is_duplicate_long_term(user_id, content): + return None + memory_id = str(uuid.uuid4()) + normalized_importance = self._normalize_importance(importance) + embedding = await self._build_embedding(content) ok = await self._add_vector_memory( memory_id=memory_id, user_id=user_id, content=content, embedding=embedding, - importance=importance, + importance=normalized_importance, metadata=metadata or {}, ) if not ok: return None + await self._trim_user_long_term(user_id) return await self.get_long_term(user_id, memory_id) async def search_long_term( @@ -457,10 +467,15 @@ class MemorySystem: return results all_memories = await self.vector_store.get_all(user_id) - query_lower = query.lower() - matched = [m for m in all_memories if query_lower in m.content.lower()] - matched.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) - return matched[:limit] + scored = [] + for memory in all_memories: + score = self._simple_text_score(query, memory.content) + if score <= 0: + continue + combined = score * 0.7 + memory.importance * 0.3 + scored.append((combined, memory)) + scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True) + return [item[1] for item in scored[:limit]] async def update_long_term( self, @@ -503,7 +518,6 @@ class MemorySystem: ) if not added: return None - return await self.get_long_term(user_id, memory_id) async def delete_long_term(self, user_id: str, memory_id: str) -> bool: @@ -511,15 +525,12 @@ class MemorySystem: if not memory: return False return await self.vector_store.delete(memory_id) - + def clear_short_term(self, user_id: str): - """清除短期记忆。""" self.short_term.clear(user_id) - + async def clear_long_term(self, user_id: str): - """清除长期记忆。""" await self.vector_store.clear_user(user_id) - + async def close(self): - """关闭记忆系统。""" await self.vector_store.close() diff --git a/src/ai/models/anthropic_model.py b/src/ai/models/anthropic_model.py index 3854b0e..059e88b 100644 --- a/src/ai/models/anthropic_model.py +++ b/src/ai/models/anthropic_model.py @@ -1,100 +1,94 @@ """ -Anthropic Claude模型实现 +Anthropic Claude model implementation. """ -from typing import List, Optional, AsyncIterator + +from __future__ import annotations + +from typing import AsyncIterator, List, Optional + from anthropic import AsyncAnthropic + from ..base import BaseAIModel, Message, ModelConfig class AnthropicModel(BaseAIModel): - """Anthropic Claude模型实现""" - + """Anthropic Claude model implementation.""" + def __init__(self, config: ModelConfig): super().__init__(config) self.client = AsyncAnthropic( api_key=config.api_key, base_url=config.api_base, - timeout=config.timeout + timeout=config.timeout, ) - + async def chat( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> Message: - """同步对话""" - # 分离system消息 system_message = None formatted_messages = [] - + for msg in messages: if msg.role == "system": system_message = msg.content else: - formatted_messages.append({ - "role": msg.role, - "content": msg.content - }) - + formatted_messages.append({"role": msg.role, "content": msg.content}) + params = { "model": self.config.model_name, "messages": formatted_messages, "max_tokens": self.config.max_tokens, "temperature": self.config.temperature, } - if system_message: params["system"] = system_message - if tools: params["tools"] = tools - params.update(kwargs) - + response = await self.client.messages.create(**params) - + content = "" tool_calls = [] - for block in response.content: if block.type == "text": content += block.text elif block.type == "tool_use": - tool_calls.append({ - "id": block.id, - "type": "function", - "function": { - "name": block.name, - "arguments": block.input + tool_calls.append( + { + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": block.input, + }, } - }) - + ) + return Message( role="assistant", content=content, - tool_calls=tool_calls if tool_calls else None + tool_calls=tool_calls if tool_calls else None, ) - + async def chat_stream( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> AsyncIterator[str]: - """流式对话""" system_message = None formatted_messages = [] - + for msg in messages: if msg.role == "system": system_message = msg.content else: - formatted_messages.append({ - "role": msg.role, - "content": msg.content - }) - + formatted_messages.append({"role": msg.role, "content": msg.content}) + params = { "model": self.config.model_name, "messages": formatted_messages, @@ -102,19 +96,24 @@ class AnthropicModel(BaseAIModel): "temperature": self.config.temperature, "stream": True, } - if system_message: params["system"] = system_message - if tools: params["tools"] = tools - params.update(kwargs) - + async with self.client.messages.stream(**params) as stream: async for text in stream.text_stream: yield text - + async def embed(self, text: str) -> List[float]: - """文本嵌入(Anthropic不直接提供,需要使用其他服务)""" - raise NotImplementedError("Anthropic不提供嵌入API,请使用OpenAI或其他服务") + raise NotImplementedError( + "Anthropic does not provide embedding API; use OpenAI-compatible embedding model." + ) + + async def close(self): + close_fn = getattr(self.client, "close", None) + if callable(close_fn): + maybe_awaitable = close_fn() + if hasattr(maybe_awaitable, "__await__"): + await maybe_awaitable diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index 40c5707..a40b0e9 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -24,7 +24,7 @@ class OpenAIModel(BaseAIModel): self.logger = logger self._embedding_token_limit: Optional[int] = None - http_client = httpx.AsyncClient( + self._http_client = httpx.AsyncClient( timeout=config.timeout, limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), ) @@ -33,7 +33,7 @@ class OpenAIModel(BaseAIModel): api_key=config.api_key, base_url=config.api_base, timeout=config.timeout, - http_client=http_client, + http_client=self._http_client, ) self._supports_tools = False @@ -590,3 +590,8 @@ class OpenAIModel(BaseAIModel): self.logger.error(f"text preview: {repr(candidate_text[:100])}") self.logger.error(f"full traceback:\n{traceback.format_exc()}") raise + + async def close(self): + """Release network resources.""" + if self._http_client: + await self._http_client.aclose() diff --git a/src/ai/personality.py b/src/ai/personality.py index 0f1d458..caeb940 100644 --- a/src/ai/personality.py +++ b/src/ai/personality.py @@ -1,4 +1,6 @@ -"""Personality system for role-play profiles.""" +"""Personality system for role-play profiles with scope priority.""" + +from __future__ import annotations from dataclasses import dataclass, field from enum import Enum @@ -32,8 +34,6 @@ class PersonalityProfile: custom_instructions: str = "" def to_system_prompt(self) -> str: - """Build plain-text system prompt.""" - traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly" lines = [ "Role Setting", @@ -55,13 +55,29 @@ class PersonalityProfile: class PersonalitySystem: - """Personality management and persistence.""" + """Personality management, persistence and scope resolution.""" - def __init__(self, config_path: Optional[Path] = None): + def __init__( + self, + config_path: Optional[Path] = None, + state_path: Optional[Path] = None, + ): self.config_path = config_path or Path("config/personalities.json") + self.state_path = state_path or self.config_path.with_name("personality_state.json") + self.personalities: Dict[str, PersonalityProfile] = {} - self.current_personality: Optional[PersonalityProfile] = None + self._active_global_key: str = "default" + self._active_user_keys: Dict[str, str] = {} + self._active_group_keys: Dict[str, str] = {} + self._active_session_keys: Dict[str, str] = {} + self._load_personalities() + self._load_state() + self._ensure_valid_active_keys() + + @property + def current_personality(self) -> Optional[PersonalityProfile]: + return self.personalities.get(self._active_global_key) def _dict_to_profile(self, config: Dict) -> PersonalityProfile: trait_names = config.get("traits", []) @@ -83,27 +99,30 @@ class PersonalitySystem: custom_instructions=str(config.get("custom_instructions", "")), ) - def _load_personalities(self): - """Load personality config from disk or create defaults.""" + def _profile_to_dict(self, profile: PersonalityProfile) -> Dict: + return { + "name": profile.name, + "description": profile.description, + "traits": [trait.name for trait in profile.traits], + "speaking_style": profile.speaking_style, + "example_responses": profile.example_responses, + "custom_instructions": profile.custom_instructions, + } + def _load_personalities(self): if self.config_path.exists(): with open(self.config_path, "r", encoding="utf-8") as f: data = json.load(f) for key, config in data.items(): - self.personalities[key] = self._dict_to_profile(config) + if isinstance(config, dict): + self.personalities[key] = self._dict_to_profile(config) - if "default" in self.personalities: - self.current_personality = self.personalities["default"] - elif self.personalities: - first_key = next(iter(self.personalities.keys())) - self.current_personality = self.personalities[first_key] - return + if self.personalities: + return self._create_default_personalities() def _create_default_personalities(self): - """Create and persist built-in default profiles.""" - default = PersonalityProfile( name="Assistant", description="A friendly and practical AI assistant.", @@ -146,87 +165,200 @@ class PersonalitySystem: "tech_expert": tech_expert, "creative": creative, } - self.current_personality = default + self._active_global_key = "default" self._save_personalities() def _save_personalities(self): - """Persist personalities to disk.""" - self.config_path.parent.mkdir(parents=True, exist_ok=True) - data = {} - - for key, profile in self.personalities.items(): - data[key] = { - "name": profile.name, - "description": profile.description, - "traits": [trait.name for trait in profile.traits], - "speaking_style": profile.speaking_style, - "example_responses": profile.example_responses, - "custom_instructions": profile.custom_instructions, - } - + data = { + key: self._profile_to_dict(profile) + for key, profile in self.personalities.items() + } with open(self.config_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - def set_personality(self, key: str) -> bool: - """Switch active personality by key.""" + def _load_state(self): + if not self.state_path.exists(): + return + try: + with open(self.state_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return + + self._active_global_key = str(data.get("global", self._active_global_key)) + self._active_user_keys = { + str(k): str(v) + for k, v in (data.get("user") or {}).items() + if k is not None and v is not None + } + self._active_group_keys = { + str(k): str(v) + for k, v in (data.get("group") or {}).items() + if k is not None and v is not None + } + self._active_session_keys = { + str(k): str(v) + for k, v in (data.get("session") or {}).items() + if k is not None and v is not None + } + + def _save_state(self): + self.state_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "global": self._active_global_key, + "user": self._active_user_keys, + "group": self._active_group_keys, + "session": self._active_session_keys, + } + with open(self.state_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def _ensure_valid_active_keys(self): + if self._active_global_key not in self.personalities: + if "default" in self.personalities: + self._active_global_key = "default" + elif self.personalities: + self._active_global_key = next(iter(self.personalities.keys())) + else: + self._active_global_key = "" + + self._active_user_keys = { + scope_id: key + for scope_id, key in self._active_user_keys.items() + if key in self.personalities + } + self._active_group_keys = { + scope_id: key + for scope_id, key in self._active_group_keys.items() + if key in self.personalities + } + self._active_session_keys = { + scope_id: key + for scope_id, key in self._active_session_keys.items() + if key in self.personalities + } + self._save_state() + + def set_personality( + self, key: str, scope: str = "global", scope_id: Optional[str] = None + ) -> bool: if key not in self.personalities: return False - self.current_personality = self.personalities[key] + scope_normalized = (scope or "global").strip().lower() + if scope_normalized == "global": + self._active_global_key = key + self._save_state() + return True + + if not scope_id: + return False + + if scope_normalized == "user": + self._active_user_keys[scope_id] = key + elif scope_normalized == "group": + self._active_group_keys[scope_id] = key + elif scope_normalized == "session": + self._active_session_keys[scope_id] = key + else: + return False + + self._save_state() return True - def get_system_prompt(self) -> str: - """Get current personality prompt.""" + def get_active_personality( + self, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Optional[PersonalityProfile]: + # Priority: session > group > user > global > default + if session_id: + key = self._active_session_keys.get(session_id) + if key in self.personalities: + return self.personalities[key] - if self.current_personality: - return self.current_personality.to_system_prompt() - return "" + if group_id: + key = self._active_group_keys.get(group_id) + if key in self.personalities: + return self.personalities[key] + + if user_id: + key = self._active_user_keys.get(user_id) + if key in self.personalities: + return self.personalities[key] + + if self._active_global_key in self.personalities: + return self.personalities[self._active_global_key] + + return self.personalities.get("default") + + def get_system_prompt( + self, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> str: + profile = self.get_active_personality( + user_id=user_id, + group_id=group_id, + session_id=session_id, + ) + return profile.to_system_prompt() if profile else "" def add_personality(self, key: str, profile: PersonalityProfile) -> bool: - """Add a new personality profile.""" - key = key.strip() if not key: return False self.personalities[key] = profile - if not self.current_personality: - self.current_personality = profile + if not self._active_global_key: + self._active_global_key = key + self._save_state() self._save_personalities() return True def remove_personality(self, key: str) -> bool: - """Remove a personality profile.""" - if key == "default": return False - if key not in self.personalities: return False - removed_profile = self.personalities[key] del self.personalities[key] + if not self.personalities: + self._create_default_personalities() - if self.current_personality == removed_profile: - if "default" in self.personalities: - self.current_personality = self.personalities["default"] - elif self.personalities: - first_key = next(iter(self.personalities.keys())) - self.current_personality = self.personalities[first_key] - else: - self.current_personality = None + if self._active_global_key == key: + self._active_global_key = ( + "default" + if "default" in self.personalities + else next(iter(self.personalities.keys())) + ) + + self._active_user_keys = { + scope_id: active_key + for scope_id, active_key in self._active_user_keys.items() + if active_key != key + } + self._active_group_keys = { + scope_id: active_key + for scope_id, active_key in self._active_group_keys.items() + if active_key != key + } + self._active_session_keys = { + scope_id: active_key + for scope_id, active_key in self._active_session_keys.items() + if active_key != key + } self._save_personalities() + self._save_state() return True def list_personalities(self) -> List[str]: - """List all personality keys.""" - return sorted(self.personalities.keys()) def get_personality(self, key: str) -> Optional[PersonalityProfile]: - """Get personality by key.""" - return self.personalities.get(key) diff --git a/src/ai/skills/__init__.py b/src/ai/skills/__init__.py deleted file mode 100644 index 40ee6ef..0000000 --- a/src/ai/skills/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Skills系统初始化 -""" -from .base import Skill, SkillsManager, SkillMetadata, create_skill_template - -__all__ = ['Skill', 'SkillsManager', 'SkillMetadata', 'create_skill_template'] diff --git a/src/ai/skills/base.py b/src/ai/skills/base.py deleted file mode 100644 index e9d4878..0000000 --- a/src/ai/skills/base.py +++ /dev/null @@ -1,750 +0,0 @@ -""" -Skills 系统 - 可扩展技能插件框架。 -""" - -from dataclasses import dataclass -import importlib -import inspect -import json -from pathlib import Path -import re -import shutil -import sys -import tempfile -import time -from typing import Any, Callable, Dict, List, Optional, Tuple -import urllib.parse -import urllib.request -import zipfile -import os -import stat - -from src.utils.logger import setup_logger - -logger = setup_logger("SkillsSystem") - - -@dataclass -class SkillMetadata: - """技能元数据。""" - - name: str - version: str - description: str - author: str - dependencies: List[str] - enabled: bool = True - - -class Skill: - """技能基类。""" - - def __init__(self): - self.metadata: Optional[SkillMetadata] = None - self.tools: Dict[str, Callable] = {} - self.manager = None - - async def initialize(self): - """初始化技能。""" - - async def cleanup(self): - """清理技能。""" - - def get_tools(self) -> Dict[str, Callable]: - """获取技能提供的工具。""" - - return self.tools - - def register_tool(self, name: str, func: Callable): - """注册工具。""" - - self.tools[name] = func - - -class SkillsManager: - """技能管理器。""" - - _SKILL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]") - _GITHUB_SHORTCUT_PATTERN = re.compile( - r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+(?:#[A-Za-z0-9_.-]+)?$" - ) - - def __init__(self, skills_dir: Path): - self.skills_dir = skills_dir - self.skills: Dict[str, Skill] = {} - self.skills_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"✅ Skills 目录: {skills_dir}") - - @classmethod - def normalize_skill_key(cls, raw_name: str) -> str: - """将任意输入规范化为可导入的 Python 包名。""" - - key = raw_name.strip().lower().replace("-", "_").replace(" ", "_") - key = cls._SKILL_KEY_PATTERN.sub("_", key) - key = re.sub(r"_+", "_", key).strip("_") - - if not key: - raise ValueError("技能名不能为空") - - if key[0].isdigit(): - key = f"skill_{key}" - - return key - - def _get_skill_path(self, skill_name: str) -> Path: - return self.skills_dir / self.normalize_skill_key(skill_name) - - @staticmethod - def _on_rmtree_error(func, path, exc_info): - """Handle Windows readonly/locked file deletion errors.""" - try: - os.chmod(path, stat.S_IWRITE) - func(path) - except Exception: - # Keep original failure path for upper retry logic. - pass - - def _read_metadata(self, skill_path: Path, fallback_name: str) -> Dict[str, Any]: - metadata_file = skill_path / "skill.json" - if metadata_file.exists(): - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - else: - metadata = {} - - metadata.setdefault("name", fallback_name) - metadata.setdefault("version", "1.0.0") - metadata.setdefault("description", f"{fallback_name} skill") - metadata.setdefault("author", "unknown") - metadata.setdefault("dependencies", []) - metadata.setdefault("enabled", True) - - with open(metadata_file, "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - return metadata - - def _ensure_skill_package_layout(self, skill_path: Path, skill_key: str): - """确保技能目录满足运行最小结构。""" - - skill_path.mkdir(parents=True, exist_ok=True) - - init_file = skill_path / "__init__.py" - if not init_file.exists(): - init_file.write_text("", encoding="utf-8") - - main_file = skill_path / "main.py" - if not main_file.exists(): - template = f'''"""{skill_key} skill""" -from src.ai.skills.base import Skill - - -class {"".join(p.capitalize() for p in skill_key.split("_"))}Skill(Skill): - async def initialize(self): - self.register_tool("ping", self.ping) - - async def ping(self, text: str = "ok") -> str: - return text - - async def cleanup(self): - pass -''' - main_file.write_text(template, encoding="utf-8") - - self._read_metadata(skill_path, skill_key) - - async def load_skill(self, skill_name: str) -> bool: - """加载技能。""" - - try: - skill_name = self.normalize_skill_key(skill_name) - - if skill_name in self.skills: - logger.info(f"✅ 技能已加载: {skill_name}") - return True - - skill_path = self._get_skill_path(skill_name) - if not skill_path.exists(): - logger.error(f"❌ 技能不存在: {skill_name}") - return False - - metadata_file = skill_path / "skill.json" - if not metadata_file.exists(): - logger.error(f"❌ 技能元数据不存在: {skill_name}") - return False - - with open(metadata_file, "r", encoding="utf-8") as f: - metadata_dict = json.load(f) - - metadata = SkillMetadata(**metadata_dict) - - if not metadata.enabled: - logger.info(f"⏸️ 技能已禁用: {skill_name}") - return False - - module_path = f"skills.{skill_name}.main" - importlib.invalidate_caches() - - try: - old_dont_write = sys.dont_write_bytecode - sys.dont_write_bytecode = True - try: - if module_path in sys.modules: - module = importlib.reload(sys.modules[module_path]) - else: - module = importlib.import_module(module_path) - finally: - sys.dont_write_bytecode = old_dont_write - except Exception as exc: - logger.error(f"❌ 无法导入技能模块 {module_path}: {exc}") - return False - - skill_class = None - for _, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, Skill) and obj != Skill: - skill_class = obj - break - - if not skill_class: - logger.error(f"❌ 技能中未找到 Skill 子类: {skill_name}") - return False - - skill = skill_class() - skill.metadata = metadata - skill.manager = self - - await skill.initialize() - self.skills[skill_name] = skill - - logger.info(f"✅ 加载技能: {skill_name} v{metadata.version}") - return True - - except Exception as exc: - logger.error(f"❌ 加载技能失败 {skill_name}: {exc}") - return False - - async def load_all_skills(self): - """加载所有可用技能。""" - - for skill_name in self.list_available_skills(): - await self.load_skill(skill_name) - - async def unload_skill(self, skill_name: str) -> bool: - """仅卸载内存中的技能。""" - - skill_name = self.normalize_skill_key(skill_name) - if skill_name not in self.skills: - return False - - skill = self.skills[skill_name] - await skill.cleanup() - del self.skills[skill_name] - - sys.modules.pop(f"skills.{skill_name}.main", None) - sys.modules.pop(f"skills.{skill_name}", None) - importlib.invalidate_caches() - - logger.info(f"✅ 卸载技能: {skill_name}") - return True - - async def uninstall_skill(self, skill_name: str, delete_files: bool = True) -> bool: - """卸载技能并可选删除文件。""" - - skill_name = self.normalize_skill_key(skill_name) - - if skill_name in self.skills: - await self.unload_skill(skill_name) - - if not delete_files: - return True - - skill_path = self._get_skill_path(skill_name) - if not skill_path.exists(): - return False - - removed = False - for _ in range(3): - try: - shutil.rmtree(skill_path, ignore_errors=False, onerror=self._on_rmtree_error) - except PermissionError: - pass - - if not skill_path.exists(): - removed = True - break - - time.sleep(0.2) - - if not removed: - try: - metadata_file = skill_path / "skill.json" - metadata = {} - if metadata_file.exists(): - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - metadata["enabled"] = False - with open(metadata_file, "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - logger.warning(f"⚠️ 删除目录失败,已软卸载技能: {skill_name}") - return True - except Exception: - return False - - importlib.invalidate_caches() - logger.info(f"✅ 删除技能目录: {skill_name}") - return True - - def get_skill(self, skill_name: str) -> Optional[Skill]: - """获取已加载技能实例。""" - - skill_name = self.normalize_skill_key(skill_name) - return self.skills.get(skill_name) - - def list_skills(self) -> List[str]: - """列出已加载技能。""" - - return sorted(self.skills.keys()) - - def list_available_skills(self) -> List[str]: - """列出可加载技能目录。""" - - if not self.skills_dir.exists(): - return [] - - available: List[str] = [] - for skill_dir in self.skills_dir.iterdir(): - if not skill_dir.is_dir() or skill_dir.name.startswith("_"): - continue - - if (skill_dir / "skill.json").exists() and (skill_dir / "main.py").exists(): - try: - with open(skill_dir / "skill.json", "r", encoding="utf-8") as f: - metadata = json.load(f) - if not metadata.get("enabled", True): - continue - available.append(self.normalize_skill_key(skill_dir.name)) - except ValueError: - continue - except Exception: - continue - - return sorted(set(available)) - - def get_all_tools(self) -> Dict[str, Callable]: - """获取全部技能工具。""" - - all_tools: Dict[str, Callable] = {} - for skill_name, skill in self.skills.items(): - for tool_name, tool_func in skill.get_tools().items(): - all_tools[f"{skill_name}.{tool_name}"] = tool_func - return all_tools - - async def reload_skill(self, skill_name: str) -> bool: - """重载技能。""" - - skill_name = self.normalize_skill_key(skill_name) - if skill_name in self.skills: - await self.unload_skill(skill_name) - return await self.load_skill(skill_name) - - def _parse_github_repo_source( - self, - source: str, - ) -> Optional[Tuple[str, str, str, Optional[str]]]: - """解析 GitHub 仓库 URL,返回 owner/repo/branch/subpath。""" - - try: - parsed = urllib.parse.urlparse(source.strip()) - except Exception: - return None - - if parsed.scheme not in {"http", "https"}: - return None - - if parsed.netloc.lower() != "github.com": - return None - - parts = [part for part in parsed.path.split("/") if part] - if len(parts) < 2: - return None - - owner = parts[0] - repo = parts[1] - if repo.endswith(".git"): - repo = repo[:-4] - - if not owner or not repo: - return None - - branch = "main" - subpath: Optional[str] = None - - if len(parts) >= 4 and parts[2] == "tree": - branch = parts[3] - if len(parts) > 4: - subpath = "/".join(parts[4:]) - elif len(parts) > 2: - # 不支持 /issues、/blob 等深链接 - return None - - return owner, repo, branch, subpath - - def _resolve_network_source( - self, - source: str, - ) -> Tuple[str, Optional[str], Optional[str]]: - """将网络来源解析为下载 URL,并返回安装提示信息。""" - - source = source.strip() - - github_repo = self._parse_github_repo_source(source) - if github_repo: - owner, repo, branch, subpath = github_repo - codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}" - return codeload_url, self.normalize_skill_key(repo), subpath - - if source.startswith(("http://", "https://")): - return source, None, None - - if self._GITHUB_SHORTCUT_PATTERN.match(source): - repo_ref, _, branch = source.partition("#") - owner, repo = repo_ref.split("/", 1) - if repo.endswith(".git"): - repo = repo[:-4] - branch = branch or "main" - codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}" - return codeload_url, self.normalize_skill_key(repo), None - - raise ValueError("source 必须是 URL 或 owner/repo[#branch]") - - def _download_zip(self, url: str, output_zip: Path): - """下载 zip 包到本地。""" - - req = urllib.request.Request(url, headers={"User-Agent": "QQBot-Skills/1.0"}) - with urllib.request.urlopen(req, timeout=30) as resp: - data = resp.read() - output_zip.write_bytes(data) - - def _find_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]: - """在目录中扫描技能候选项。""" - - candidates: List[Tuple[str, Path]] = [] - for metadata_file in root_dir.rglob("skill.json"): - candidate_dir = metadata_file.parent - if not (candidate_dir / "main.py").exists(): - continue - - try: - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - raw_name = str(metadata.get("name") or candidate_dir.name) - except Exception: - raw_name = candidate_dir.name - - try: - skill_key = self.normalize_skill_key(raw_name) - except ValueError: - continue - - candidates.append((skill_key, candidate_dir)) - - uniq: Dict[str, Path] = {} - for key, path in candidates: - uniq[key] = path - - return sorted(uniq.items(), key=lambda x: x[0]) - - def _find_codex_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]: - """在目录中扫描仅包含 SKILL.md 的候选项。""" - - candidates: List[Tuple[str, Path]] = [] - for skill_doc in root_dir.rglob("SKILL.md"): - candidate_dir = skill_doc.parent - if (candidate_dir / "skill.json").exists() and (candidate_dir / "main.py").exists(): - continue - - try: - skill_key = self.normalize_skill_key(candidate_dir.name) - except ValueError: - continue - - candidates.append((skill_key, candidate_dir)) - - uniq: Dict[str, Path] = {} - for key, path in candidates: - uniq[key] = path - - return sorted(uniq.items(), key=lambda x: x[0]) - - def _scope_extract_root(self, extract_root: Path, subpath: Optional[str]) -> Path: - """将解压目录缩小到 repo 子路径(若指定)。""" - - if not subpath: - return extract_root - - clean_subpath = Path(subpath.strip("/\\")) - if str(clean_subpath) == ".": - return extract_root - - direct = extract_root / clean_subpath - if direct.exists(): - return direct - - for child in extract_root.iterdir(): - candidate = child / clean_subpath - if candidate.exists(): - return candidate - - return extract_root - - @staticmethod - def _extract_markdown_title(markdown_content: str) -> str: - for line in markdown_content.splitlines(): - stripped = line.strip() - if stripped.startswith("#"): - return stripped.lstrip("#").strip() - return "" - - def _install_codex_skill_adapter(self, source_dir: Path, target_path: Path, skill_key: str): - """将 Codex SKILL.md 转换为可加载的项目技能。""" - - skill_doc = source_dir / "SKILL.md" - if not skill_doc.exists(): - raise FileNotFoundError(f"SKILL.md not found in {source_dir}") - - content = skill_doc.read_text(encoding="utf-8") - title = self._extract_markdown_title(content) - description = ( - f"Imported from SKILL.md: {title}" if title else f"Imported from SKILL.md ({skill_key})" - ) - - target_path.mkdir(parents=True, exist_ok=True) - (target_path / "SKILL.md").write_text(content, encoding="utf-8") - (target_path / "__init__.py").write_text("", encoding="utf-8") - - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": description, - "author": "imported", - "dependencies": [], - "enabled": True, - } - with open(target_path / "skill.json", "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill" - main_code = f'''"""{skill_key} adapter skill (generated from SKILL.md)""" -from pathlib import Path -from src.ai.skills.base import Skill - - -class {class_name}(Skill): - async def initialize(self): - self.register_tool("read_skill_doc", self.read_skill_doc) - - async def read_skill_doc(self) -> str: - skill_doc = Path(__file__).with_name("SKILL.md") - if not skill_doc.exists(): - return "SKILL.md not found." - return skill_doc.read_text(encoding="utf-8") - - async def cleanup(self): - pass -''' - (target_path / "main.py").write_text(main_code, encoding="utf-8") - - def install_skill_from_source( - self, - source: str, - skill_name: Optional[str] = None, - overwrite: bool = False, - ) -> Tuple[bool, str]: - """从网络或本地源安装技能目录(仅落盘,不自动加载)。""" - - desired_key = self.normalize_skill_key(skill_name) if skill_name else None - - with tempfile.TemporaryDirectory(prefix="qqbot_skill_") as tmp: - tmp_dir = Path(tmp) - extract_root: Optional[Path] = None - source_hint_key: Optional[str] = None - source_subpath: Optional[str] = None - - source_path = Path(source) - if source_path.exists(): - if source_path.is_dir(): - extract_root = source_path - elif source_path.is_file() and source_path.suffix.lower() == ".zip": - with zipfile.ZipFile(source_path, "r") as zf: - zf.extractall(tmp_dir / "extract") - extract_root = tmp_dir / "extract" - else: - return False, "本地 source 仅支持目录或 zip 文件" - else: - try: - url, source_hint_key, source_subpath = self._resolve_network_source(source) - except ValueError as exc: - return False, str(exc) - - download_zip = tmp_dir / "download.zip" - try: - self._download_zip(url, download_zip) - except Exception as exc: - # GitHub 简写默认 main 失败时尝试 master - if "codeload.github.com" in url and url.endswith("/main"): - fallback = url[:-4] + "master" - try: - self._download_zip(fallback, download_zip) - except Exception: - return False, f"下载技能失败: {exc}" - else: - return False, f"下载技能失败: {exc}" - - try: - with zipfile.ZipFile(download_zip, "r") as zf: - zf.extractall(tmp_dir / "extract") - except Exception as exc: - return False, f"解压技能失败: {exc}" - - extract_root = tmp_dir / "extract" - - extract_root = self._scope_extract_root(extract_root, source_subpath) - - candidates = self._find_skill_candidates(extract_root) - use_codex_adapter = False - - selected_key: Optional[str] = None - selected_path: Optional[Path] = None - - if candidates: - if desired_key: - for key, path in candidates: - if key == desired_key: - selected_key, selected_path = key, path - break - - if not selected_path: - names = ", ".join([k for k, _ in candidates]) - return False, f"源中未找到技能 {desired_key},可选: {names}" - else: - if len(candidates) > 1: - names = ", ".join([k for k, _ in candidates]) - return False, f"检测到多个技能,请指定 skill_name。可选: {names}" - selected_key, selected_path = candidates[0] - else: - codex_candidates = self._find_codex_skill_candidates(extract_root) - if not codex_candidates: - return False, "未找到可安装技能(需包含 skill.json 与 main.py,或 SKILL.md)" - - use_codex_adapter = True - if desired_key: - for key, path in codex_candidates: - if key == desired_key: - selected_key, selected_path = key, path - break - - if not selected_path: - if len(codex_candidates) == 1: - selected_key = desired_key - selected_path = codex_candidates[0][1] - else: - names = ", ".join([k for k, _ in codex_candidates]) - return False, f"源中未找到技能 {desired_key},可选: {names}" - else: - if source_hint_key: - for key, path in codex_candidates: - if key == source_hint_key: - selected_key, selected_path = key, path - break - - if not selected_path: - if len(codex_candidates) > 1: - names = ", ".join([k for k, _ in codex_candidates]) - return False, f"检测到多个 SKILL.md 技能,请指定 skill_name。可选: {names}" - selected_key, selected_path = codex_candidates[0] - if source_hint_key: - selected_key = source_hint_key - - assert selected_key is not None and selected_path is not None - target_path = self._get_skill_path(selected_key) - - if target_path.exists(): - if not overwrite: - return False, f"技能已存在: {selected_key}" - shutil.rmtree(target_path) - - if use_codex_adapter: - self._install_codex_skill_adapter(selected_path, target_path, selected_key) - else: - shutil.copytree( - selected_path, - target_path, - ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".git", ".github"), - ) - self._ensure_skill_package_layout(target_path, selected_key) - - importlib.invalidate_caches() - - logger.info(f"✅ 安装技能成功: {selected_key} <- {source}") - return True, selected_key - - -def create_skill_template( - skill_name: str, - output_dir: Path, - description: str = "技能描述", - author: str = "QQBot", -): - """创建技能模板。""" - - skill_key = SkillsManager.normalize_skill_key(skill_name) - skill_dir = output_dir / skill_key - skill_dir.mkdir(parents=True, exist_ok=True) - - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": description, - "author": author, - "dependencies": [], - "enabled": True, - } - - with open(skill_dir / "skill.json", "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill" - main_code = f'''"""{skill_key} skill""" -from src.ai.skills.base import Skill - - -class {class_name}(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"{skill_key} 收到: {{text}}" - - async def cleanup(self): - pass -''' - - with open(skill_dir / "main.py", "w", encoding="utf-8") as f: - f.write(main_code) - - with open(skill_dir / "__init__.py", "w", encoding="utf-8") as f: - f.write("") - - readme = f"""# {skill_key} - -## 描述 -{description} - -## 工具 -- example_tool(text) -""" - - with open(skill_dir / "README.md", "w", encoding="utf-8") as f: - f.write(readme) - - logger.info(f"✅ 创建技能模板: {skill_dir}") diff --git a/src/ai/task_manager.py b/src/ai/task_manager.py deleted file mode 100644 index 2a0d45e..0000000 --- a/src/ai/task_manager.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -长任务管理器 - 处理需要多步骤的复杂任务 -""" -import asyncio -from typing import List, Dict, Optional, Callable, Any -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -import json -from pathlib import Path -import uuid - - -class TaskStatus(Enum): - """任务状态""" - PENDING = "pending" - RUNNING = "running" - PAUSED = "paused" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@dataclass -class TaskStep: - """任务步骤""" - step_id: str - description: str - action: str - params: Dict[str, Any] - status: TaskStatus = TaskStatus.PENDING - result: Optional[Any] = None - error: Optional[str] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - def to_dict(self) -> Dict: - return { - 'step_id': self.step_id, - 'description': self.description, - 'action': self.action, - 'params': self.params, - 'status': self.status.value, - 'result': self.result, - 'error': self.error, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None - } - - -@dataclass -class LongTask: - """长任务""" - task_id: str - user_id: str - title: str - description: str - steps: List[TaskStep] = field(default_factory=list) - status: TaskStatus = TaskStatus.PENDING - created_at: datetime = field(default_factory=datetime.now) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - progress: float = 0.0 - metadata: Dict = field(default_factory=dict) - - def to_dict(self) -> Dict: - return { - 'task_id': self.task_id, - 'user_id': self.user_id, - 'title': self.title, - 'description': self.description, - 'steps': [step.to_dict() for step in self.steps], - 'status': self.status.value, - 'created_at': self.created_at.isoformat(), - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'progress': self.progress, - 'metadata': self.metadata - } - - @classmethod - def from_dict(cls, data: Dict) -> 'LongTask': - steps = [ - TaskStep( - step_id=s['step_id'], - description=s['description'], - action=s['action'], - params=s['params'], - status=TaskStatus(s['status']), - result=s.get('result'), - error=s.get('error'), - started_at=datetime.fromisoformat(s['started_at']) if s.get('started_at') else None, - completed_at=datetime.fromisoformat(s['completed_at']) if s.get('completed_at') else None - ) - for s in data['steps'] - ] - - return cls( - task_id=data['task_id'], - user_id=data['user_id'], - title=data['title'], - description=data['description'], - steps=steps, - status=TaskStatus(data['status']), - created_at=datetime.fromisoformat(data['created_at']), - started_at=datetime.fromisoformat(data['started_at']) if data.get('started_at') else None, - completed_at=datetime.fromisoformat(data['completed_at']) if data.get('completed_at') else None, - progress=data.get('progress', 0.0), - metadata=data.get('metadata', {}) - ) - - -class LongTaskManager: - """长任务管理器""" - - def __init__(self, storage_path: Path): - self.storage_path = storage_path - self.tasks: Dict[str, LongTask] = {} - self.action_handlers: Dict[str, Callable] = {} - self.running_tasks: Dict[str, asyncio.Task] = {} - self._load() - - def _load(self): - """加载任务""" - if self.storage_path.exists(): - with open(self.storage_path, 'r', encoding='utf-8') as f: - data = json.load(f) - for task_data in data: - task = LongTask.from_dict(task_data) - self.tasks[task.task_id] = task - - def _save(self): - """保存任务""" - self.storage_path.parent.mkdir(parents=True, exist_ok=True) - data = [task.to_dict() for task in self.tasks.values()] - with open(self.storage_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - def register_action(self, action_name: str, handler: Callable): - """注册动作处理器""" - self.action_handlers[action_name] = handler - - def create_task( - self, - user_id: str, - title: str, - description: str, - steps: List[Dict], - metadata: Optional[Dict] = None - ) -> str: - """创建任务""" - task_id = str(uuid.uuid4()) - - task_steps = [ - TaskStep( - step_id=str(uuid.uuid4()), - description=step['description'], - action=step['action'], - params=step.get('params', {}) - ) - for step in steps - ] - - task = LongTask( - task_id=task_id, - user_id=user_id, - title=title, - description=description, - steps=task_steps, - metadata=metadata or {} - ) - - self.tasks[task_id] = task - self._save() - - return task_id - - async def execute_task( - self, - task_id: str, - progress_callback: Optional[Callable[[str, float, str], None]] = None - ) -> bool: - """执行任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - - if task.status == TaskStatus.RUNNING: - return False - - task.status = TaskStatus.RUNNING - task.started_at = datetime.now() - self._save() - - try: - total_steps = len(task.steps) - - for i, step in enumerate(task.steps): - # 检查是否被取消 - if task.status == TaskStatus.CANCELLED: - break - - step.status = TaskStatus.RUNNING - step.started_at = datetime.now() - - # 执行步骤 - try: - handler = self.action_handlers.get(step.action) - if not handler: - raise ValueError(f"未找到动作处理器: {step.action}") - - result = await handler(**step.params) - step.result = result - step.status = TaskStatus.COMPLETED - - except Exception as e: - step.error = str(e) - step.status = TaskStatus.FAILED - task.status = TaskStatus.FAILED - self._save() - return False - - finally: - step.completed_at = datetime.now() - - # 更新进度 - task.progress = (i + 1) / total_steps - self._save() - - if progress_callback: - await progress_callback( - task_id, - task.progress, - f"完成步骤 {i+1}/{total_steps}: {step.description}" - ) - - task.status = TaskStatus.COMPLETED - task.completed_at = datetime.now() - task.progress = 1.0 - self._save() - - return True - - except Exception as e: - task.status = TaskStatus.FAILED - self._save() - raise e - - async def start_task( - self, - task_id: str, - progress_callback: Optional[Callable[[str, float, str], None]] = None - ): - """启动任务(异步)""" - if task_id in self.running_tasks: - return - - async def run(): - try: - await self.execute_task(task_id, progress_callback) - finally: - if task_id in self.running_tasks: - del self.running_tasks[task_id] - - self.running_tasks[task_id] = asyncio.create_task(run()) - - def pause_task(self, task_id: str) -> bool: - """暂停任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - if task.status == TaskStatus.RUNNING: - task.status = TaskStatus.PAUSED - self._save() - return True - - return False - - def cancel_task(self, task_id: str) -> bool: - """取消任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - task.status = TaskStatus.CANCELLED - self._save() - - # 取消正在运行的任务 - if task_id in self.running_tasks: - self.running_tasks[task_id].cancel() - - return True - - def get_task(self, task_id: str) -> Optional[LongTask]: - """获取任务""" - return self.tasks.get(task_id) - - def get_user_tasks(self, user_id: str) -> List[LongTask]: - """获取用户的所有任务""" - return [ - task for task in self.tasks.values() - if task.user_id == user_id - ] - - def get_task_status(self, task_id: str) -> Optional[Dict]: - """获取任务状态""" - task = self.get_task(task_id) - if not task: - return None - - completed_steps = sum(1 for step in task.steps if step.status == TaskStatus.COMPLETED) - total_steps = len(task.steps) - - return { - 'task_id': task.task_id, - 'title': task.title, - 'status': task.status.value, - 'progress': task.progress, - 'completed_steps': completed_steps, - 'total_steps': total_steps, - 'current_step': next( - (step.description for step in task.steps if step.status == TaskStatus.RUNNING), - None - ) - } diff --git a/src/ai/vector_store/json_store.py b/src/ai/vector_store/json_store.py index 87fea5b..7abd2b7 100644 --- a/src/ai/vector_store/json_store.py +++ b/src/ai/vector_store/json_store.py @@ -1,64 +1,78 @@ """ -JSON文件存储实现(向后兼容) +JSON-backed vector store implementation. """ + +from __future__ import annotations + +import asyncio import json import uuid -from typing import List, Dict, Optional -from pathlib import Path from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + import numpy as np -from .base import VectorStore, VectorMemory + +from .base import VectorMemory, VectorStore from src.utils.logger import setup_logger -logger = setup_logger('JSONStore') +logger = setup_logger("JSONStore") class JSONVectorStore(VectorStore): - """JSON文件存储实现(向后兼容旧版本)""" - + """JSON file storage implementation.""" + def __init__(self, storage_path: Path): - """初始化JSON存储""" self.storage_path = storage_path - self.memories: Dict[str, List[VectorMemory]] = {} # user_id -> List[VectorMemory] + self.memories: Dict[str, List[VectorMemory]] = {} + self._lock = asyncio.Lock() self._load() - logger.info(f"✅ JSON存储初始化: {storage_path}") - + logger.info(f"JSON storage initialized: {storage_path}") + def _load(self): - """加载记忆""" - if self.storage_path.exists(): - try: - with open(self.storage_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - for user_id, items in data.items(): - self.memories[user_id] = [] - for item in items: - # 兼容旧格式 - if 'id' not in item: - item['id'] = str(uuid.uuid4()) - - memory = VectorMemory.from_dict(item) - self.memories[user_id].append(memory) - - logger.info(f"加载了 {sum(len(v) for v in self.memories.values())} 条记忆") - except Exception as e: - logger.error(f"加载记忆失败: {e}") - self.memories = {} - - def _save(self): - """保存记忆""" + if not self.storage_path.exists(): + return + try: - self.storage_path.parent.mkdir(parents=True, exist_ok=True) - data = { - user_id: [memory.to_dict() for memory in memories] - for user_id, memories in self.memories.items() - } - - with open(self.storage_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"保存记忆失败: {e}") - + with open(self.storage_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception as exc: + logger.error(f"Failed to load memory file: {exc}") + self.memories = {} + return + + loaded = 0 + memories: Dict[str, List[VectorMemory]] = {} + for user_id, items in (data or {}).items(): + if not isinstance(items, list): + continue + + normalized: List[VectorMemory] = [] + for item in items: + if not isinstance(item, dict): + continue + if "id" not in item: + item["id"] = str(uuid.uuid4()) + try: + normalized.append(VectorMemory.from_dict(item)) + loaded += 1 + except Exception: + continue + + memories[str(user_id)] = normalized + + self.memories = memories + logger.info(f"Loaded {loaded} memories from JSON store") + + def _save_locked(self): + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + data = { + user_id: [memory.to_dict() for memory in user_memories] + for user_id, user_memories in self.memories.items() + } + with open(self.storage_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + async def add( self, id: str, @@ -66,9 +80,8 @@ class JSONVectorStore(VectorStore): content: str, embedding: List[float], importance: float, - metadata: Optional[Dict] = None + metadata: Optional[Dict] = None, ) -> bool: - """添加记忆""" try: memory = VectorMemory( id=id, @@ -79,120 +92,108 @@ class JSONVectorStore(VectorStore): timestamp=datetime.now(), metadata=metadata or {}, access_count=0, - last_access=None + last_access=None, ) - - if user_id not in self.memories: - self.memories[user_id] = [] - - self.memories[user_id].append(memory) - self._save() - - logger.debug(f"添加记忆: {id} (用户: {user_id})") + async with self._lock: + self.memories.setdefault(user_id, []).append(memory) + self._save_locked() return True - except Exception as e: - logger.error(f"添加记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to add memory: {exc}") return False - + async def search( self, user_id: str, query_embedding: List[float], limit: int = 5, - min_importance: float = 0.3 + min_importance: float = 0.3, ) -> List[VectorMemory]: - """搜索相似记忆""" - if user_id not in self.memories: + async with self._lock: + source = list(self.memories.get(user_id, [])) + + candidates = [m for m in source if m.importance >= min_importance] + if not candidates: return [] - - memories = self.memories[user_id] - - # 过滤重要性 - memories = [m for m in memories if m.importance >= min_importance] - - if not memories: - return [] - - # 使用向量相似度排序 + scored_memories = [] - for memory in memories: + for memory in candidates: if memory.embedding: similarity = self._cosine_similarity(query_embedding, memory.embedding) - scored_memories.append((similarity, memory)) - + if similarity is not None: + scored_memories.append((similarity, memory)) + if not scored_memories: - # 如果没有嵌入向量,按重要性排序 return await self.get_by_importance(user_id, limit, min_importance) - + scored_memories.sort(reverse=True, key=lambda x: x[0]) - return [m for _, m in scored_memories[:limit]] - + return [memory for _, memory in scored_memories[:limit]] + async def get_by_importance( self, user_id: str, limit: int = 5, - min_importance: float = 0.3 + min_importance: float = 0.3, ) -> List[VectorMemory]: - """按重要性获取记忆""" - if user_id not in self.memories: - return [] - - memories = [m for m in self.memories[user_id] if m.importance >= min_importance] + async with self._lock: + source = list(self.memories.get(user_id, [])) + + memories = [m for m in source if m.importance >= min_importance] memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) return memories[:limit] - - def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: - """计算余弦相似度""" - vec1 = np.array(vec1) - vec2 = np.array(vec2) - return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) - + + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> Optional[float]: + arr1 = np.array(vec1, dtype=float) + arr2 = np.array(vec2, dtype=float) + denom = np.linalg.norm(arr1) * np.linalg.norm(arr2) + if denom == 0: + return None + return float(np.dot(arr1, arr2) / denom) + async def update_access(self, memory_id: str) -> bool: - """更新访问记录""" try: - for memories in self.memories.values(): - for memory in memories: - if memory.id == memory_id: - memory.access_count += 1 - memory.last_access = datetime.now() - self._save() - return True + async with self._lock: + for memories in self.memories.values(): + for memory in memories: + if memory.id == memory_id: + memory.access_count += 1 + memory.last_access = datetime.now() + self._save_locked() + return True return False - except Exception as e: - logger.error(f"更新访问记录失败: {e}") + except Exception as exc: + logger.error(f"Failed to update access: {exc}") return False - + async def delete(self, memory_id: str) -> bool: - """删除记忆""" try: - for user_id, memories in self.memories.items(): - for i, memory in enumerate(memories): - if memory.id == memory_id: - del self.memories[user_id][i] - self._save() - return True + async with self._lock: + for user_id, memories in self.memories.items(): + for idx, memory in enumerate(memories): + if memory.id == memory_id: + del self.memories[user_id][idx] + self._save_locked() + return True return False - except Exception as e: - logger.error(f"删除记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to delete memory: {exc}") return False - + async def get_all(self, user_id: str) -> List[VectorMemory]: - """获取用户所有记忆""" - return self.memories.get(user_id, []) - + async with self._lock: + return list(self.memories.get(user_id, [])) + async def clear_user(self, user_id: str) -> bool: - """清除用户所有记忆""" try: - if user_id in self.memories: - del self.memories[user_id] - self._save() - logger.info(f"清除用户记忆: {user_id}") + async with self._lock: + self.memories.pop(user_id, None) + self._save_locked() return True - except Exception as e: - logger.error(f"清除用户记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to clear user memories: {exc}") return False - + async def close(self): - """关闭连接""" - self._save() - logger.info("JSON存储已关闭") + async with self._lock: + self._save_locked() diff --git a/src/core/bot.py b/src/core/bot.py index f05f9e2..e51208e 100644 --- a/src/core/bot.py +++ b/src/core/bot.py @@ -1,105 +1,71 @@ """ -QQ机器人主程序 -基于官方SDK: https://github.com/tencent-connect/botpy -官方文档: https://bot.q.qq.com/wiki/develop/api-v2/ +QQ bot application entry module. """ + +from __future__ import annotations + import botpy from botpy.message import Message -from src.core.config import Config -from src.utils.logger import setup_logger -from src.handlers.message_handler_ai import MessageHandler -logger = setup_logger('QQBot') +from src.core.config import Config +from src.handlers.message_handler_ai import MessageHandler +from src.utils.logger import setup_logger + +logger = setup_logger("QQBot") def build_intents() -> botpy.Intents: - """ - 构建最小可用的 intents。 - - - public_guild_messages: 频道公域 @机器人 消息 - - public_messages: 群聊@ + C2C 私聊(好友单聊)消息 - """ intents = botpy.Intents.none() intents.public_guild_messages = True - # 新版 botpy 中,QQ 群聊@ / C2C 私聊依赖 public_messages(GROUP_AND_C2C_EVENT)。 if hasattr(intents, "public_messages"): intents.public_messages = True - logger.info("✅ 已启用 public_messages(支持群聊@与 C2C 私聊)") + logger.info("Enabled public_messages for group/C2C events") else: - logger.warning("⚠️ 当前 botpy 不支持 public_messages,可能无法接收 C2C 私聊事件") + logger.warning("Current botpy version does not expose public_messages") return intents class MyClient(botpy.Client): - """QQ机器人客户端""" - + """QQ bot client wrapper.""" + def __init__(self, intents: botpy.Intents): super().__init__(intents=intents) self.message_handler = MessageHandler(self) - + async def on_ready(self): - """机器人启动完成事件""" - logger.info(f"🤖 机器人已启动: {self.robot.name} (ID: {self.robot.id})") - + logger.info(f"Bot is ready: {self.robot.name} (ID: {self.robot.id})") + async def on_at_message_create(self, message: Message): - """处理@机器人的消息(频道公域消息)""" await self.message_handler.handle_at_message(message) - + async def on_message_create(self, message: Message): - """处理普通消息(需要私域权限)""" await self.message_handler.handle_at_message(message) - + async def on_direct_message_create(self, message: Message): - """处理私信消息""" await self.message_handler.handle_at_message(message) - + async def on_group_at_message_create(self, message: Message): - """处理群聊@消息""" await self.message_handler.handle_at_message(message) - + async def on_c2c_message_create(self, message: Message): - """处理C2C消息(单聊)""" await self.message_handler.handle_at_message(message) - - async def on_guild_create(self, guild): - """机器人加入频道事件""" - logger.info(f"➕ 加入频道: {guild.name} (ID: {guild.id})") - - async def on_guild_delete(self, guild): - """机器人离开频道事件""" - logger.info(f"➖ 离开频道: {guild.name} (ID: {guild.id})") - + async def on_error(self, error): - """错误处理""" - logger.error(f"❌ 发生错误: {error}") + logger.error(f"Bot error: {error}", exc_info=True) + + async def on_close(self): + ai_client = getattr(self.message_handler, "ai_client", None) + if ai_client: + try: + await ai_client.close() + except Exception as exc: + logger.warning(f"Failed to close AI resources cleanly: {exc}") def main(): - """主函数""" - try: - # 验证配置 - Config.validate() - logger.info("✅ 配置验证通过") - - # 创建机器人实例 - 使用最小可用 intents,避免 disallowed intents(4014) - intents = build_intents() - logger.info("✅ Intents 配置完成(最小权限模式)") - - client = MyClient(intents=intents) - - # 启动机器人 - logger.info("🚀 正在启动机器人...") - client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) - - except ValueError as e: - logger.error(f"❌ 配置错误: {e}") - logger.error("请检查 .env 文件配置") - except Exception as e: - logger.error(f"❌ 启动失败: {e}") - raise - - -if __name__ == "__main__": - main() + Config.validate() + intents = build_intents() + client = MyClient(intents=intents) + client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) diff --git a/src/core/config.py b/src/core/config.py index b4a6591..8c18007 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,72 +1,123 @@ """ -QQ机器人配置管理模块 +Centralized project configuration. """ + +from __future__ import annotations + import os -from typing import Optional +from typing import Optional, Set + from dotenv import load_dotenv -# 加载环境变量 load_dotenv() def _read_env(name: str, default: Optional[str] = None) -> Optional[str]: - """ - 读取并清洗环境变量。 - - 去除首尾空白 - - 空字符串视为未设置 - - 以 # 开头的值视为注释占位,视为未设置 - """ value = os.getenv(name) if value is None: return default - + value = value.strip() - if not value or value.startswith('#'): + if not value or value.startswith("#"): return default - return value +def _read_bool(name: str, default: bool) -> bool: + raw = _read_env(name, None) + if raw is None: + return default + lowered = raw.lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def _read_int(name: str, default: int) -> int: + raw = _read_env(name, None) + if raw is None: + return default + try: + return int(raw) + except ValueError: + return default + + +def _read_float(name: str, default: float) -> float: + raw = _read_env(name, None) + if raw is None: + return default + try: + return float(raw) + except ValueError: + return default + + +def _read_csv_set(name: str) -> Set[str]: + raw = _read_env(name, "") or "" + return {item.strip() for item in raw.split(",") if item.strip()} + + class Config: - """机器人配置类""" - - # 机器人基本信息 - BOT_APPID = _read_env('BOT_APPID', '') or '' - BOT_SECRET = _read_env('BOT_SECRET', '') or '' - - # 日志配置 - LOG_LEVEL = _read_env('LOG_LEVEL', 'INFO') or 'INFO' - - # 沙箱模式 - SANDBOX_MODE = os.getenv('SANDBOX_MODE', 'False').lower() == 'true' - - # AI配置 - AI_PROVIDER = _read_env('AI_PROVIDER', 'openai') or 'openai' - AI_MODEL = _read_env('AI_MODEL', 'gpt-4') or 'gpt-4' - AI_API_KEY = _read_env('AI_API_KEY', '') or '' - AI_API_BASE = _read_env('AI_API_BASE', None) - - # AI嵌入模型配置(用于RAG) - AI_EMBED_PROVIDER = _read_env('AI_EMBED_PROVIDER', 'openai') or 'openai' - AI_EMBED_MODEL = _read_env('AI_EMBED_MODEL', 'text-embedding-3-small') or 'text-embedding-3-small' - AI_EMBED_API_KEY = _read_env('AI_EMBED_API_KEY', None) # 留空则使用 AI_API_KEY - AI_EMBED_API_BASE = _read_env('AI_EMBED_API_BASE', None) # 留空则使用 AI_API_BASE - - # 向量数据库配置 - AI_USE_VECTOR_DB = os.getenv('AI_USE_VECTOR_DB', 'true').lower() == 'true' - + """Application runtime configuration.""" + + BOT_APPID = _read_env("BOT_APPID", "") or "" + BOT_SECRET = _read_env("BOT_SECRET", "") or "" + + ENV = _read_env("APP_ENV", "dev") or "dev" + LOG_LEVEL = _read_env("LOG_LEVEL", "INFO") or "INFO" + LOG_FORMAT = (_read_env("LOG_FORMAT", "text") or "text").lower() + + SANDBOX_MODE = _read_bool("SANDBOX_MODE", False) + + AI_PROVIDER = _read_env("AI_PROVIDER", "openai") or "openai" + AI_MODEL = _read_env("AI_MODEL", "gpt-4") or "gpt-4" + AI_API_KEY = _read_env("AI_API_KEY", "") or "" + AI_API_BASE = _read_env("AI_API_BASE", None) + + AI_EMBED_PROVIDER = _read_env("AI_EMBED_PROVIDER", "openai") or "openai" + AI_EMBED_MODEL = ( + _read_env("AI_EMBED_MODEL", "text-embedding-3-small") + or "text-embedding-3-small" + ) + AI_EMBED_API_KEY = _read_env("AI_EMBED_API_KEY", None) + AI_EMBED_API_BASE = _read_env("AI_EMBED_API_BASE", None) + + AI_USE_VECTOR_DB = _read_bool("AI_USE_VECTOR_DB", True) + AI_USE_QUERY_EMBEDDING = _read_bool("AI_USE_QUERY_EMBEDDING", False) + + # `user`: one memory bucket per user + # `session`: memory bucket scoped to chat session (user + channel/group) + AI_MEMORY_SCOPE = (_read_env("AI_MEMORY_SCOPE", "session") or "session").lower() + + AI_CHAT_RETRIES = max(0, _read_int("AI_CHAT_RETRIES", 1)) + AI_CHAT_RETRY_BACKOFF_SECONDS = max( + 0.0, _read_float("AI_CHAT_RETRY_BACKOFF_SECONDS", 0.8) + ) + + MESSAGE_DEDUP_SECONDS = max(1, _read_int("MESSAGE_DEDUP_SECONDS", 30)) + MESSAGE_DEDUP_MAX_SIZE = max(128, _read_int("MESSAGE_DEDUP_MAX_SIZE", 4096)) + + BOT_ADMIN_IDS = _read_csv_set("BOT_ADMIN_IDS") + @classmethod - def validate(cls): - """验证配置是否完整""" + def is_admin(cls, user_id: Optional[str]) -> bool: + if not cls.BOT_ADMIN_IDS: + return True + if not user_id: + return False + return str(user_id) in cls.BOT_ADMIN_IDS + + @classmethod + def validate(cls) -> bool: if not cls.BOT_APPID: raise ValueError("BOT_APPID 未配置") if not cls.BOT_SECRET: raise ValueError("BOT_SECRET 未配置") - - # AI配置验证(可选) - if cls.AI_API_KEY: - print(f"✅ AI配置: {cls.AI_PROVIDER}/{cls.AI_MODEL}") - else: - print("⚠️ AI_API_KEY 未设置,AI功能将不可用") - + + if cls.AI_MEMORY_SCOPE not in {"user", "session"}: + raise ValueError("AI_MEMORY_SCOPE 仅支持 user 或 session") + return True diff --git a/src/handlers/__init__.py b/src/handlers/__init__.py index 0e5c565..15e0907 100644 --- a/src/handlers/__init__.py +++ b/src/handlers/__init__.py @@ -1,6 +1,5 @@ -""" -消息处理模块 -""" -from .message_handler import MessageHandler +"""Message handlers.""" -__all__ = ['MessageHandler'] +from .message_handler_ai import MessageHandler + +__all__ = ["MessageHandler"] diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 55da2ef..874b742 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -1,45 +1,54 @@ """ -集成 AI 能力的 QQ 消息处理器。 +QQ message handler with AI, memory and persona capabilities. """ +from __future__ import annotations + import asyncio import json -from pathlib import Path import re -from typing import Any, Dict, Optional +import time +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Tuple import httpx - -from botpy.message import Message +try: + from botpy.message import Message +except Exception: # pragma: no cover - used for test environments without botpy + class Message: # type: ignore[no-redef] + pass from src.ai import AIClient from src.ai.base import ModelConfig, ModelProvider -from src.ai.mcp import MCPManager -from src.ai.mcp.servers import FileSystemMCPServer from src.ai.personality import PersonalityProfile, PersonalityTrait -from src.ai.skills import SkillsManager from src.core.config import Config from src.utils.logger import setup_logger logger = setup_logger("MessageHandler") +@dataclass +class ChatScope: + user_id: str + group_id: Optional[str] + session_id: str + memory_key: str + + class MessageHandler: - """消息处理器(集成 AI 功能)。""" - MODEL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]") - + MENTION_PATTERNS = [re.compile(r"<@!\d+>"), re.compile(r"<@\d+>")] MARKDOWN_PATTERNS = [ (re.compile(r"```[\s\S]*?```"), lambda m: m.group(0).replace("```", "")), (re.compile(r"`([^`]+)`"), r"\1"), (re.compile(r"\*\*([^*]+)\*\*"), r"\1"), (re.compile(r"\*([^*]+)\*"), r"\1"), (re.compile(r"__([^_]+)__"), r"\1"), - # Avoid stripping underscores inside identifiers like model keys. (re.compile(r"(?\s?", re.MULTILINE), ""), - # Keep link label only to avoid QQ URL-blocking in private messages. (re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1"), (re.compile(r"^[-*]\s+", re.MULTILINE), "- "), (re.compile(r"^\d+\.\s+", re.MULTILINE), "- "), @@ -52,82 +61,16 @@ class MessageHandler: def __init__(self, bot): self.bot = bot - self.ai_client = None - self.skills_manager = None - self.mcp_manager = None + self.ai_client: Optional[AIClient] = None self.model_profiles_path = Path("config/models.json") self.model_profiles: Dict[str, Dict[str, Any]] = {} self.active_model_key = "default" self._ai_initialized = False self._init_lock = asyncio.Lock() - @staticmethod - def _get_user_id(message: Message) -> str: - author = getattr(message, "author", None) - if not author: - return "unknown" - - return ( - getattr(author, "id", None) - or getattr(author, "user_openid", None) - or getattr(author, "member_openid", None) - or getattr(author, "union_openid", None) - or "unknown" - ) - - @staticmethod - def _build_skills_usage(command_name: str = "/skills") -> str: - return ( - "技能命令:\n" - f"{command_name} 或 {command_name} list\n" - f"{command_name} install [skill_name]\n" - f" source 支持:本地技能名、URL、owner/repo、owner/repo#branch、GitHub 仓库 URL(.git)\n" - f"{command_name} uninstall \n" - f"{command_name} reload " - ) - - @staticmethod - def _build_personality_usage() -> str: - return ( - "人设命令:\n" - "/personality\n" - "/personality list\n" - "/personality set \n" - "/personality add \n" - " 兼容旧格式: 或 管道 name|description|speaking_style|TRAIT1,TRAIT2|custom_instructions\n" - "/personality remove " - ) - - @staticmethod - def _build_models_usage(command_name: str = "/models") -> str: - return ( - "模型命令:\n" - f"{command_name} 或 {command_name} list\n" - f"{command_name} current\n" - f"{command_name} add (保持 provider/api_base/api_key 不变)\n" - f"{command_name} add [api_base]\n" - f"{command_name} add \n" - " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n" - f"{command_name} switch \n" - f"{command_name} remove " - ) - - @staticmethod - def _build_memory_usage(command_name: str = "/memory") -> str: - return ( - "记忆命令:\n" - f"{command_name} 或 {command_name} list [limit]\n" - f"{command_name} get \n" - f"{command_name} add \n" - f"{command_name} add \n" - f"{command_name} update \n" - f"{command_name} delete \n" - f"{command_name} search [limit]\n" - "/clear (默认等价 /clear short)\n" - "/clear short\n" - "/clear long\n" - "/clear all" - ) + self._seen_messages: OrderedDict[str, float] = OrderedDict() + self._dedup_window_seconds = Config.MESSAGE_DEDUP_SECONDS + self._dedup_max_size = Config.MESSAGE_DEDUP_MAX_SIZE @staticmethod def _provider_map() -> Dict[str, ModelProvider]: @@ -139,83 +82,34 @@ class MessageHandler: "siliconflow": ModelProvider.OPENAI, } - @classmethod - def _normalize_model_key(cls, raw_key: str) -> str: - key = raw_key.strip().lower().replace("-", "_").replace(" ", "_") - key = cls.MODEL_KEY_PATTERN.sub("_", key) - key = re.sub(r"_+", "_", key).strip("_") - if not key: - raise ValueError("模型 key 不能为空") - if key[0].isdigit(): - key = f"model_{key}" - return key + @staticmethod + def _get_user_id(message: Message) -> str: + author = getattr(message, "author", None) + if not author: + return "unknown" + return ( + getattr(author, "id", None) + or getattr(author, "user_openid", None) + or getattr(author, "member_openid", None) + or getattr(author, "union_openid", None) + or "unknown" + ) @staticmethod - def _compact_model_key(raw_key: str) -> str: - return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower()) + def _get_group_id(message: Message) -> Optional[str]: + return ( + getattr(message, "group_openid", None) + or getattr(message, "group_id", None) + or getattr(message, "channel_id", None) + or getattr(message, "guild_id", None) + ) - def _ordered_model_keys(self) -> list[str]: - return sorted(self.model_profiles.keys()) - - def _resolve_model_selector(self, selector: str) -> str: - raw = (selector or "").strip() - if not raw: - raise ValueError("模型 key 不能为空") - - ordered_keys = self._ordered_model_keys() - if raw.isdigit(): - index = int(raw) - if index < 1 or index > len(ordered_keys): - raise ValueError( - f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}" - ) - return ordered_keys[index - 1] - - if raw in self.model_profiles: - return raw - - normalized_selector: Optional[str] - try: - normalized_selector = self._normalize_model_key(raw) - except ValueError: - normalized_selector = None - - if normalized_selector and normalized_selector in self.model_profiles: - return normalized_selector - - normalized_candidates: Dict[str, list[str]] = {} - compact_candidates: Dict[str, list[str]] = {} - for key in ordered_keys: - try: - normalized_key = self._normalize_model_key(key) - except ValueError: - normalized_key = key.strip().lower() - normalized_candidates.setdefault(normalized_key, []).append(key) - compact_candidates.setdefault( - self._compact_model_key(normalized_key), [] - ).append(key) - - if normalized_selector and normalized_selector in normalized_candidates: - matches = normalized_candidates[normalized_selector] - if len(matches) == 1: - return matches[0] - raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") - - compact_selector = self._compact_model_key(normalized_selector or raw) - if compact_selector in compact_candidates: - matches = compact_candidates[compact_selector] - if len(matches) == 1: - return matches[0] - raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") - - raise ValueError(f"模型配置不存在: {raw}") - - @classmethod - def _parse_provider(cls, raw_provider: str) -> ModelProvider: - provider = cls._provider_map().get(raw_provider.strip().lower()) - if not provider: - raise ValueError(f"不支持的 provider: {raw_provider}") - return provider + def _build_scope(self, message: Message) -> ChatScope: + user_id = self._get_user_id(message) + group_id = self._get_group_id(message) + session_id = f"{group_id}:{user_id}" if group_id else user_id + memory_key = user_id if Config.AI_MEMORY_SCOPE == "user" else session_id + return ChatScope(user_id=user_id, group_id=group_id, session_id=session_id, memory_key=memory_key) @staticmethod def _coerce_float(value: Any, default: float) -> float: @@ -246,10 +140,64 @@ class MessageHandler: return default return bool(value) + @classmethod + def _normalize_model_key(cls, raw_key: str) -> str: + key = raw_key.strip().lower().replace("-", "_").replace(" ", "_") + key = cls.MODEL_KEY_PATTERN.sub("_", key) + key = re.sub(r"_+", "_", key).strip("_") + if not key: + raise ValueError("模型 key 不能为空") + if key[0].isdigit(): + key = f"model_{key}" + return key + @staticmethod - def _model_config_to_dict( - config: ModelConfig, include_api_key: bool = True - ) -> Dict[str, Any]: + def _compact_model_key(raw_key: str) -> str: + return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower()) + + def _ordered_model_keys(self) -> list[str]: + return sorted(self.model_profiles.keys()) + + @classmethod + def _parse_provider(cls, raw_provider: str) -> ModelProvider: + provider = cls._provider_map().get(raw_provider.strip().lower()) + if not provider: + raise ValueError(f"不支持的 provider: {raw_provider}") + return provider + + def _resolve_model_selector(self, selector: str) -> str: + raw = (selector or "").strip() + if not raw: + raise ValueError("模型 key 不能为空") + + ordered_keys = self._ordered_model_keys() + if raw.isdigit(): + index = int(raw) + if index < 1 or index > len(ordered_keys): + raise ValueError(f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}") + return ordered_keys[index - 1] + + if raw in self.model_profiles: + return raw + + normalized_selector: Optional[str] + try: + normalized_selector = self._normalize_model_key(raw) + except ValueError: + normalized_selector = None + + if normalized_selector and normalized_selector in self.model_profiles: + return normalized_selector + + compact_selector = self._compact_model_key(normalized_selector or raw) + for key in ordered_keys: + if compact_selector == self._compact_model_key(key): + return key + + raise ValueError(f"模型配置不存在: {raw}") + + @staticmethod + def _model_config_to_dict(config: ModelConfig, include_api_key: bool = True) -> Dict[str, Any]: data: Dict[str, Any] = { "provider": config.provider.value, "model_name": config.model_name, @@ -267,9 +215,7 @@ class MessageHandler: data["api_key"] = config.api_key return data - def _model_config_from_dict( - self, raw: Dict[str, Any], fallback: ModelConfig - ) -> ModelConfig: + def _model_config_from_dict(self, raw: Dict[str, Any], fallback: ModelConfig) -> ModelConfig: provider = self._parse_provider(str(raw.get("provider") or fallback.provider.value)) model_name = str(raw.get("model_name") or fallback.model_name).strip() if not model_name: @@ -291,12 +237,8 @@ class MessageHandler: temperature=self._coerce_float(raw.get("temperature"), fallback.temperature), max_tokens=self._coerce_int(raw.get("max_tokens"), fallback.max_tokens), top_p=self._coerce_float(raw.get("top_p"), fallback.top_p), - frequency_penalty=self._coerce_float( - raw.get("frequency_penalty"), fallback.frequency_penalty - ), - presence_penalty=self._coerce_float( - raw.get("presence_penalty"), fallback.presence_penalty - ), + frequency_penalty=self._coerce_float(raw.get("frequency_penalty"), fallback.frequency_penalty), + presence_penalty=self._coerce_float(raw.get("presence_penalty"), fallback.presence_penalty), timeout=self._coerce_int(raw.get("timeout"), fallback.timeout), stream=self._coerce_bool(raw.get("stream"), fallback.stream), ) @@ -323,43 +265,20 @@ class MessageHandler: for raw_key, raw_profile in raw_profiles.items(): if not isinstance(raw_profile, dict): continue - key_text = str(raw_key or "").strip() if not key_text: continue - try: normalized_key = self._normalize_model_key(key_text) except ValueError: continue - - if normalized_key in profiles and profiles[normalized_key] != raw_profile: - logger.warning( - f"duplicate model key after normalization, keep first: {normalized_key}" - ) - continue - profiles[normalized_key] = raw_profile if not profiles: - profiles = { - "default": self._model_config_to_dict( - default_config, include_api_key=False - ) - } + profiles = {"default": self._model_config_to_dict(default_config, include_api_key=False)} active_raw = str(payload.get("active") or "").strip() - active = "" - if active_raw in profiles: - active = active_raw - elif active_raw: - try: - normalized_active = self._normalize_model_key(active_raw) - except ValueError: - normalized_active = "" - if normalized_active in profiles: - active = normalized_active - + active = active_raw if active_raw in profiles else "" if not active: active = "default" if "default" in profiles else sorted(profiles.keys())[0] @@ -367,44 +286,21 @@ class MessageHandler: self.active_model_key = active try: - active_config = self._model_config_from_dict( - self.model_profiles[self.active_model_key], default_config - ) - except Exception as exc: - logger.warning(f"active model profile invalid, fallback to default: {exc}") - self.model_profiles = { - "default": self._model_config_to_dict( - default_config, include_api_key=False - ) - } - self.active_model_key = "default" + active_config = self._model_config_from_dict(self.model_profiles[active], default_config) + except Exception: active_config = default_config - self._save_model_profiles() return active_config def _ensure_model_profiles_ready(self): if not self.ai_client: return - if self.model_profiles and self.active_model_key in self.model_profiles: return - active_config = self._load_model_profiles(self.ai_client.config) if active_config != self.ai_client.config: self.ai_client.switch_model(active_config) - def _plain_text(self, text: str) -> str: - if not text: - return "" - - result = text - for pattern, replacement in self.MARKDOWN_PATTERNS: - result = pattern.sub(replacement, result) - result = self._strip_urls(result) - - return result.strip() - @classmethod def _strip_urls(cls, text: str) -> str: result = text @@ -412,99 +308,85 @@ class MessageHandler: result = pattern.sub("[链接已省略]", result) return result + def _plain_text(self, text: str) -> str: + if not text: + return "" + result = text + for pattern, replacement in self.MARKDOWN_PATTERNS: + result = pattern.sub(replacement, result) + return self._strip_urls(result).strip() + + @staticmethod + def _looks_like_url_block_error(exc: Exception) -> bool: + lower = str(exc).lower() + return "url" in lower and ( + "not allow" in lower or "not allowed" in lower or "不允许" in lower or "forbidden" in lower + ) + async def _reply_plain(self, message: Message, text: str): - content = self._plain_text(text) + content = self._plain_text(text) or " " try: await message.reply(content=content) + return except Exception as exc: - # QQ C2C may reject any message containing URL. - if "不允许发送url" not in str(exc).lower(): - raise + if not self._looks_like_url_block_error(exc): + logger.warning(f"reply failed once, retrying: {exc}") + await asyncio.sleep(0.2) + await message.reply(content=content) + return + fallback = self._strip_urls(content).strip() or "内容包含受限链接,已省略。" + await message.reply(content=fallback) - logger.warning("消息被平台判定包含 URL,尝试二次清洗后重发") - fallback = self._strip_urls(content).strip() or "内容包含受限链接,已省略。" - await message.reply(content=fallback) + def _extract_message_identity(self, message: Message) -> str: + msg_id = ( + getattr(message, "id", None) + or getattr(message, "message_id", None) + or getattr(message, "msg_id", None) + ) + if msg_id: + return str(msg_id) + author = self._get_user_id(message) + group = self._get_group_id(message) or "-" + content = (getattr(message, "content", "") or "").strip() + return f"fallback:{author}:{group}:{hash(content)}" - def _register_skill_tools(self, skill_name: str) -> int: - if not self.skills_manager or not self.ai_client: - return 0 + def _is_duplicate_message(self, message: Message) -> bool: + now = time.time() + message_key = self._extract_message_identity(message) + expired_before = now - self._dedup_window_seconds + while self._seen_messages: + first_key = next(iter(self._seen_messages)) + if self._seen_messages[first_key] >= expired_before: + break + self._seen_messages.popitem(last=False) + if message_key in self._seen_messages: + return True + self._seen_messages[message_key] = now + if len(self._seen_messages) > self._dedup_max_size: + self._seen_messages.popitem(last=False) + return False - skill = self.skills_manager.get_skill(skill_name) - if not skill: - return 0 + def _is_admin(self, user_id: str) -> bool: + return Config.is_admin(user_id) - count = 0 - for tool_name, tool_func in skill.get_tools().items(): - full_tool_name = f"{skill_name}.{tool_name}" - self.ai_client.register_tool( - name=full_tool_name, - description=f"技能工具: {full_tool_name}", - parameters={"type": "object", "properties": {}}, - function=tool_func, - source="skills", - ) - count += 1 - - return count - - async def _register_mcp_tools(self) -> int: - if not self.mcp_manager or not self.ai_client: - return 0 - - tools = await self.mcp_manager.get_all_tools_for_ai() - count = 0 - - for item in tools: - function_info = item.get("function") if isinstance(item, dict) else None - if not isinstance(function_info, dict): - continue - - full_tool_name = function_info.get("name") - if not full_tool_name: - continue - - parameters = function_info.get("parameters") - if not isinstance(parameters, dict): - parameters = {"type": "object", "properties": {}} - - async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs): - if not self.mcp_manager: - raise RuntimeError("MCP manager not initialized") - return await self.mcp_manager.execute_tool(_full_tool_name, kwargs) - - self.ai_client.register_tool( - name=full_tool_name, - description=f"MCP工具: {full_tool_name}", - parameters=parameters, - function=_mcp_proxy, - source="mcp", - ) - count += 1 - - return count + async def _require_admin(self, message: Message, user_id: str, action: str) -> bool: + if self._is_admin(user_id): + return True + await self._reply_plain(message, f"权限不足,无法执行: {action}") + return False def _parse_traits(self, traits_raw) -> list[PersonalityTrait]: - traits = [] - if isinstance(traits_raw, str): trait_names = [name.strip().upper() for name in traits_raw.split(",") if name.strip()] elif isinstance(traits_raw, list): trait_names = [str(name).strip().upper() for name in traits_raw if str(name).strip()] else: trait_names = [] - - for name in trait_names: - if name in PersonalityTrait.__members__: - traits.append(PersonalityTrait[name]) - - if not traits: - traits = [PersonalityTrait.FRIENDLY] - - return traits + traits = [PersonalityTrait[name] for name in trait_names if name in PersonalityTrait.__members__] + return traits or [PersonalityTrait.FRIENDLY] def _parse_personality_payload(self, key: str, payload: str) -> PersonalityProfile: payload = payload.strip() - if payload.startswith("{"): data = json.loads(payload) return PersonalityProfile( @@ -517,98 +399,129 @@ class MessageHandler: ) if "|" not in payload: - introduction = payload return PersonalityProfile( name=key, - description=introduction, + description=payload, traits=[PersonalityTrait.FRIENDLY], speaking_style="自然口语", - custom_instructions=introduction, + custom_instructions=payload, ) parts = [part.strip() for part in payload.split("|")] - name = parts[0] if len(parts) >= 1 and parts[0] else key - description = parts[1] if len(parts) >= 2 and parts[1] else "自定义人设" - speaking_style = parts[2] if len(parts) >= 3 and parts[2] else "自然口语" - traits = self._parse_traits(parts[3] if len(parts) >= 4 else "") - custom_instructions = parts[4] if len(parts) >= 5 else "" - return PersonalityProfile( - name=name, - description=description, - traits=traits, - speaking_style=speaking_style, - custom_instructions=custom_instructions, + name=parts[0] if len(parts) >= 1 and parts[0] else key, + description=parts[1] if len(parts) >= 2 and parts[1] else "自定义人设", + traits=self._parse_traits(parts[3] if len(parts) >= 4 else ""), + speaking_style=parts[2] if len(parts) >= 3 and parts[2] else "自然口语", + custom_instructions=parts[4] if len(parts) >= 5 else "", + ) + + @staticmethod + def _parse_scope_token(token: Optional[str], scope: ChatScope) -> Tuple[str, Optional[str]]: + raw = (token or "").strip().lower() + if not raw: + return "global", None + if ":" in raw: + scope_name, scope_id = raw.split(":", 1) + return scope_name, scope_id or None + if raw == "global": + return "global", None + if raw == "user": + return "user", scope.user_id + if raw == "group": + return "group", scope.group_id + if raw == "session": + return "session", scope.session_id + return "global", None + + @staticmethod + def _build_personality_usage() -> str: + return ( + "人设命令:\n" + "/personality\n" + "/personality list\n" + "/personality set [global|user|group|session]\n" + "/personality add \n" + "/personality remove " + ) + + @staticmethod + def _build_models_usage(command_name: str = "/models") -> str: + return ( + "模型命令:\n" + f"{command_name} 或 {command_name} list\n" + f"{command_name} current\n" + f"{command_name} add \n" + f"{command_name} add [api_base]\n" + f"{command_name} add \n" + f"{command_name} switch \n" + f"{command_name} remove " + ) + + @staticmethod + def _build_memory_usage(command_name: str = "/memory") -> str: + return ( + "记忆命令:\n" + f"{command_name} 或 {command_name} list [limit]\n" + f"{command_name} get \n" + f"{command_name} add \n" + f"{command_name} add \n" + f"{command_name} update \n" + f"{command_name} delete \n" + f"{command_name} search [limit]\n" + "/clear (默认等价 /clear short)\n" + "/clear short\n" + "/clear long\n" + "/clear all" ) async def _init_ai(self): - if self._ai_initialized: - return + provider = self._provider_map().get(Config.AI_PROVIDER.lower(), ModelProvider.OPENAI) + config = ModelConfig( + provider=provider, + model_name=Config.AI_MODEL, + api_key=Config.AI_API_KEY, + api_base=Config.AI_API_BASE, + temperature=0.7, + ) + config = self._load_model_profiles(config) - try: - provider = self._provider_map().get( - Config.AI_PROVIDER.lower(), ModelProvider.OPENAI + embed_config = None + if Config.AI_EMBED_PROVIDER and Config.AI_EMBED_MODEL: + embed_provider = self._provider_map().get(Config.AI_EMBED_PROVIDER.lower(), ModelProvider.OPENAI) + embed_config = ModelConfig( + provider=embed_provider, + model_name=Config.AI_EMBED_MODEL, + api_key=Config.AI_EMBED_API_KEY or Config.AI_API_KEY, + api_base=Config.AI_EMBED_API_BASE or Config.AI_API_BASE, ) - config = ModelConfig( - provider=provider, - model_name=Config.AI_MODEL, - api_key=Config.AI_API_KEY, - api_base=Config.AI_API_BASE, - temperature=0.7, - ) - config = self._load_model_profiles(config) + self.ai_client = AIClient( + model_config=config, + embed_config=embed_config, + data_dir=Path("data/ai"), + use_vector_db=Config.AI_USE_VECTOR_DB, + use_query_embedding=Config.AI_USE_QUERY_EMBEDDING, + chat_retries=Config.AI_CHAT_RETRIES, + chat_retry_backoff=Config.AI_CHAT_RETRY_BACKOFF_SECONDS, + ) + self._ai_initialized = True + logger.info("AI system initialized") - embed_config = None - if Config.AI_EMBED_PROVIDER and Config.AI_EMBED_MODEL: - embed_provider = self._provider_map().get( - Config.AI_EMBED_PROVIDER.lower(), ModelProvider.OPENAI - ) - - embed_config = ModelConfig( - provider=embed_provider, - model_name=Config.AI_EMBED_MODEL, - api_key=Config.AI_EMBED_API_KEY or Config.AI_API_KEY, - api_base=Config.AI_EMBED_API_BASE or Config.AI_API_BASE, - ) - - self.ai_client = AIClient( - model_config=config, - embed_config=embed_config, - data_dir=Path("data/ai"), - use_vector_db=Config.AI_USE_VECTOR_DB, - ) - - self.skills_manager = SkillsManager(Path("skills")) - await self.skills_manager.load_all_skills() - - total_tools = 0 - for skill_name in self.skills_manager.list_skills(): - total_tools += self._register_skill_tools(skill_name) - logger.info( - f"技能系统初始化完成: {len(self.skills_manager.list_skills())} skills, {total_tools} tools" - ) - - try: - self.mcp_manager = MCPManager(Path("config/mcp.json")) - fs_server = FileSystemMCPServer(root_path=Path("data")) - await self.mcp_manager.register_server(fs_server) - mcp_tool_count = await self._register_mcp_tools() - logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools") - except Exception as exc: - logger.warning(f"MCP 初始化失败: {exc}") - - self._ai_initialized = True - logger.info("AI 系统初始化完成") - - except Exception as exc: - logger.error(f"AI 初始化失败: {exc}") - import traceback - - logger.error(traceback.format_exc()) + def _strip_mentions(self, content: str) -> str: + cleaned = content or "" + for pattern in self.MENTION_PATTERNS: + cleaned = pattern.sub("", cleaned) + if getattr(self.bot, "robot", None) and getattr(self.bot.robot, "name", None): + cleaned = cleaned.replace(f"@{self.bot.robot.name}", "") + return cleaned.strip() async def handle_at_message(self, message: Message): try: + if self._is_duplicate_message(message): + logger.info("duplicate message skipped") + return + if not self._ai_initialized: async with self._init_lock: if not self._ai_initialized: @@ -618,191 +531,83 @@ class MessageHandler: await self._reply_plain(message, "AI 初始化失败,请检查配置") return - content = (message.content or "").strip() - user_id = self._get_user_id(message) - - if content.startswith(f"<@!{self.bot.robot.id}>"): - content = content.replace(f"<@!{self.bot.robot.id}>", "").strip() - + scope = self._build_scope(message) + content = self._strip_mentions((message.content or "").strip()) if not content: return if content.startswith("/"): - await self._handle_command(message, content) + await self._handle_command(message, content, scope) return response = await self.ai_client.chat( - user_id=user_id, + user_id=scope.user_id, user_message=content, use_memory=True, - use_tools=True, + group_id=scope.group_id, + session_id=scope.session_id, + memory_key=scope.memory_key, ) await self._reply_plain(message, response) - except Exception as exc: - logger.error(f"处理消息失败: {exc}") - import traceback - - logger.error(traceback.format_exc()) + logger.error(f"message handling failed: {exc}", exc_info=True) if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): - await self._reply_plain( - message, - "模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。", - ) + await self._reply_plain(message, "模型响应超时,请稍后重试,或将 timeout 调大(建议 120-180 秒)。") return await self._reply_plain(message, "消息处理失败,请稍后重试") - async def _handle_skills_command(self, message: Message, command: str): - if not self.skills_manager or not self.ai_client: - await self._reply_plain(message, "技能系统未初始化") + async def _handle_personality_command(self, message: Message, command: str, scope: ChatScope): + if not self.ai_client: + await self._reply_plain(message, "AI 客户端未初始化") return - parts = command.split() - command_name = parts[0].lower() - action = parts[1].lower() if len(parts) > 1 else "list" + parts = command.split(maxsplit=4) + action = parts[1].lower() if len(parts) > 1 else "show" - if action in {"list", "ls"} and len(parts) <= 2: - loaded = self.skills_manager.list_skills() - available = self.skills_manager.list_available_skills() - unloaded = [name for name in available if name not in loaded] - - lines = ["技能状态:"] - lines.append("已加载: " + (", ".join(loaded) if loaded else "无")) - lines.append("可安装: " + (", ".join(unloaded) if unloaded else "无")) - lines.append(self._build_skills_usage(command_name)) - await self._reply_plain(message, "\n".join(lines)) - return - - if action not in {"install", "uninstall", "remove", "reload"}: - await self._reply_plain(message, self._build_skills_usage(command_name)) - return - - if len(parts) < 3: - await self._reply_plain(message, self._build_skills_usage(command_name)) - return - - if action == "install": - source = parts[2] - desired_name = parts[3] if len(parts) > 3 else None - - try: - source_key = self.skills_manager.normalize_skill_key(source) - except Exception: - source_key = None - - if source_key and source_key in self.skills_manager.list_available_skills(): - success = await self.skills_manager.load_skill(source_key) - if not success: - await self._reply_plain(message, f"加载技能失败: {source_key}") - return - - tool_count = self._register_skill_tools(source_key) - await self._reply_plain( - message, - f"已加载本地技能: {source_key}\n注册工具: {tool_count}", - ) - return - - ok, result = self.skills_manager.install_skill_from_source( - source=source, - skill_name=desired_name, - overwrite=False, + if action in {"show", "current"} and len(parts) <= 2: + profile = self.ai_client.personality.get_active_personality( + user_id=scope.user_id, + group_id=scope.group_id, + session_id=scope.session_id, ) - if not ok: - await self._reply_plain(message, f"安装失败: {result}") - return - - installed_key = result - success = await self.skills_manager.load_skill(installed_key) - if not success: - await self._reply_plain(message, f"安装成功但加载失败: {installed_key}") - return - - tool_count = self._register_skill_tools(installed_key) - await self._reply_plain( - message, - f"已从来源安装并加载技能: {installed_key}\n注册工具: {tool_count}", - ) - return - - skill_name = parts[2] - try: - skill_key = self.skills_manager.normalize_skill_key(skill_name) - except Exception: - await self._reply_plain(message, f"非法技能名: {skill_name}") - return - - if action in {"uninstall", "remove"}: - removed_tools = self.ai_client.unregister_tools_by_prefix(f"{skill_key}.") - removed = await self.skills_manager.uninstall_skill(skill_key, delete_files=True) - if not removed: - await self._reply_plain(message, f"卸载失败或技能不存在: {skill_key}") - return - - await self._reply_plain( - message, - f"已卸载技能: {skill_key}\n注销工具: {removed_tools}", - ) - return - - self.ai_client.unregister_tools_by_prefix(f"{skill_key}.") - success = await self.skills_manager.reload_skill(skill_key) - if not success: - await self._reply_plain(message, f"重载失败: {skill_key}") - return - - tool_count = self._register_skill_tools(skill_key) - await self._reply_plain(message, f"已重载技能: {skill_key}\n注册工具: {tool_count}") - - async def _handle_personality_command(self, message: Message, command: str): - parts = command.split(maxsplit=3) - - if len(parts) == 1: - current = self.ai_client.personality.current_personality names = self.ai_client.list_personalities() - if current: - await self._reply_plain( - message, - f"当前人设: {current.name}\n简介: {current.description}\n可用: {', '.join(names)}", - ) + if profile: + await self._reply_plain(message, f"当前人设: {profile.name}\n简介: {profile.description}\n可用: {', '.join(names)}") else: - await self._reply_plain(message, "当前没有激活的人设") + await self._reply_plain(message, "当前没有可用人设") return - action = parts[1].lower() - if action == "list": - names = self.ai_client.list_personalities() - await self._reply_plain(message, "可用人设: " + ", ".join(names)) + await self._reply_plain(message, "可用人设: " + ", ".join(self.ai_client.list_personalities())) return if action in {"set", "use"}: if len(parts) < 3: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - if self.ai_client.set_personality(key): + key = parts[2].strip() + scope_name, scope_id = self._parse_scope_token(parts[3] if len(parts) >= 4 else None, scope) + if scope_name != "global" and not await self._require_admin(message, scope.user_id, "personality scoped set"): + return + if self.ai_client.set_personality(key, scope=scope_name, scope_id=scope_id): await self._reply_plain(message, f"已切换人设: {key}") else: - await self._reply_plain(message, f"人设不存在: {key}") + await self._reply_plain(message, f"人设不存在或 scope 参数非法: {key}") return if action in {"add", "create"}: + if not await self._require_admin(message, scope.user_id, "personality add"): + return if len(parts) < 4: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - payload = parts[3] - + key = parts[2].strip() + payload = parts[3].strip() try: profile = self._parse_personality_payload(key, payload) - ok = self.ai_client.personality.add_personality(key, profile) - if not ok: + if not self.ai_client.personality.add_personality(key, profile): await self._reply_plain(message, f"新增人设失败: {key}") return - self.ai_client.set_personality(key) await self._reply_plain(message, f"已新增并切换人设: {key}") except Exception as exc: @@ -810,33 +615,32 @@ class MessageHandler: return if action in {"remove", "delete"}: + if not await self._require_admin(message, scope.user_id, "personality remove"): + return if len(parts) < 3: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - removed = self.ai_client.personality.remove_personality(key) - if removed: + key = parts[2].strip() + if self.ai_client.personality.remove_personality(key): await self._reply_plain(message, f"已删除人设: {key}") else: await self._reply_plain(message, f"删除失败(可能是默认人设或不存在): {key}") return - # 兼容旧命令: /personality if self.ai_client.set_personality(parts[1]): await self._reply_plain(message, f"已切换人设: {parts[1]}") return await self._reply_plain(message, self._build_personality_usage()) - async def _handle_memory_command(self, message: Message, command: str): + async def _handle_memory_command(self, message: Message, command: str, scope: ChatScope): if not self.ai_client: await self._reply_plain(message, "AI 客户端未初始化") return - user_id = self._get_user_id(message) parts = command.split(maxsplit=3) action = parts[1].lower() if len(parts) > 1 else "list" + memory_user_id = scope.memory_key if action in {"list", "ls"}: limit = 10 @@ -846,179 +650,11 @@ class MessageHandler: except ValueError: await self._reply_plain(message, self._build_memory_usage()) return - - memories = await self.ai_client.list_long_term_memories(user_id, limit=limit) + memories = await self.ai_client.list_long_term_memories(memory_user_id, limit=limit) if not memories: await self._reply_plain(message, "暂无长期记忆") return - lines = ["长期记忆列表:"] - for memory in memories: - content = self._plain_text(memory.content).replace("\n", " ").strip() - if len(content) > 60: - content = content[:57] + "..." - lines.append( - f"- {memory.id} | 重要性={memory.importance:.2f} | {memory.timestamp.isoformat()} | {content}" - ) - await self._reply_plain(message, "\n".join(lines)) - return - - if action == "get": - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - memory = await self.ai_client.get_long_term_memory(user_id, memory_id) - if not memory: - await self._reply_plain(message, f"记忆不存在: {memory_id}") - return - - meta_text = json.dumps(memory.metadata or {}, ensure_ascii=False) - await self._reply_plain( - message, - "记忆详情:\n" - f"id: {memory.id}\n" - f"重要性: {memory.importance:.2f}\n" - f"时间: {memory.timestamp.isoformat()}\n" - f"访问次数: {memory.access_count}\n" - f"内容: {memory.content}\n" - f"元数据: {meta_text}", - ) - return - - if action == "add": - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - payload = " ".join(parts[2:]).strip() - content = payload - importance = 0.8 - metadata = None - - if payload.startswith("{"): - try: - data = json.loads(payload) - if not isinstance(data, dict): - raise ValueError("payload 必须是对象") - content = str(data.get("content") or "").strip() - importance = float(data.get("importance", 0.8)) - raw_meta = data.get("metadata") - metadata = raw_meta if isinstance(raw_meta, dict) else None - except Exception as exc: - await self._reply_plain(message, f"JSON 解析失败: {exc}") - return - - if not content: - await self._reply_plain(message, "内容不能为空") - return - - memory = await self.ai_client.add_long_term_memory( - user_id=user_id, - content=content, - importance=importance, - metadata=metadata, - ) - if not memory: - await self._reply_plain(message, "新增长期记忆失败") - return - - await self._reply_plain( - message, f"已新增长期记忆: {memory.id} (重要性={memory.importance:.2f})" - ) - return - - if action in {"update", "set"}: - if len(parts) < 4: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - payload = parts[3].strip() - content = payload - importance = None - metadata = None - - if payload.startswith("{"): - try: - data = json.loads(payload) - if not isinstance(data, dict): - raise ValueError("payload 必须是对象") - if "content" in data: - content = str(data.get("content") or "") - else: - content = None - if "importance" in data: - importance = float(data.get("importance")) - raw_meta = data.get("metadata") - if raw_meta is not None: - if not isinstance(raw_meta, dict): - raise ValueError("metadata 必须是对象") - metadata = raw_meta - except Exception as exc: - await self._reply_plain(message, f"JSON 解析失败: {exc}") - return - - updated = await self.ai_client.update_long_term_memory( - user_id=user_id, - memory_id=memory_id, - content=content, - importance=importance, - metadata=metadata, - ) - if not updated: - await self._reply_plain(message, f"更新失败或记忆不存在: {memory_id}") - return - - await self._reply_plain( - message, - f"已更新长期记忆: {memory_id} (重要性={updated.importance:.2f})", - ) - return - - if action in {"delete", "remove", "rm"}: - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - deleted = await self.ai_client.delete_long_term_memory(user_id, memory_id) - if not deleted: - await self._reply_plain(message, f"删除失败或记忆不存在: {memory_id}") - return - - await self._reply_plain(message, f"已删除长期记忆: {memory_id}") - return - - if action in {"search", "find"}: - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - query_payload = " ".join(parts[2:]).strip() - limit = 10 - query = query_payload - if " " in query_payload: - possible_query, possible_limit = query_payload.rsplit(" ", 1) - if possible_limit.isdigit(): - query = possible_query - limit = max(1, min(100, int(possible_limit))) - - if not query: - await self._reply_plain(message, "搜索词不能为空") - return - - memories = await self.ai_client.search_long_term_memories( - user_id=user_id, - query=query, - limit=limit, - ) - if not memories: - await self._reply_plain(message, f"没有匹配到相关记忆: {query}") - return - - lines = [f"搜索结果({len(memories)} 条):"] for memory in memories: content = self._plain_text(memory.content).replace("\n", " ").strip() if len(content) > 60: @@ -1027,13 +663,123 @@ class MessageHandler: await self._reply_plain(message, "\n".join(lines)) return + if action == "get": + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + memory = await self.ai_client.get_long_term_memory(memory_user_id, memory_id) + if not memory: + await self._reply_plain(message, f"记忆不存在: {memory_id}") + return + meta_text = json.dumps(memory.metadata or {}, ensure_ascii=False) + await self._reply_plain( + message, + "记忆详情:\n" + f"id: {memory.id}\n重要性: {memory.importance:.2f}\n时间: {memory.timestamp.isoformat()}\n" + f"访问次数: {memory.access_count}\n内容: {memory.content}\n元数据: {meta_text}", + ) + return + + if action == "add": + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + payload = " ".join(parts[2:]).strip() + content = payload + importance = 0.8 + metadata = None + if payload.startswith("{"): + try: + data = json.loads(payload) + content = str(data.get("content") or "").strip() + importance = float(data.get("importance", 0.8)) + raw_meta = data.get("metadata") + metadata = raw_meta if isinstance(raw_meta, dict) else None + except Exception as exc: + await self._reply_plain(message, f"JSON 解析失败: {exc}") + return + if not content: + await self._reply_plain(message, "内容不能为空") + return + memory = await self.ai_client.add_long_term_memory(memory_user_id, content, importance, metadata) + if not memory: + await self._reply_plain(message, "新增长期记忆失败(可能命中去重策略)") + return + await self._reply_plain(message, f"已新增长期记忆: {memory.id} (重要性={memory.importance:.2f})") + return + + if action in {"update", "set"}: + if len(parts) < 4: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + payload = parts[3].strip() + content = payload + importance = None + metadata = None + if payload.startswith("{"): + try: + data = json.loads(payload) + content = str(data.get("content") or "") if "content" in data else None + importance = float(data.get("importance")) if "importance" in data else None + raw_meta = data.get("metadata") + if raw_meta is not None and not isinstance(raw_meta, dict): + raise ValueError("metadata 必须是对象") + metadata = raw_meta + except Exception as exc: + await self._reply_plain(message, f"JSON 解析失败: {exc}") + return + + updated = await self.ai_client.update_long_term_memory( + user_id=memory_user_id, + memory_id=memory_id, + content=content, + importance=importance, + metadata=metadata, + ) + if not updated: + await self._reply_plain(message, f"更新失败或记忆不存在: {memory_id}") + return + await self._reply_plain(message, f"已更新长期记忆: {memory_id} (重要性={updated.importance:.2f})") + return + + if action in {"delete", "remove", "rm"}: + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + deleted = await self.ai_client.delete_long_term_memory(memory_user_id, memory_id) + await self._reply_plain(message, f"已删除长期记忆: {memory_id}" if deleted else f"删除失败或记忆不存在: {memory_id}") + return + + if action in {"search", "find"}: + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + payload = " ".join(parts[2:]).strip() + query, limit = payload, 10 + if " " in payload: + maybe_query, maybe_limit = payload.rsplit(" ", 1) + if maybe_limit.isdigit(): + query, limit = maybe_query, max(1, min(100, int(maybe_limit))) + memories = await self.ai_client.search_long_term_memories(memory_user_id, query, limit) + if not memories: + await self._reply_plain(message, f"没有匹配到相关记忆: {query}") + return + lines = [f"搜索结果({len(memories)} 条):"] + for memory in memories: + content = self._plain_text(memory.content).replace("\n", " ").strip() + lines.append(f"- {memory.id} | 重要性={memory.importance:.2f} | {content[:60]}") + await self._reply_plain(message, "\n".join(lines)) + return + await self._reply_plain(message, self._build_memory_usage()) - async def _handle_models_command(self, message: Message, command: str): + async def _handle_models_command(self, message: Message, command: str, scope: ChatScope): if not self.ai_client: await self._reply_plain(message, "AI 客户端未初始化") return - self._ensure_model_profiles_ready() parts = command.split(maxsplit=3) @@ -1041,69 +787,40 @@ class MessageHandler: if action in {"list", "ls"} and len(parts) <= 2: lines = [f"当前模型配置: {self.active_model_key}"] - ordered_keys = self._ordered_model_keys() - for idx, key in enumerate(ordered_keys, start=1): + for idx, key in enumerate(self._ordered_model_keys(), start=1): profile = self.model_profiles.get(key, {}) marker = "*" if key == self.active_model_key else "-" - provider = str(profile.get("provider") or "?") - model_name = str(profile.get("model_name") or "?") - lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}") - - if ordered_keys: - lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2") - + lines.append(f"{marker} {idx}. {key}: {profile.get('provider')}/{profile.get('model_name')}") lines.append(self._build_models_usage("/models")) await self._reply_plain(message, "\n".join(lines)) return if action in {"current", "show"}: - config = self.ai_client.config - await self._reply_plain( - message, - "当前模型:\n" - f"配置名: {self.active_model_key}\n" - f"供应商: {config.provider.value}\n" - f"模型: {config.model_name}\n" - f"API 地址: {config.api_base or '-'}", - ) + cfg = self.ai_client.config + await self._reply_plain(message, f"当前模型: {self.active_model_key}\n{cfg.provider.value}/{cfg.model_name}") + return + + if not await self._require_admin(message, scope.user_id, f"models {action}"): return if action in {"switch", "set", "use"}: if len(parts) < 3: await self._reply_plain(message, self._build_models_usage()) return - - try: - key = self._resolve_model_selector(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - - try: - config = self._model_config_from_dict( - self.model_profiles[key], self.ai_client.config - ) - self.ai_client.switch_model(config) - except Exception as exc: - await self._reply_plain(message, f"切换模型失败: {exc}") - return - + key = self._resolve_model_selector(parts[2]) + config = self._model_config_from_dict(self.model_profiles[key], self.ai_client.config) + self.ai_client.switch_model(config) self.active_model_key = key self._save_model_profiles() - await self._reply_plain( - message, f"已切换模型: {key} ({config.provider.value}/{config.model_name})" - ) + await self._reply_plain(message, f"已切换模型: {key} ({config.provider.value}/{config.model_name})") return if action in {"add", "create"}: - # 快捷方式: /models add - # 仅替换 model_name,保留 provider/api_base/api_key。 + if len(parts) < 3: + await self._reply_plain(message, self._build_models_usage()) + return if len(parts) == 3: model_name = parts[2].strip() - if not model_name: - await self._reply_plain(message, self._build_models_usage()) - return - key = self._normalize_model_key(model_name) config = ModelConfig( provider=self.ai_client.config.provider, @@ -1115,241 +832,79 @@ class MessageHandler: top_p=self.ai_client.config.top_p, frequency_penalty=self.ai_client.config.frequency_penalty, presence_penalty=self.ai_client.config.presence_penalty, - timeout=self.ai_client.config.outtime, + timeout=self.ai_client.config.timeout, stream=self.ai_client.config.stream, ) - - self.model_profiles[key] = self._model_config_to_dict( - config, include_api_key=False - ) - self.active_model_key = key - self._save_model_profiles() - self.ai_client.switch_model(config) - - await self._reply_plain( - message, - f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})", - ) - return - - if len(parts) < 4: - await self._reply_plain(message, self._build_models_usage()) - return - - try: + else: key = self._normalize_model_key(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - - payload = parts[3].strip() - include_api_key = False - - try: + payload = parts[3].strip() if payload.startswith("{"): raw_profile = json.loads(payload) - if not isinstance(raw_profile, dict): - raise ValueError("模型参数必须是 JSON 对象") else: payload_parts = payload.split() if len(payload_parts) < 2: - raise ValueError( - "用法: /models add [api_base]" - ) + await self._reply_plain(message, self._build_models_usage()) + return raw_profile = { "provider": payload_parts[0], "model_name": payload_parts[1], } if len(payload_parts) >= 3: raw_profile["api_base"] = payload_parts[2] - - include_api_key = bool(raw_profile.get("api_key")) config = self._model_config_from_dict(raw_profile, self.ai_client.config) - except Exception as exc: - await self._reply_plain(message, f"新增模型失败: {exc}") - return - self.model_profiles[key] = self._model_config_to_dict( - config, include_api_key=include_api_key - ) + self.model_profiles[key] = self._model_config_to_dict(config, include_api_key=False) self.active_model_key = key self._save_model_profiles() self.ai_client.switch_model(config) - - await self._reply_plain( - message, - f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})", - ) + await self._reply_plain(message, f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})") return if action in {"remove", "delete", "rm"}: if len(parts) < 3: await self._reply_plain(message, self._build_models_usage()) return - - try: - key = self._resolve_model_selector(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - + key = self._resolve_model_selector(parts[2]) if key == "default": await self._reply_plain(message, "默认模型配置不能删除") return - - del self.model_profiles[key] - switched_to = None - - if self.active_model_key == key: - fallback_key = ( - "default" - if "default" in self.model_profiles - else (sorted(self.model_profiles.keys())[0] if self.model_profiles else None) - ) - - if not fallback_key: - self.model_profiles["default"] = self._model_config_to_dict( - self.ai_client.config, include_api_key=False - ) - fallback_key = "default" - - fallback_config = self._model_config_from_dict( - self.model_profiles[fallback_key], self.ai_client.config - ) - self.ai_client.switch_model(fallback_config) - self.active_model_key = fallback_key - switched_to = fallback_key - + self.model_profiles.pop(key, None) self._save_model_profiles() - - if switched_to: - await self._reply_plain( - message, f"已删除模型: {key},已切换到: {switched_to}" - ) - else: - await self._reply_plain(message, f"已删除模型: {key}") + await self._reply_plain(message, f"已删除模型: {key}") return await self._reply_plain(message, self._build_models_usage()) - async def _handle_command(self, message: Message, command: str): - user_id = self._get_user_id(message) - + async def _handle_command(self, message: Message, command: str, scope: ChatScope): if command == "/help": - await self._reply_plain( - message, - "命令帮助\n" - "====================\n" - "基础命令\n" - "/help\n" - "/clear (默认等价 /clear short)\n" - "/clear short\n" - "/clear long\n" - "/clear all\n" - "\n" - "人设命令\n" - "/personality\n" - "/personality list\n" - "/personality set \n" - "/personality add \n" - "/personality remove \n" - "\n" - "技能命令\n" - "/skills\n" - "/skills install [skill_name]\n" - "/skills uninstall \n" - "/skills reload \n" - "\n" - "模型命令\n" - "/models\n" - "/models current\n" - "/models add \n" - "/models add [api_base]\n" - "/models switch \n" - "/models remove \n" - "\n" - "记忆命令\n" - "/memory\n" - "/memory get \n" - "/memory add \n" - "/memory update \n" - "/memory delete \n" - "/memory search [limit]\n" - "\n" - "任务命令\n" - "/task ", - ) + await self._reply_plain(message, "输入 /personality /models /memory /clear 查看对应能力") return if command.startswith("/clear"): - clear_parts = command.split(maxsplit=1) - clear_scope = clear_parts[1].strip().lower() if len(clear_parts) > 1 else "short" - + clear_scope = command.split(maxsplit=1)[1].strip().lower() if len(command.split(maxsplit=1)) > 1 else "short" if clear_scope in {"all", ""}: - cleared_all = await self.ai_client.clear_all_memory(user_id) - if cleared_all: - await self._reply_plain(message, "已清除短期记忆和长期记忆") - else: - await self._reply_plain( - message, - "短期记忆已清除,但长期记忆清除失败", - ) + ok = await self.ai_client.clear_all_memory(scope.memory_key) + await self._reply_plain(message, "已清除短期记忆和长期记忆" if ok else "短期记忆已清除,但长期记忆清除失败") return - if clear_scope in {"short", "short_term"}: - self.ai_client.clear_memory(user_id) + self.ai_client.clear_memory(scope.memory_key) await self._reply_plain(message, "已清除短期记忆") return - if clear_scope in {"long", "long_term"}: - cleared_long = await self.ai_client.clear_long_term_memory(user_id) - if cleared_long: - await self._reply_plain(message, "已清除长期记忆") - else: - await self._reply_plain(message, "清除长期记忆失败") + ok = await self.ai_client.clear_long_term_memory(scope.memory_key) + await self._reply_plain(message, "已清除长期记忆" if ok else "清除长期记忆失败") return - - await self._reply_plain( - message, - "用法:\n/clear\n/clear short\n/clear long\n/clear all", - ) + await self._reply_plain(message, "用法:\n/clear\n/clear short\n/clear long\n/clear all") return if command.startswith("/personality"): - await self._handle_personality_command(message, command) + await self._handle_personality_command(message, command, scope) return - - if command.startswith("/skills") or command.startswith("/skill"): - await self._handle_skills_command(message, command) - return - if command.startswith("/models") or command.startswith("/model"): - await self._handle_models_command(message, command) + await self._handle_models_command(message, command, scope) return - if command.startswith("/memory") or command.startswith("/mem"): - await self._handle_memory_command(message, command) - return - - if command.startswith("/task"): - parts = command.split(maxsplit=1) - if len(parts) == 2: - task_id = parts[1] - status = self.ai_client.get_task_status(task_id) - if status: - await self._reply_plain( - message, - "任务状态:\n" - f"标题: {status['title']}\n" - f"状态: {status['status']}\n" - f"进度: {status['progress'] * 100:.1f}%\n" - f"步骤: {status['completed_steps']}/{status['total_steps']}", - ) - else: - await self._reply_plain(message, "任务不存在") - else: - await self._reply_plain(message, "用法: /task ") + await self._handle_memory_command(message, command, scope) return await self._reply_plain(message, "未知命令,请输入 /help 查看帮助") - diff --git a/src/utils/logger.py b/src/utils/logger.py index feef2b0..f7a0da6 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -1,63 +1,101 @@ """ -日志配置模块 +Logging helpers. """ + +from __future__ import annotations + +import json import logging import os +from datetime import datetime, timezone from pathlib import Path +from typing import Any, Dict -def setup_logger(name='QQBot', level=None): - """ - 设置日志记录器 - - Args: - name: 日志记录器名称 - level: 日志级别,默认从环境变量读取 - - Returns: - logging.Logger: 配置好的日志记录器 - """ - # 创建logs目录 - log_dir = Path(__file__).parent.parent.parent / 'logs' +class _JsonFormatter(logging.Formatter): + """Simple JSON formatter for structured logs.""" + + _BASE_FIELDS = { + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + } + + def format(self, record: logging.LogRecord) -> str: + payload: Dict[str, Any] = { + "ts": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "file": record.filename, + "line": record.lineno, + } + + for key, value in record.__dict__.items(): + if key in self._BASE_FIELDS or key.startswith("_"): + continue + payload[key] = value + + if record.exc_info: + payload["exc"] = self.formatException(record.exc_info) + + return json.dumps(payload, ensure_ascii=False) + + +def setup_logger(name: str = "QQBot", level: str | None = None) -> logging.Logger: + log_dir = Path(__file__).parent.parent.parent / "logs" log_dir.mkdir(exist_ok=True) - - # 设置日志级别 + if level is None: - level = os.getenv('LOG_LEVEL', 'INFO') - - # 创建日志记录器 + level = os.getenv("LOG_LEVEL", "INFO") + log_format = (os.getenv("LOG_FORMAT", "text") or "text").lower() + logger = logging.getLogger(name) - logger.setLevel(getattr(logging, level.upper())) - # 避免向 root logger 传播导致重复输出 + logger.setLevel(getattr(logging, str(level).upper(), logging.INFO)) logger.propagate = False - - # 避免重复添加处理器 + if logger.handlers: return logger - - # 控制台处理器 + console_handler = logging.StreamHandler() + file_handler = logging.FileHandler(log_dir / "bot.log", encoding="utf-8") + + if log_format == "json": + formatter = _JsonFormatter() + console_handler.setFormatter(formatter) + file_handler.setFormatter(formatter) + else: + console_fmt = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + file_fmt = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + console_handler.setFormatter(console_fmt) + file_handler.setFormatter(file_fmt) + console_handler.setLevel(logging.INFO) - console_format = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - console_handler.setFormatter(console_format) - - # 文件处理器 - file_handler = logging.FileHandler( - log_dir / 'bot.log', - encoding='utf-8' - ) file_handler.setLevel(logging.DEBUG) - file_format = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler.setFormatter(file_format) - - # 添加处理器 + logger.addHandler(console_handler) logger.addHandler(file_handler) - return logger diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c5c7f8c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +from pathlib import Path +import sys + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) diff --git a/tests/test_ai.py b/tests/test_ai.py deleted file mode 100644 index 0638f2c..0000000 --- a/tests/test_ai.py +++ /dev/null @@ -1,628 +0,0 @@ -"""AI integration tests.""" - -import asyncio -import json -import os -from pathlib import Path -import shutil -import stat -import sys -import tempfile -import time -import zipfile - -from dotenv import load_dotenv - -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from src.ai import AIClient -from src.ai.base import ModelConfig, ModelProvider -from src.ai.memory import MemorySystem -from src.ai.skills import SkillsManager, create_skill_template -from src.handlers.message_handler_ai import MessageHandler - -load_dotenv(project_root / ".env") - - -TEST_DATA_DIR = Path("data/ai_test") - - -def _safe_rmtree(path: Path): - if not path.exists(): - return - - def _onerror(func, target, exc_info): - try: - os.chmod(target, stat.S_IWRITE) - func(target) - except Exception: - pass - - for _ in range(3): - try: - shutil.rmtree(path, onerror=_onerror) - return - except PermissionError: - time.sleep(0.2) - - -def _safe_unlink(path: Path): - if not path.exists(): - return - - for _ in range(3): - try: - path.unlink() - return - except PermissionError: - time.sleep(0.2) - - -def _read_env(name: str, default=None): - value = os.getenv(name) - if value is None: - return default - - value = value.strip() - if not value or value.startswith("#"): - return default - - return value - - -def get_ai_config() -> ModelConfig: - provider_map = { - "openai": ModelProvider.OPENAI, - "anthropic": ModelProvider.ANTHROPIC, - "deepseek": ModelProvider.DEEPSEEK, - "qwen": ModelProvider.QWEN, - "siliconflow": ModelProvider.OPENAI, - } - - provider_str = (_read_env("AI_PROVIDER", "openai") or "openai").lower() - provider = provider_map.get(provider_str, ModelProvider.OPENAI) - - return ModelConfig( - provider=provider, - model_name=_read_env("AI_MODEL", "gpt-3.5-turbo") or "gpt-3.5-turbo", - api_key=_read_env("AI_API_KEY", "") or "", - api_base=_read_env("AI_API_BASE"), - temperature=0.7, - ) - - -def get_embed_config() -> ModelConfig: - provider_map = { - "openai": ModelProvider.OPENAI, - "anthropic": ModelProvider.ANTHROPIC, - "deepseek": ModelProvider.DEEPSEEK, - "qwen": ModelProvider.QWEN, - "siliconflow": ModelProvider.OPENAI, - } - - provider_str = (_read_env("AI_EMBED_PROVIDER", "openai") or "openai").lower() - provider = provider_map.get(provider_str, ModelProvider.OPENAI) - - api_key = _read_env("AI_EMBED_API_KEY") or _read_env("AI_API_KEY", "") or "" - api_base = _read_env("AI_EMBED_API_BASE") or _read_env("AI_API_BASE") - - return ModelConfig( - provider=provider, - model_name=_read_env("AI_EMBED_MODEL", "text-embedding-3-small") - or "text-embedding-3-small", - api_key=api_key, - api_base=api_base, - temperature=0.0, - ) - - -class FakeMessage: - def __init__(self, content: str): - from types import SimpleNamespace - - self.content = content - self.author = SimpleNamespace(id="test_user") - self.replies = [] - - async def reply(self, content: str): - self.replies.append(content) - - -def make_handler() -> MessageHandler: - from types import SimpleNamespace - - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) - handler = MessageHandler(fake_bot) - handler.ai_client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - handler.skills_manager = SkillsManager(Path("skills")) - handler.model_profiles_path = TEST_DATA_DIR / "models_test.json" - TEST_DATA_DIR.mkdir(parents=True, exist_ok=True) - _safe_unlink(handler.model_profiles_path) - handler._ai_initialized = True - return handler - - -async def _test_basic_chat(): - print("=== test_basic_chat ===") - - config = get_ai_config() - if not config.api_key: - print("skip: AI_API_KEY not configured") - return - - embed_config = get_embed_config() - client = AIClient(config, embed_config=embed_config, data_dir=TEST_DATA_DIR) - response = await client.chat( - user_id="test_user", - user_message="你好,请介绍一下你自己", - use_memory=False, - use_tools=False, - ) - assert response - print("ok: chat response length", len(response)) - - -async def _test_memory(): - print("=== test_memory ===") - - config = get_ai_config() - if not config.api_key: - print("skip: AI_API_KEY not configured") - return - - client = AIClient(config, embed_config=get_embed_config(), data_dir=TEST_DATA_DIR) - await client.chat(user_id="test_user", user_message="鎴戝彨寮犱笁", use_memory=True) - await client.chat(user_id="test_user", user_message="what is my name", use_memory=True) - - short_term, long_term = await client.memory.get_context("test_user") - assert len(short_term) >= 2 - # 重要性改为模型评估后,是否入长期记忆取决于模型打分,不再固定断言数量。 - assert isinstance(long_term, list) - print("ok: memory short/long", len(short_term), len(long_term)) - - -async def _test_personality(): - print("=== test_personality ===") - - client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - names = client.list_personalities() - assert names - assert client.set_personality(names[0]) - - key = "roleplay_test" - added = client.personality.add_personality( - key, - client.personality.get_personality("default"), - ) - assert added - assert key in client.list_personalities() - assert client.personality.remove_personality(key) - assert key not in client.list_personalities() - - print("ok: personality add/remove") - - -async def _test_skills(): - print("=== test_skills ===") - - manager = SkillsManager(Path("skills")) - assert await manager.load_skill("weather") - - tools = manager.get_all_tools() - assert "weather.get_weather" in tools - weather = await tools["weather.get_weather"](city="鍖椾含") - assert weather - - assert await manager.load_skill("skills_creator") - tools = manager.get_all_tools() - assert "skills_creator.create_skill" in tools - - await manager.unload_skill("weather") - await manager.unload_skill("skills_creator") - print("ok: skills load/unload") - - -async def _test_skill_commands(): - print("=== test_skill_commands ===") - - handler = make_handler() - skill_key = f"cmd_zip_skill_{int(time.time() * 1000)}" - - # Prepare a zip package source for install testing - tmp_root = TEST_DATA_DIR / "tmp_skill_pkg" - if tmp_root.exists(): - _safe_rmtree(tmp_root) - tmp_root.mkdir(parents=True, exist_ok=True) - - create_skill_template(skill_key, tmp_root, description="zip skill", author="test") - - zip_path = TEST_DATA_DIR / f"{skill_key}.zip" - - with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: - for file in (tmp_root / skill_key).rglob("*"): - if file.is_file(): - zf.write(file, file.relative_to(tmp_root)) - - install_msg = FakeMessage(f"/skills install {zip_path}") - await handler._handle_command(install_msg, install_msg.content) - assert install_msg.replies, "install command no reply" - - list_msg = FakeMessage("/skills") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies, "list command no reply" - - reload_msg = FakeMessage(f"/skills reload {skill_key}") - await handler._handle_command(reload_msg, reload_msg.content) - assert reload_msg.replies, "reload command no reply" - - uninstall_msg = FakeMessage(f"/skills uninstall {skill_key}") - await handler._handle_command(uninstall_msg, uninstall_msg.content) - assert uninstall_msg.replies, "uninstall command no reply" - - if tmp_root.exists(): - _safe_rmtree(tmp_root) - _safe_unlink(zip_path) - - print("ok: skills install/reload/uninstall command") - - -async def _test_personality_commands(): - print("=== test_personality_commands ===") - - handler = make_handler() - intro = "You are a hot-blooded anime hero. Speak directly and stay in-character." - - add_cmd = ( - "/personality add roleplay_hero " - f"{intro}" - ) - add_msg = FakeMessage(add_cmd) - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - - set_msg = FakeMessage("/personality set roleplay_hero") - await handler._handle_command(set_msg, set_msg.content) - assert set_msg.replies - assert intro in handler.ai_client.personality.get_system_prompt() - - remove_msg = FakeMessage("/personality remove roleplay_hero") - await handler._handle_command(remove_msg, remove_msg.content) - assert remove_msg.replies - - assert "roleplay_hero" not in handler.ai_client.list_personalities() - print("ok: personality add/set/remove command") - - -async def _test_model_commands(): - print("=== test_model_commands ===") - - handler = make_handler() - - list_msg = FakeMessage("/models") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies - assert "default" in list_msg.replies[-1].lower() - - add_msg = FakeMessage("/models add roleplay_llm openai gpt-4o-mini") - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - assert handler.active_model_key == "roleplay_llm" - assert "roleplay_llm" in handler.model_profiles - - switch_msg = FakeMessage("/models switch default") - await handler._handle_command(switch_msg, switch_msg.content) - assert switch_msg.replies - assert handler.active_model_key == "default" - - current_msg = FakeMessage("/models current") - await handler._handle_command(current_msg, current_msg.content) - assert current_msg.replies - - old_config = handler.ai_client.config - shortcut_model = "Qwen/Qwen2.5-7B-Instruct" - shortcut_key = handler._normalize_model_key(shortcut_model) - shortcut_add_msg = FakeMessage(f"/models add {shortcut_model}") - await handler._handle_command(shortcut_add_msg, shortcut_add_msg.content) - assert shortcut_add_msg.replies - assert handler.active_model_key == shortcut_key - assert handler.ai_client.config.model_name == shortcut_model - assert handler.ai_client.config.provider == old_config.provider - assert handler.ai_client.config.api_base == old_config.api_base - assert handler.ai_client.config.api_key == old_config.api_key - - shortcut_remove_msg = FakeMessage(f"/models remove {shortcut_key}") - await handler._handle_command(shortcut_remove_msg, shortcut_remove_msg.content) - assert shortcut_remove_msg.replies - assert shortcut_key not in handler.model_profiles - - remove_msg = FakeMessage("/models remove roleplay_llm") - await handler._handle_command(remove_msg, remove_msg.content) - assert remove_msg.replies - assert "roleplay_llm" not in handler.model_profiles - - _safe_unlink(handler.model_profiles_path) - print("ok: model add/switch/remove command") - - -async def _test_memory_commands(): - print("=== test_memory_commands ===") - - handler = make_handler() - user_id = "test_user" - - await handler.ai_client.clear_all_memory(user_id) - - add_msg = FakeMessage("/memory add this is a long-term memory test") - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - assert "已新增长期记忆" in add_msg.replies[-1] - - memory_id = add_msg.replies[-1].split(": ", 1)[1].split(" ", 1)[0] - assert memory_id - - list_msg = FakeMessage("/memory list 5") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies - assert memory_id in list_msg.replies[-1] - - get_msg = FakeMessage(f"/memory get {memory_id}") - await handler._handle_command(get_msg, get_msg.content) - assert get_msg.replies - assert memory_id in get_msg.replies[-1] - - search_msg = FakeMessage("/memory search 长期记忆") - await handler._handle_command(search_msg, search_msg.content) - assert search_msg.replies - assert memory_id in search_msg.replies[-1] - - update_msg = FakeMessage(f"/memory update {memory_id} 这是更新后的长期记忆") - await handler._handle_command(update_msg, update_msg.content) - assert update_msg.replies - assert "已更新长期记忆" in update_msg.replies[-1] - - # Build short-term memory then clear only short-term. - await handler.ai_client.memory.add_message( - user_id=user_id, - role="user", - content="short memory for clear short test", - ) - assert handler.ai_client.memory.short_term.get(user_id) - - clear_short_msg = FakeMessage("/clear short") - await handler._handle_command(clear_short_msg, clear_short_msg.content) - assert clear_short_msg.replies - assert not handler.ai_client.memory.short_term.get(user_id) - - # Long-term memory should still exist after clearing short-term only. - still_exists = await handler.ai_client.get_long_term_memory(user_id, memory_id) - assert still_exists is not None - - delete_msg = FakeMessage(f"/memory delete {memory_id}") - await handler._handle_command(delete_msg, delete_msg.content) - assert delete_msg.replies - assert "已删除长期记忆" in delete_msg.replies[-1] - - removed = await handler.ai_client.get_long_term_memory(user_id, memory_id) - assert removed is None - - print("ok: memory command CRUD + clear short") - - -async def _test_plain_text_output(): - print("=== test_plain_text_output ===") - - handler = make_handler() - md_text = "# 标题\n**加粗** 和 `代码`\n- 列表\n[链接](https://example.com)" - plain = handler._plain_text(md_text) - - assert "#" not in plain - assert "**" not in plain - assert "`" not in plain - assert "[" not in plain - assert "](" not in plain - print("ok: markdown stripped") - - -async def _test_skills_creator_autoload(): - print("=== test_skills_creator_autoload ===") - - from types import SimpleNamespace - - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) - handler = MessageHandler(fake_bot) - handler.model_profiles_path = TEST_DATA_DIR / "models_autoload_test.json" - _safe_unlink(handler.model_profiles_path) - await handler._init_ai() - - assert handler.skills_manager is not None - assert "skills_creator" in handler.skills_manager.list_skills() - - tool_names = [tool.name for tool in handler.ai_client.tools.list()] - assert "skills_creator.create_skill" in tool_names - print("ok: skills_creator autoloaded") - - -async def _test_mcp(): - print("=== test_mcp ===") - - from src.ai.mcp import MCPManager - from src.ai.mcp.servers import FileSystemMCPServer - - manager = MCPManager(Path("config/mcp.json")) - fs_server = FileSystemMCPServer(root_path=Path("data")) - await manager.register_server(fs_server) - - tools = await manager.get_all_tools_for_ai() - assert len(tools) >= 1 - print("ok: mcp tools", len(tools)) - - -async def _test_long_task(): - print("=== test_long_task ===") - - client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - - async def step1(): - await asyncio.sleep(0.1) - return "step1" - - async def step2(): - await asyncio.sleep(0.1) - return "step2" - - client.task_manager.register_action("step1", step1) - client.task_manager.register_action("step2", step2) - - task_id = await client.create_long_task( - user_id="test_user", - title="test", - description="test task", - steps=[ - {"description": "s1", "action": "step1", "params": {}}, - {"description": "s2", "action": "step2", "params": {}}, - ], - ) - - await client.start_task(task_id) - await asyncio.sleep(0.5) - - status = client.get_task_status(task_id) - assert status is not None - assert status["status"] in {"completed", "running"} - print("ok: long task", status["status"]) - - -async def _test_memory_importance_evaluator(): - print("=== test_memory_importance_evaluator ===") - - called = {"value": False} - - async def fake_importance_eval(content, metadata): - called["value"] = True - assert "用户:" in content - assert "助手:" in content - return 0.91 - - store_path = TEST_DATA_DIR / "importance_test.json" - _safe_unlink(store_path) - - memory = MemorySystem( - storage_path=store_path, - importance_evaluator=fake_importance_eval, - use_vector_db=False, - ) - - stored = await memory.add_qa_pair( - user_id="u1", - question="请记住我的昵称是小明", - answer="好的,我记住了你的昵称是小明", - metadata={"source": "test"}, - ) - assert called["value"] - assert stored is not None - assert "用户:" in stored.content - assert "助手:" in stored.content - assert "小明" in stored.content - - long_term = await memory.list_long_term("u1") - assert len(long_term) == 1 - - # add_message 仅写入短期记忆,不触发长期记忆评分写入。 - await memory.add_message(user_id="u1", role="user", content="单条短期消息") - long_term_after_single = await memory.list_long_term("u1") - assert len(long_term_after_single) == 1 - - memory_without_eval = MemorySystem( - storage_path=TEST_DATA_DIR / "importance_fallback_test.json", - use_vector_db=False, - ) - fallback_score = await memory_without_eval._evaluate_importance("任意内容", None) - assert fallback_score == 0.5 - - await memory.close() - await memory_without_eval.close() - _safe_unlink(store_path) - _safe_unlink(TEST_DATA_DIR / "importance_fallback_test.json") - print("ok: memory importance evaluator") - - -def test_basic_chat(): - asyncio.run(_test_basic_chat()) - - -def test_memory(): - asyncio.run(_test_memory()) - - -def test_personality(): - asyncio.run(_test_personality()) - - -def test_skills(): - asyncio.run(_test_skills()) - - -def test_skill_commands(): - asyncio.run(_test_skill_commands()) - - -def test_personality_commands(): - asyncio.run(_test_personality_commands()) - - -def test_model_commands(): - asyncio.run(_test_model_commands()) - - -def test_memory_commands(): - asyncio.run(_test_memory_commands()) - - -def test_plain_text_output(): - asyncio.run(_test_plain_text_output()) - - -def test_skills_creator_autoload(): - asyncio.run(_test_skills_creator_autoload()) - - -def test_mcp(): - asyncio.run(_test_mcp()) - - -def test_long_task(): - asyncio.run(_test_long_task()) - - -def test_memory_importance_evaluator(): - asyncio.run(_test_memory_importance_evaluator()) - - -async def main(): - print("寮€濮?AI 鍔熻兘娴嬭瘯") - - await _test_personality() - await _test_skills() - await _test_skill_commands() - await _test_personality_commands() - await _test_model_commands() - await _test_memory_commands() - await _test_plain_text_output() - await _test_skills_creator_autoload() - await _test_mcp() - await _test_long_task() - await _test_memory_importance_evaluator() - - config = get_ai_config() - if config.api_key: - await _test_basic_chat() - await _test_memory() - else: - print("跳过需要 API Key 的对话/记忆测试") - - print("娴嬭瘯瀹屾垚") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py deleted file mode 100644 index 0669878..0000000 --- a/tests/test_ai_client_forced_tool.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Tests for AIClient forced tool name extraction.""" - -from src.ai.client import AIClient - - -def test_extract_forced_tool_name_full_name(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_unique_prefix(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "please call tool humanizer_zh only" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_compact_prefix_without_underscore(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "调用humanizerzh人性化处理以下文本" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): - tools = [ - "skills_creator.create_skill", - "skills_creator.reload_skill", - ] - message = "please call tool skills_creator" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced is None - - -def test_extract_prefix_limit_from_user_message(): - assert AIClient._extract_prefix_limit("直接返回前100字") == 100 - assert AIClient._extract_prefix_limit("前 256 字") == 256 - assert AIClient._extract_prefix_limit("返回全文") is None - - -def test_extract_processing_payload_with_marker(): - message = "调用humanizer_zh.read_skill_doc人性化处理以下文本:\n第一段。\n第二段。" - payload = AIClient._extract_processing_payload(message) - assert payload == "第一段。\n第二段。" - - -def test_extract_processing_payload_with_generic_pattern(): - message = "请按技能规则优化如下:\n这是待处理文本。" - payload = AIClient._extract_processing_payload(message) - assert payload == "这是待处理文本。" - - -def test_extract_processing_payload_returns_none_when_absent(): - assert AIClient._extract_processing_payload("请调用工具 humanizer_zh.read_skill_doc") is None diff --git a/tests/test_mcp_tool_registration.py b/tests/test_mcp_tool_registration.py deleted file mode 100644 index 90bde10..0000000 --- a/tests/test_mcp_tool_registration.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio -from pathlib import Path - -from src.ai.mcp.base import MCPManager, MCPServer - - -class _DummyMCPServer(MCPServer): - def __init__(self): - super().__init__(name="dummy", version="1.0.0") - - async def initialize(self): - self.register_tool( - name="echo", - description="Echo text", - input_schema={ - "type": "object", - "properties": {"text": {"type": "string"}}, - "required": ["text"], - }, - handler=self.echo, - ) - - async def echo(self, text: str) -> str: - return text - - -def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path): - manager = MCPManager(tmp_path / "mcp.json") - asyncio.run(manager.register_server(_DummyMCPServer())) - - tools = asyncio.run(manager.get_all_tools_for_ai()) - assert len(tools) == 1 - function_info = tools[0]["function"] - assert function_info["name"] == "dummy.echo" - assert function_info["description"] == "Echo text" - assert function_info["parameters"]["required"] == ["text"] - - -def test_mcp_manager_execute_tool(tmp_path: Path): - manager = MCPManager(tmp_path / "mcp.json") - asyncio.run(manager.register_server(_DummyMCPServer())) - - result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"})) - assert result == "hello" diff --git a/tests/test_message_dedup.py b/tests/test_message_dedup.py new file mode 100644 index 0000000..21b9d4e --- /dev/null +++ b/tests/test_message_dedup.py @@ -0,0 +1,22 @@ +from types import SimpleNamespace + +from src.handlers.message_handler_ai import MessageHandler + + +def _build_handler() -> MessageHandler: + fake_bot = SimpleNamespace(robot=SimpleNamespace(id="bot_1", name="TestBot")) + return MessageHandler(fake_bot) + + +def test_message_dedup_by_message_id(): + handler = _build_handler() + msg = SimpleNamespace(id="m1", content="hello", author=SimpleNamespace(id="u1")) + assert handler._is_duplicate_message(msg) is False + assert handler._is_duplicate_message(msg) is True + + +def test_message_dedup_fallback_without_message_id(): + handler = _build_handler() + msg = SimpleNamespace(content="hello", author=SimpleNamespace(id="u1"), group_id="g1") + assert handler._is_duplicate_message(msg) is False + assert handler._is_duplicate_message(msg) is True diff --git a/tests/test_message_handler_text_sanitize.py b/tests/test_message_handler_text_sanitize.py index b41f922..9ce9aa9 100644 --- a/tests/test_message_handler_text_sanitize.py +++ b/tests/test_message_handler_text_sanitize.py @@ -10,16 +10,14 @@ from src.handlers.message_handler_ai import MessageHandler def _handler() -> MessageHandler: - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) + fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot", name="TestBot")) return MessageHandler(fake_bot) def test_plain_text_removes_markdown_link_url(): handler = _handler() text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。" - result = handler._plain_text(text) - assert "Wikipedia" in result assert "http" not in result.lower() @@ -27,9 +25,7 @@ def test_plain_text_removes_markdown_link_url(): def test_plain_text_removes_bare_url(): handler = _handler() text = "访问 https://example.com/path?a=1 或 www.example.org 查看。" - result = handler._plain_text(text) - assert "http" not in result.lower() assert "www." not in result.lower() assert "[链接已省略]" in result diff --git a/tests/test_personality_scope_priority.py b/tests/test_personality_scope_priority.py new file mode 100644 index 0000000..e00d704 --- /dev/null +++ b/tests/test_personality_scope_priority.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from src.ai.personality import PersonalityProfile, PersonalitySystem, PersonalityTrait + + +def _profile(name: str) -> PersonalityProfile: + return PersonalityProfile( + name=name, + description=f"{name} profile", + traits=[PersonalityTrait.FRIENDLY], + speaking_style="plain", + ) + + +def test_scope_priority_session_over_group_over_user_over_global(tmp_path: Path): + cfg = tmp_path / "personalities.json" + state = tmp_path / "personality_state.json" + system = PersonalitySystem(config_path=cfg, state_path=state) + + system.add_personality("p_global", _profile("global")) + system.add_personality("p_user", _profile("user")) + system.add_personality("p_group", _profile("group")) + system.add_personality("p_session", _profile("session")) + + assert system.set_personality("p_global", scope="global") + assert system.set_personality("p_user", scope="user", scope_id="u1") + assert system.set_personality("p_group", scope="group", scope_id="g1") + assert system.set_personality("p_session", scope="session", scope_id="g1:u1") + + profile = system.get_active_personality(user_id="u1", group_id="g1", session_id="g1:u1") + assert profile is not None + assert profile.name == "session" + + profile_no_session = system.get_active_personality(user_id="u1", group_id="g1", session_id="other") + assert profile_no_session is not None + assert profile_no_session.name == "group" + + profile_user_only = system.get_active_personality(user_id="u1", group_id=None, session_id=None) + assert profile_user_only is not None + assert profile_user_only.name == "user" + + profile_global = system.get_active_personality(user_id="u2", group_id=None, session_id=None) + assert profile_global is not None + assert profile_global.name == "global" diff --git a/tests/test_skills_install_source.py b/tests/test_skills_install_source.py deleted file mode 100644 index f27fb36..0000000 --- a/tests/test_skills_install_source.py +++ /dev/null @@ -1,74 +0,0 @@ -import io -from pathlib import Path -import zipfile - -from src.ai.skills.base import SkillsManager - - -def _build_codex_skill_zip_bytes(markdown_text: str, root_name: str = "Humanizer-zh-main") -> bytes: - buffer = io.BytesIO() - with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf: - zf.writestr(f"{root_name}/SKILL.md", markdown_text) - return buffer.getvalue() - - -def test_resolve_network_source_supports_github_git_url(tmp_path: Path): - manager = SkillsManager(tmp_path / "skills") - url, hint_key, subpath = manager._resolve_network_source( - "https://github.com/op7418/Humanizer-zh.git" - ) - - assert url == "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main" - assert hint_key == "humanizer_zh" - assert subpath is None - - -def test_install_skill_from_local_skill_markdown_source(tmp_path: Path): - source_dir = tmp_path / "Humanizer-zh-main" - source_dir.mkdir(parents=True, exist_ok=True) - (source_dir / "SKILL.md").write_text( - "# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n", - encoding="utf-8", - ) - - manager = SkillsManager(tmp_path / "skills") - ok, installed_key = manager.install_skill_from_source(str(source_dir), skill_name="humanizer_zh") - - assert ok - assert installed_key == "humanizer_zh" - installed_dir = tmp_path / "skills" / "humanizer_zh" - assert (installed_dir / "skill.json").exists() - assert (installed_dir / "main.py").exists() - assert (installed_dir / "SKILL.md").exists() - - main_code = (installed_dir / "main.py").read_text(encoding="utf-8") - assert "read_skill_doc" in main_code - skill_text = (installed_dir / "SKILL.md").read_text(encoding="utf-8") - assert "Humanizer-zh" in skill_text - - -def test_install_skill_from_github_git_url_uses_repo_zip_and_markdown_adapter( - tmp_path: Path, monkeypatch -): - manager = SkillsManager(tmp_path / "skills") - zip_bytes = _build_codex_skill_zip_bytes( - "# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n" - ) - captured_urls = [] - - def fake_download(url: str, output_zip: Path): - captured_urls.append(url) - output_zip.write_bytes(zip_bytes) - - monkeypatch.setattr(manager, "_download_zip", fake_download) - - ok, installed_key = manager.install_skill_from_source( - "https://github.com/op7418/Humanizer-zh.git" - ) - - assert ok - assert installed_key == "humanizer_zh" - assert captured_urls == [ - "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main" - ] - assert (tmp_path / "skills" / "humanizer_zh" / "SKILL.md").exists()