diff --git a/.env.example b/.env.example index 104e807..e1d4ea3 100644 --- a/.env.example +++ b/.env.example @@ -1,36 +1,36 @@ -# QQ 机器人配置 -# 可从 https://bot.q.qq.com/open 获取 - -# 机器人 AppID(必填) +# QQ Bot credentials BOT_APPID=your_app_id_here - -# 机器人 AppSecret(必填) BOT_SECRET=your_app_secret_here -# 日志级别: DEBUG / INFO / WARNING / ERROR +# Runtime +APP_ENV=dev LOG_LEVEL=INFO - -# 是否启用沙箱环境 +LOG_FORMAT=text SANDBOX_MODE=False -# ==================== AI 配置 ==================== +# Optional admin allow-list (comma separated user IDs). +# Empty means all users are treated as admin. +BOT_ADMIN_IDS= -# 主模型配置 -# 可选 provider: openai / anthropic / deepseek / qwen +# AI chat model AI_PROVIDER=openai AI_MODEL=gpt-4 AI_API_KEY=your_api_key_here -# 可选,自定义 API 地址 AI_API_BASE=https://api.openai.com/v1 -# 嵌入模型配置(用于长期记忆检索) -# 不配置时将回退为主模型的 embedding 能力(如果可用) +# Embedding model (optional) AI_EMBED_PROVIDER=openai AI_EMBED_MODEL=text-embedding-3-small AI_EMBED_API_KEY= AI_EMBED_API_BASE= -# 向量数据库配置 -# true: 使用 Chroma(推荐) -# false: 使用 JSON 存储 +# Memory storage and retrieval AI_USE_VECTOR_DB=true +AI_USE_QUERY_EMBEDDING=false +AI_MEMORY_SCOPE=session + +# Reliability +AI_CHAT_RETRIES=1 +AI_CHAT_RETRY_BACKOFF_SECONDS=0.8 +MESSAGE_DEDUP_SECONDS=30 +MESSAGE_DEDUP_MAX_SIZE=4096 diff --git a/README.md b/README.md index fa0875f..276fec5 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,16 @@ -# QQbot(AI 聊天机器人) +# QQbot (Memory + Persona Core) -一个基于 `botpy` 的 QQ 机器人项目,支持多模型切换、长期/短期记忆、人设管理、Skills 插件与 MCP 能力。 +QQ 机器人项目,保留并强化两大核心能力: +- `Memory`:短期/长期记忆、检索、清理、会话作用域 +- `Persona`:角色配置、作用域优先级(session > group > user > global) -## 功能概览 +## 主要能力 -- 多模型配置与运行时切换(`/models`) -- 人设增删改切换(`/personality`) -- 短期/长期记忆管理(`/clear`、`/memory`) -- Skills 本地与网络安装/卸载/重载(`/skills`) -- 自动去除 Markdown 格式后再回复(适配 QQ 聊天) +- 多模型配置与运行时切换:`/models` +- 人设管理:`/personality` +- 记忆管理:`/memory`、`/clear` +- QQ 消息安全输出:自动清理 Markdown/URL +- 工程增强:消息去重、失败重试、权限边界、结构化日志 ## 快速开始 @@ -20,7 +22,9 @@ pip install -r requirements.txt 2. 配置环境变量 -复制 `.env.example` 为 `.env`,填写 QQ 机器人和 AI 配置。 +```bash +copy .env.example .env +``` 3. 启动 @@ -28,89 +32,42 @@ pip install -r requirements.txt python main.py ``` -## 命令说明 +## 命令 -### 通用 +- 基础 + - `/help` + - `/clear` `/clear short` `/clear long` `/clear all` +- 人设 + - `/personality` + - `/personality list` + - `/personality set [global|user|group|session]` + - `/personality add ` + - `/personality remove ` +- 模型 + - `/models` + - `/models current` + - `/models add ` + - `/models add [api_base]` + - `/models switch ` + - `/models remove ` +- 记忆 + - `/memory` + - `/memory get ` + - `/memory add ` + - `/memory update ` + - `/memory delete ` + - `/memory search [limit]` -- `/help` -- `/clear`(默认等价 `/clear short`) -- `/clear short` -- `/clear long` -- `/clear all` +## 关键配置 -### 人设命令 - -- `/personality` -- `/personality list` -- `/personality set ` -- `/personality add ` -- `/personality remove ` - -说明: -- `add` 会新增并切换到该人设 -- `Introduction` 会作为人设简介与自定义指令 - -### Skills 命令 - -- `/skills` -- `/skills install [skill_name]` -- `/skills uninstall ` -- `/skills reload ` - -`source` 支持: -- 本地技能名(如 `weather`) -- URL(zip 包) -- GitHub 简写(`owner/repo` 或 `owner/repo#branch`) -- GitHub 仓库 URL(如 `https://github.com/op7418/Humanizer-zh.git`) - -兼容说明: -- 若源中包含标准技能结构(`skill.json` + `main.py`),按原方式安装 -- 若仅包含 `SKILL.md`,会自动生成适配技能并提供 `read_skill_doc` 工具读取文档内容 - -### 模型命令 - -- `/models` -- `/models current` -- `/models add ` -- `/models add [api_base]` -- `/models switch ` -- `/models remove ` - -说明: -- `/models add ` 只替换模型名,沿用当前 API Base 和 API Key - -### 长期记忆命令 - -- `/memory` -- `/memory get ` -- `/memory add ` -- `/memory update ` -- `/memory delete ` -- `/memory search [limit]` - -## 目录结构 - -```text -QQbot/ -├─ src/ -│ ├─ ai/ -│ ├─ handlers/ -│ ├─ core/ -│ └─ utils/ -├─ skills/ -├─ config/ -├─ docs/ -└─ tests/ -``` +- `AI_MEMORY_SCOPE=user|session`:记忆作用域 +- `BOT_ADMIN_IDS`:管理员白名单(逗号分隔) +- `AI_CHAT_RETRIES` / `AI_CHAT_RETRY_BACKOFF_SECONDS`:聊天失败重试 +- `MESSAGE_DEDUP_SECONDS` / `MESSAGE_DEDUP_MAX_SIZE`:消息去重窗口 +- `LOG_FORMAT=text|json`:日志输出格式 ## 测试 ```bash -python -m pytest -q -``` - -如果你使用 conda 环境,请先执行: - -```bash -conda activate qqbot +pytest -q ``` diff --git a/config/mcp.json b/config/mcp.json deleted file mode 100644 index d13402b..0000000 --- a/config/mcp.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "filesystem": { - "enabled": true, - "root_path": "data" - } -} diff --git a/main.py b/main.py index 2d1e162..9be020d 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ """ -QQ机器人主入口 +Project entrypoint. """ + +from __future__ import annotations + import sys from pathlib import Path -# 添加项目根目录到Python路径 project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) @@ -23,10 +25,6 @@ def _sqlite_supports_trigram(sqlite_module) -> bool: def _ensure_sqlite_for_chroma(): - """ - Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma. - On some cloud images, system sqlite lacks trigram support. - """ try: import sqlite3 except Exception: @@ -55,35 +53,8 @@ def _ensure_sqlite_for_chroma(): _ensure_sqlite_for_chroma() -from src.core.bot import MyClient, build_intents -from src.core.config import Config -from src.utils.logger import setup_logger - - -def main(): - """主函数""" - # 设置日志 - logger = setup_logger() - - try: - # 验证配置 - Config.validate() - logger.info("配置验证通过") - - # 创建并启动机器人(最小权限,避免 4014 disallowed intents) - logger.info("正在启动QQ机器人...") - intents = build_intents() - client = MyClient(intents=intents) - client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) - - except ValueError as e: - logger.error(f"配置错误: {e}") - logger.error("请检查 .env 文件配置") - sys.exit(1) - except Exception as e: - logger.error(f"启动失败: {e}", exc_info=True) - sys.exit(1) +from src.core.bot import main as run_bot_main if __name__ == "__main__": - main() + run_bot_main() diff --git a/requirements.txt b/requirements.txt index a1a0525..295ae36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,12 @@ -# ============================================ -# 核心依赖(必须安装) -# ============================================ +# Core runtime qq-botpy python-dotenv>=1.0.0 -# ============================================ -# AI 功能依赖(可选) -# 如果不需要 AI 对话功能,可以注释掉下面的依赖 -# ============================================ +# AI providers openai>=1.0.0 anthropic>=0.18.0 + +# Memory storage numpy>=1.24.0 -chromadb>=0.4.0 # 向量数据库,用于记忆存储 -pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持 +chromadb>=0.4.0 +pysqlite3-binary>=0.5.3; platform_system != "Windows" diff --git a/skills/cmd_zip_skill/README.md b/skills/cmd_zip_skill/README.md deleted file mode 100644 index f196536..0000000 --- a/skills/cmd_zip_skill/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill/__init__.py b/skills/cmd_zip_skill/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill/main.py b/skills/cmd_zip_skill/main.py deleted file mode 100644 index e673dff..0000000 --- a/skills/cmd_zip_skill/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkillSkill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill/skill.json b/skills/cmd_zip_skill/skill.json deleted file mode 100644 index 7acae47..0000000 --- a/skills/cmd_zip_skill/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465404375/README.md b/skills/cmd_zip_skill_1772465404375/README.md deleted file mode 100644 index 7b863a6..0000000 --- a/skills/cmd_zip_skill_1772465404375/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465404375 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465404375/__init__.py b/skills/cmd_zip_skill_1772465404375/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465404375/main.py b/skills/cmd_zip_skill_1772465404375/main.py deleted file mode 100644 index 37808f8..0000000 --- a/skills/cmd_zip_skill_1772465404375/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465404375 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465404375Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465404375 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465404375/skill.json b/skills/cmd_zip_skill_1772465404375/skill.json deleted file mode 100644 index c046c86..0000000 --- a/skills/cmd_zip_skill_1772465404375/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465404375", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465434774/README.md b/skills/cmd_zip_skill_1772465434774/README.md deleted file mode 100644 index d0516de..0000000 --- a/skills/cmd_zip_skill_1772465434774/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465434774 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465434774/__init__.py b/skills/cmd_zip_skill_1772465434774/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465434774/main.py b/skills/cmd_zip_skill_1772465434774/main.py deleted file mode 100644 index 2af7b12..0000000 --- a/skills/cmd_zip_skill_1772465434774/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465434774 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465434774Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465434774 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465434774/skill.json b/skills/cmd_zip_skill_1772465434774/skill.json deleted file mode 100644 index 60148e6..0000000 --- a/skills/cmd_zip_skill_1772465434774/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465434774", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465467809/README.md b/skills/cmd_zip_skill_1772465467809/README.md deleted file mode 100644 index f42c1d7..0000000 --- a/skills/cmd_zip_skill_1772465467809/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465467809 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465467809/__init__.py b/skills/cmd_zip_skill_1772465467809/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465467809/main.py b/skills/cmd_zip_skill_1772465467809/main.py deleted file mode 100644 index a61bd41..0000000 --- a/skills/cmd_zip_skill_1772465467809/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465467809 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465467809Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465467809 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465467809/skill.json b/skills/cmd_zip_skill_1772465467809/skill.json deleted file mode 100644 index dc7fc87..0000000 --- a/skills/cmd_zip_skill_1772465467809/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465467809", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465652075/README.md b/skills/cmd_zip_skill_1772465652075/README.md deleted file mode 100644 index 39ed797..0000000 --- a/skills/cmd_zip_skill_1772465652075/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465652075 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465652075/__init__.py b/skills/cmd_zip_skill_1772465652075/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465652075/main.py b/skills/cmd_zip_skill_1772465652075/main.py deleted file mode 100644 index c0dfa8e..0000000 --- a/skills/cmd_zip_skill_1772465652075/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465652075 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465652075Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465652075 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465652075/skill.json b/skills/cmd_zip_skill_1772465652075/skill.json deleted file mode 100644 index 24e2772..0000000 --- a/skills/cmd_zip_skill_1772465652075/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465652075", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465685352/README.md b/skills/cmd_zip_skill_1772465685352/README.md deleted file mode 100644 index 42f156b..0000000 --- a/skills/cmd_zip_skill_1772465685352/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465685352 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465685352/__init__.py b/skills/cmd_zip_skill_1772465685352/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465685352/main.py b/skills/cmd_zip_skill_1772465685352/main.py deleted file mode 100644 index cff99cf..0000000 --- a/skills/cmd_zip_skill_1772465685352/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465685352 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465685352Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465685352 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465685352/skill.json b/skills/cmd_zip_skill_1772465685352/skill.json deleted file mode 100644 index 23cbd37..0000000 --- a/skills/cmd_zip_skill_1772465685352/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465685352", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465936294/README.md b/skills/cmd_zip_skill_1772465936294/README.md deleted file mode 100644 index 43b734c..0000000 --- a/skills/cmd_zip_skill_1772465936294/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465936294 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465936294/__init__.py b/skills/cmd_zip_skill_1772465936294/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465936294/main.py b/skills/cmd_zip_skill_1772465936294/main.py deleted file mode 100644 index 1c25f33..0000000 --- a/skills/cmd_zip_skill_1772465936294/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465936294 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465936294Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465936294 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465936294/skill.json b/skills/cmd_zip_skill_1772465936294/skill.json deleted file mode 100644 index dcb35e7..0000000 --- a/skills/cmd_zip_skill_1772465936294/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465936294", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772465966322/README.md b/skills/cmd_zip_skill_1772465966322/README.md deleted file mode 100644 index d8449f7..0000000 --- a/skills/cmd_zip_skill_1772465966322/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772465966322 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772465966322/__init__.py b/skills/cmd_zip_skill_1772465966322/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772465966322/main.py b/skills/cmd_zip_skill_1772465966322/main.py deleted file mode 100644 index f0b63f6..0000000 --- a/skills/cmd_zip_skill_1772465966322/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772465966322 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772465966322Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772465966322 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772465966322/skill.json b/skills/cmd_zip_skill_1772465966322/skill.json deleted file mode 100644 index 6635930..0000000 --- a/skills/cmd_zip_skill_1772465966322/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772465966322", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/cmd_zip_skill_1772466071278/README.md b/skills/cmd_zip_skill_1772466071278/README.md deleted file mode 100644 index a8b7ecf..0000000 --- a/skills/cmd_zip_skill_1772466071278/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# cmd_zip_skill_1772466071278 - -## 描述 -zip skill - -## 工具 -- example_tool(text) diff --git a/skills/cmd_zip_skill_1772466071278/__init__.py b/skills/cmd_zip_skill_1772466071278/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/cmd_zip_skill_1772466071278/main.py b/skills/cmd_zip_skill_1772466071278/main.py deleted file mode 100644 index e1180e3..0000000 --- a/skills/cmd_zip_skill_1772466071278/main.py +++ /dev/null @@ -1,13 +0,0 @@ -"""cmd_zip_skill_1772466071278 skill""" -from src.ai.skills.base import Skill - - -class CmdZipSkill1772466071278Skill(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"cmd_zip_skill_1772466071278 收到: {text}" - - async def cleanup(self): - pass diff --git a/skills/cmd_zip_skill_1772466071278/skill.json b/skills/cmd_zip_skill_1772466071278/skill.json deleted file mode 100644 index e1fa141..0000000 --- a/skills/cmd_zip_skill_1772466071278/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "cmd_zip_skill_1772466071278", - "version": "1.0.0", - "description": "zip skill", - "author": "test", - "dependencies": [], - "enabled": false -} \ No newline at end of file diff --git a/skills/skills_creator/README.md b/skills/skills_creator/README.md deleted file mode 100644 index 5e9eb9d..0000000 --- a/skills/skills_creator/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# skills_creator - -Tools exposed to AI: -- create_skill -- write_skill_main -- update_skill_metadata -- delete_skill -- load_skill -- reload_skill -- list_skills - -This skill enables creating and managing other skills from chat-driven AI tool calls. \ No newline at end of file diff --git a/skills/skills_creator/__init__.py b/skills/skills_creator/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/skills/skills_creator/main.py b/skills/skills_creator/main.py deleted file mode 100644 index a101f7f..0000000 --- a/skills/skills_creator/main.py +++ /dev/null @@ -1,147 +0,0 @@ -"""skills_creator skill""" - -import json -from pathlib import Path -import shutil -from typing import Any, Dict - -from src.ai.skills.base import Skill, SkillsManager, create_skill_template - - -class SkillsCreatorSkill(Skill): - """Provide tools to create and manage local skills from chat.""" - - def __init__(self): - super().__init__() - self.skills_root = Path("skills") - self.protected_skills = {"skills_creator"} - - async def initialize(self): - self.register_tool("create_skill", self.create_skill) - self.register_tool("write_skill_main", self.write_skill_main) - self.register_tool("update_skill_metadata", self.update_skill_metadata) - self.register_tool("delete_skill", self.delete_skill) - self.register_tool("load_skill", self.load_skill) - self.register_tool("reload_skill", self.reload_skill) - self.register_tool("list_skills", self.list_skills) - - def _skill_key(self, skill_name: str) -> str: - return SkillsManager.normalize_skill_key(skill_name) - - def _skill_path(self, skill_name: str) -> Path: - return self.skills_root / self._skill_key(skill_name) - - async def create_skill( - self, - skill_name: str, - description: str = "custom skill", - author: str = "chat_user", - auto_load: bool = True, - overwrite: bool = False, - ) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - - if skill_path.exists() and not overwrite: - return f"技能已存在: {skill_key}" - - if skill_path.exists() and overwrite: - shutil.rmtree(skill_path) - - create_skill_template(skill_key, self.skills_root, description=description, author=author) - - if auto_load and self.manager: - await self.manager.load_skill(skill_key) - - return f"已创建技能: {skill_key}" - - async def write_skill_main(self, skill_name: str, code: str, auto_reload: bool = True) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - if not skill_path.exists(): - return f"技能不存在: {skill_key}" - - main_file = skill_path / "main.py" - main_file.write_text(code, encoding="utf-8") - - if auto_reload and self.manager: - await self.manager.reload_skill(skill_key) - - return f"已更新技能代码: {skill_key}/main.py" - - async def update_skill_metadata(self, skill_name: str, fields: Dict[str, Any]) -> str: - skill_key = self._skill_key(skill_name) - skill_path = self._skill_path(skill_key) - if not skill_path.exists(): - return f"技能不存在: {skill_key}" - - metadata_file = skill_path / "skill.json" - if metadata_file.exists(): - metadata = json.loads(metadata_file.read_text(encoding="utf-8")) - else: - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": "", - "author": "chat_user", - "dependencies": [], - "enabled": True, - } - - for key, value in fields.items(): - metadata[key] = value - - metadata.setdefault("name", skill_key) - metadata.setdefault("version", "1.0.0") - metadata.setdefault("description", "") - metadata.setdefault("author", "chat_user") - metadata.setdefault("dependencies", []) - metadata.setdefault("enabled", True) - - metadata_file.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8") - return f"已更新技能元数据: {skill_key}/skill.json" - - async def delete_skill(self, skill_name: str, delete_files: bool = True) -> str: - skill_key = self._skill_key(skill_name) - if skill_key in self.protected_skills: - return f"拒绝删除受保护技能: {skill_key}" - - if self.manager: - removed = await self.manager.uninstall_skill(skill_key, delete_files=delete_files) - return f"删除结果({skill_key}): {removed}" - - skill_path = self._skill_path(skill_key) - if delete_files and skill_path.exists(): - shutil.rmtree(skill_path) - return f"已删除技能目录: {skill_key}" - - return f"技能不存在或未删除: {skill_key}" - - async def load_skill(self, skill_name: str) -> str: - skill_key = self._skill_key(skill_name) - if not self.manager: - return "SkillsManager 不可用" - - success = await self.manager.load_skill(skill_key) - return f"加载技能 {skill_key}: {success}" - - async def reload_skill(self, skill_name: str) -> str: - skill_key = self._skill_key(skill_name) - if not self.manager: - return "SkillsManager 不可用" - - success = await self.manager.reload_skill(skill_key) - return f"重载技能 {skill_key}: {success}" - - async def list_skills(self) -> str: - if not self.manager: - return "SkillsManager 不可用" - - payload = { - "loaded": self.manager.list_skills(), - "available": self.manager.list_available_skills(), - } - return json.dumps(payload, ensure_ascii=False) - - async def cleanup(self): - pass \ No newline at end of file diff --git a/skills/skills_creator/skill.json b/skills/skills_creator/skill.json deleted file mode 100644 index 7875637..0000000 --- a/skills/skills_creator/skill.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "name": "skills_creator", - "version": "1.0.0", - "description": "Create, edit, load, and remove skills from chat", - "author": "QQBot", - "dependencies": [], - "enabled": true -} \ No newline at end of file diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 7650318..32a0906 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,14 +1,7 @@ -""" -AI模块 - 提供AI模型接入、人格系统、记忆系统和长任务处理能力 -""" -from .client import AIClient -from .personality import PersonalitySystem -from .memory import MemorySystem -from .task_manager import LongTaskManager +"""AI package exports.""" -__all__ = [ - 'AIClient', - 'PersonalitySystem', - 'MemorySystem', - 'LongTaskManager' -] +from .client import AIClient +from .memory import MemorySystem +from .personality import PersonalitySystem + +__all__ = ["AIClient", "MemorySystem", "PersonalitySystem"] diff --git a/src/ai/client.py b/src/ai/client.py index 6838621..ac92a8f 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -1,98 +1,97 @@ -""" -AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘 """ -import inspect +Unified AI client for chat, memory and persona. +""" + +from __future__ import annotations + +import asyncio import json import re -from typing import List, Optional, Dict, Any, AsyncIterator, Tuple from pathlib import Path -from .base import ModelConfig, ModelProvider, Message, ToolRegistry -from .models import OpenAIModel, AnthropicModel -from .personality import PersonalitySystem +from typing import Any, Dict, List, Optional + +import httpx + +from .base import Message, ModelConfig, ModelProvider from .memory import MemorySystem -from .task_manager import LongTaskManager +from .models import AnthropicModel, OpenAIModel +from .personality import PersonalitySystem from src.utils.logger import setup_logger -logger = setup_logger('AIClient') +logger = setup_logger("AIClient") class AIClient: - """AI瀹㈡埛绔?- 缁熶竴鎺ュ彛""" - + """High-level application service for chat and memory/persona orchestration.""" + def __init__( self, model_config: ModelConfig, embed_config: Optional[ModelConfig] = None, data_dir: Path = Path("data/ai"), - use_vector_db: bool = True + use_vector_db: bool = True, + use_query_embedding: bool = False, + chat_retries: int = 1, + chat_retry_backoff: float = 0.8, ): self.config = model_config self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) - - # 初始化主模型 + + self.chat_retries = max(0, int(chat_retries)) + self.chat_retry_backoff = max(0.0, float(chat_retry_backoff)) + self.model = self._create_model(model_config) - - # 初始化嵌入模型(如果提供) - self.embed_model = None - if embed_config: - self.embed_model = self._create_model(embed_config) - logger.info( - f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}" - ) - - # 初始化工具注册表 - self.tools = ToolRegistry() - self._tool_sources: Dict[str, str] = {} - - # 初始化人格系统 + self.embed_model = self._create_model(embed_config) if embed_config else None + self.personality = PersonalitySystem( - config_path=data_dir / "personalities.json" + config_path=data_dir / "personalities.json", + state_path=data_dir / "personality_state.json", ) - - # 初始化记忆系统 + self.memory = MemorySystem( storage_path=data_dir / "long_term_memory.json", embed_func=self._embed_wrapper, importance_evaluator=self._evaluate_memory_importance, - use_vector_db=use_vector_db + use_vector_db=use_vector_db, + use_query_embedding=use_query_embedding, ) - - # 初始化长任务管理器 - self.task_manager = LongTaskManager( - storage_path=data_dir / "tasks.json" - ) - + logger.info( - f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}" + "AI client initialized", + extra={ + "provider": model_config.provider.value, + "model": model_config.model_name, + "use_vector_db": use_vector_db, + "use_query_embedding": use_query_embedding, + "chat_retries": self.chat_retries, + }, ) - - def _create_model(self, config: ModelConfig): - """创建模型实例。""" - if config.provider == ModelProvider.OPENAI: + + def _create_model(self, config: Optional[ModelConfig]): + if config is None: + return None + + if config.provider in { + ModelProvider.OPENAI, + ModelProvider.DEEPSEEK, + ModelProvider.QWEN, + }: return OpenAIModel(config) - elif config.provider == ModelProvider.ANTHROPIC: + if config.provider == ModelProvider.ANTHROPIC: return AnthropicModel(config) - elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]: - # DeepSeek 和 Qwen 使用 OpenAI 兼容接口 - return OpenAIModel(config) - else: - raise ValueError(f"不支持的模型提供商: {config.provider}") - + raise ValueError(f"Unsupported model provider: {config.provider}") + async def _embed_wrapper(self, text: str) -> List[float]: - """嵌入向量包装器。""" try: - # 如果有独立的嵌入模型,优先使用 if self.embed_model: return await self.embed_model.embed(text) - # 否则尝试使用主模型 return await self.model.embed(text) except NotImplementedError: - # 如果都不支持嵌入,返回 None(记忆系统会降级) - logger.warning("Current model does not support embeddings; vector retrieval disabled") + logger.warning("Current model does not support embeddings; fallback to local embedding.") return None - except Exception as e: - logger.error(f"生成嵌入向量失败: {e}") + except Exception as exc: + logger.warning(f"Embedding generation failed: {exc}") return None @staticmethod @@ -120,17 +119,12 @@ class AIClient: async def _evaluate_memory_importance( self, content: str, metadata: Optional[Dict] = None ) -> float: - """ - 调用主模型评估记忆重要性,返回 [0, 1] 分值。 - """ system_prompt = ( - "你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。" - "输出一个 0 到 1 的数字,数字越大表示越重要。" - "只输出数字,不要输出任何解释、单位或多余文本。" + "You evaluate if content should be kept as long-term memory. " + "Return only a float between 0 and 1." ) payload = json.dumps( - {"content": content, "metadata": metadata or {}}, - ensure_ascii=False, + {"content": content, "metadata": metadata or {}}, ensure_ascii=False ) messages = [ Message(role="system", content=system_prompt), @@ -146,636 +140,181 @@ class AIClient: ) score = self._parse_importance_score(response.content) return max(0.0, min(1.0, score)) - except Exception as e: - logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}") + except Exception as exc: + logger.warning(f"Memory importance evaluation failed, fallback to 0.5: {exc}") return 0.5 - + + @staticmethod + def _preview_log_payload(payload: Any, max_len: int = 240) -> str: + try: + text = json.dumps(payload, ensure_ascii=False, default=str) + except Exception: + text = str(payload) + if len(text) > max_len: + return text[:max_len] + "..." + return text + + @staticmethod + def _is_retriable_error(exc: Exception) -> bool: + if isinstance(exc, (httpx.ReadTimeout, httpx.ConnectError, asyncio.TimeoutError)): + return True + + lower = str(exc).lower() + retriable_signals = [ + "timed out", + "timeout", + "temporarily unavailable", + "connection reset", + "service unavailable", + "rate limit", + "429", + "502", + "503", + "504", + ] + return any(signal in lower for signal in retriable_signals) + + async def _chat_with_retry(self, messages: List[Message], **kwargs): + attempts = 1 + self.chat_retries + last_error: Optional[Exception] = None + + for attempt in range(1, attempts + 1): + try: + return await self.model.chat(messages, None, **kwargs) + except Exception as exc: + last_error = exc + if attempt >= attempts or not self._is_retriable_error(exc): + break + + sleep_seconds = self.chat_retry_backoff * (2 ** (attempt - 1)) + logger.warning( + "Chat request failed, retrying", + extra={ + "attempt": attempt, + "max_attempts": attempts, + "sleep_seconds": sleep_seconds, + "error": str(exc), + }, + ) + await asyncio.sleep(sleep_seconds) + + if last_error: + raise last_error + raise RuntimeError("chat retry loop ended unexpectedly") + async def chat( self, user_id: str, user_message: str, system_prompt: Optional[str] = None, use_memory: bool = True, - use_tools: bool = True, stream: bool = False, - **kwargs + group_id: Optional[str] = None, + session_id: Optional[str] = None, + memory_key: Optional[str] = None, + **kwargs, ) -> str: - """对话接口。""" - try: - # 构建消息列表 - messages = [] - - # 系统提示词 - if system_prompt is None: - system_prompt = self.personality.get_system_prompt() - - # 注入记忆上下文 - if use_memory: - short_term, long_term = await self.memory.get_context( - user_id=user_id, - query=user_message - ) - - if short_term or long_term: - memory_context = self.memory.format_context(short_term, long_term) - system_prompt += f"\n\n{memory_context}" - - messages.append(Message(role="system", content=system_prompt)) - - # 添加用户消息 - messages.append(Message(role="user", content=user_message)) - - # 准备工具 - tools = None - tool_names: List[str] = [] - if use_tools and self.tools.list(): - tools = self.tools.to_openai_format() - tool_names = [tool.name for tool in self.tools.list()] - forced_tool_name = self._extract_forced_tool_name(user_message, tool_names) - if forced_tool_name: - kwargs = dict(kwargs) - kwargs["forced_tool_name"] = forced_tool_name - if tools: - before_count = len(tools) - tools = [ - tool - for tool in tools - if ((tool.get("function") or {}).get("name") == forced_tool_name) - ] - if len(tools) == 1: - tool_names = [forced_tool_name] - logger.info( - "显式工具调用已收敛工具列表: " - f"{before_count} -> {len(tools)}" - ) - logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") + if stream: + raise NotImplementedError("stream mode is not supported by AIClient.chat") - logger.info( - "LLM请求: " - f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " - f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " - f"tool_names={self._preview_log_payload(tool_names)}, " - f"forced_tool={forced_tool_name or '-'}" + memory_user_key = memory_key or user_id + messages: List[Message] = [] + + if system_prompt is None: + system_prompt = self.personality.get_system_prompt( + user_id=user_id, + group_id=group_id, + session_id=session_id, ) - logger.info( - "LLM输入: " - f"user_message={self._preview_log_payload(user_message)}" + + if use_memory: + short_term, long_term = await self.memory.get_context( + user_id=memory_user_key, + query=user_message, ) - - # 调用模型 - if stream: - return self._chat_stream(messages, tools, **kwargs) - else: - response = await self.model.chat(messages, tools, **kwargs) - response_tool_count = len(response.tool_calls or []) - response_tool_names = [] - for tool_call in response.tool_calls or []: - if isinstance(tool_call, dict): - function_info = tool_call.get("function") or {} - response_tool_names.append(function_info.get("name")) - else: - function_info = getattr(tool_call, "function", None) - response_tool_names.append( - getattr(function_info, "name", None) if function_info else None - ) - logger.info( - "LLM首轮输出: " - f"tool_calls={response_tool_count}, " - f"tool_names={self._preview_log_payload(response_tool_names)}, " - f"content={self._preview_log_payload(response.content)}" - ) - - # 处理工具调用 - if response.tool_calls: - response = await self._handle_tool_calls( - messages, response, tools, **kwargs - ) - elif forced_tool_name: - forced_response = await self._run_forced_tool_fallback( - forced_tool_name=forced_tool_name, - user_message=user_message, - ) - if forced_response is not None: - response = forced_response - - # 写入记忆 - if use_memory: - stored_memory = await self.memory.add_qa_pair( - user_id=user_id, - question=user_message, - answer=response.content, - metadata={"source": "chat"}, - ) - if stored_memory: - logger.info( - "已写入长期记忆问答对:\n" - f"{stored_memory.content}\n" - f"memory_id={stored_memory.id}, " - f"importance={stored_memory.importance:.2f}" - ) - - return response.content - - except Exception as e: - logger.error(f"对话失败: {type(e).__name__}: {e!r}") - raise + if short_term or long_term: + memory_context = self.memory.format_context(short_term, long_term) + system_prompt = f"{system_prompt}\n\n{memory_context}".strip() - async def _run_forced_tool_fallback( - self, forced_tool_name: str, user_message: str - ) -> Optional[Message]: - """Execute forced tool locally when model did not emit tool_calls.""" - tool_def = self.tools.get(forced_tool_name) - tool_source = self._tool_sources.get(forced_tool_name, "custom") - if not tool_def: - logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}") - return None - - logger.warning( - "模型未返回 tool_calls,启用本地强制工具执行: " - f"source={tool_source}, name={forced_tool_name}" - ) - - try: - result = tool_def.function() - if inspect.isawaitable(result): - result = await result - except TypeError as exc: - logger.warning( - "本地强制工具执行失败(参数不匹配): " - f"name={forced_tool_name}, error={exc}" - ) - return None - except Exception as exc: - logger.warning( - "本地强制工具执行失败: " - f"name={forced_tool_name}, error={exc}" - ) - return None - - result_text = str(result) - pipelined_text = await self._run_skill_doc_pipeline( - forced_tool_name=forced_tool_name, - skill_doc=result_text, - user_message=user_message, - ) - if pipelined_text is not None: - result_text = pipelined_text - - prefix_limit = self._extract_prefix_limit(user_message) - if prefix_limit: - result_text = result_text[:prefix_limit] + messages.append(Message(role="system", content=system_prompt)) + messages.append(Message(role="user", content=user_message)) logger.info( - "本地强制工具执行成功: " - f"source={tool_source}, name={forced_tool_name}, " - f"result={self._preview_log_payload(result_text)}" + "LLM request", + extra={ + "user_id": user_id, + "group_id": group_id, + "session_id": session_id, + "memory_key": memory_user_key, + "use_memory": use_memory, + "message_preview": self._preview_log_payload(user_message), + }, ) - return Message(role="assistant", content=result_text) - - async def _run_skill_doc_pipeline( - self, forced_tool_name: str, skill_doc: str, user_message: str - ) -> Optional[str]: - """Run an extra model step: execute instructions from skill doc on user text.""" - if not forced_tool_name.endswith(".read_skill_doc"): - return None - - target_text = self._extract_processing_payload(user_message) - if not target_text: - return None + response = await self._chat_with_retry(messages, **kwargs) logger.info( - "强制工具后续处理开始: " - f"name={forced_tool_name}, target_len={len(target_text)}" + "LLM response", + extra={"content_preview": self._preview_log_payload(response.content)}, ) - messages = [ - Message( - role="system", - content=( - "你是技能执行器。请严格按下面技能文档处理用户文本。" - "不要复述技能文档,不要解释工具调用过程,只输出最终处理结果。\n\n" - "[技能文档开始]\n" - f"{skill_doc}\n" - "[技能文档结束]" - ), - ), - Message( - role="user", - content=( - "请根据技能文档处理以下文本,保持原意并提升自然度:\n" - f"{target_text}" - ), - ), - ] - - try: - response = await self.model.chat(messages=messages, tools=None) - content = (response.content or "").strip() - if not content: - return None - - logger.info( - "强制工具后续处理完成: " - f"name={forced_tool_name}, output_len={len(content)}" + if use_memory: + stored_memory = await self.memory.add_qa_pair( + user_id=memory_user_key, + question=user_message, + answer=response.content, + metadata={ + "source": "chat", + "user_id": user_id, + "group_id": group_id, + "session_id": session_id, + }, ) - return content - except Exception as exc: - logger.warning( - "强制工具后续处理失败,回退为工具原始输出: " - f"name={forced_tool_name}, error={exc}" - ) - return None - - async def _chat_stream( - self, - messages: List[Message], - tools: Optional[List[Dict]], - **kwargs - ) -> AsyncIterator[str]: - """流式对话。""" - async for chunk in self.model.chat_stream(messages, tools, **kwargs): - yield chunk - - async def _handle_tool_calls( - self, - messages: List[Message], - response: Message, - tools: Optional[List[Dict]], - **kwargs - ) -> Message: - """处理工具调用。""" - messages.append(response) - total_calls = len(response.tool_calls or []) - if total_calls: - logger.info(f"检测到工具调用请求: {total_calls} 个") - - # 执行工具调用 - for tool_call in response.tool_calls or []: - try: - tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) - except Exception as e: - logger.warning(f"解析工具调用失败: {e}") - fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None) - if fallback_id: - messages.append(Message( - role="tool", - content=f"工具参数解析失败: {str(e)}", - tool_call_id=fallback_id, - name="tool" - )) - continue - if not tool_name: - logger.warning(f"跳过无效工具调用: {tool_call}") - continue - - tool_def = self.tools.get(tool_name) - tool_source = self._tool_sources.get(tool_name, "custom") - if not tool_def: - error_msg = f"未找到工具: {tool_name}" - logger.warning(error_msg) - messages.append(Message( - role="tool", - name=tool_name, - content=error_msg, - tool_call_id=tool_call_id - )) - continue - - try: + if stored_memory: logger.info( - "工具调用开始: " - f"source={tool_source}, name={tool_name}, " - f"args={self._preview_log_payload(tool_args)}" + "Long-term memory stored", + extra={ + "memory_id": stored_memory.id, + "importance": stored_memory.importance, + }, ) - result = tool_def.function(**tool_args) - if inspect.isawaitable(result): - result = await result - logger.info( - "工具调用成功: " - f"source={tool_source}, name={tool_name}, " - f"result={self._preview_log_payload(result)}" - ) - messages.append(Message( - role="tool", - name=tool_name, - content=str(result), - tool_call_id=tool_call_id - )) - except Exception as e: - logger.warning( - "工具调用失败: " - f"source={tool_source}, name={tool_name}, error={e}" - ) - messages.append(Message( - role="tool", - name=tool_name, - content=f"工具执行失败: {str(e)}", - tool_call_id=tool_call_id - )) - # 再次调用模型获取最终响应 - final_kwargs = dict(kwargs) - # Force only the first model turn, avoid recursive force after tool result. - final_kwargs.pop("forced_tool_name", None) - final_response = await self.model.chat(messages, tools, **final_kwargs) - logger.info( - "LLM最终输出: " - f"content={self._preview_log_payload(final_response.content)}" + return response.content + + def set_personality( + self, personality_name: str, scope: str = "global", scope_id: Optional[str] = None + ) -> bool: + return self.personality.set_personality( + key=personality_name, + scope=scope, + scope_id=scope_id, ) - return final_response - def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]: - """兼容不同 SDK 返回的工具调用结构。""" - if isinstance(tool_call, dict): - tool_call_id = tool_call.get('id') - function = tool_call.get('function') or {} - tool_name = function.get('name') - raw_args = function.get('arguments') - else: - tool_call_id = getattr(tool_call, 'id', None) - function = getattr(tool_call, 'function', None) - tool_name = getattr(function, 'name', None) if function else None - raw_args = getattr(function, 'arguments', None) if function else None - - tool_args = self._normalize_tool_args(raw_args) - return tool_name, tool_args, tool_call_id - - def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]: - """将工具参数统一转换为字典。""" - if raw_args is None: - return {} - - if isinstance(raw_args, dict): - return raw_args - - if isinstance(raw_args, str): - raw_args = raw_args.strip() - if not raw_args: - return {} - parsed = json.loads(raw_args) - if not isinstance(parsed, dict): - raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}") - return parsed - - if hasattr(raw_args, 'model_dump'): - parsed = raw_args.model_dump() - if isinstance(parsed, dict): - return parsed - - raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") - - @staticmethod - def _preview_log_payload(payload: Any, max_len: int = 240) -> str: - """日志中展示参数/结果时使用的简短预览。""" - try: - text = json.dumps(payload, ensure_ascii=False, default=str) - except Exception: - text = str(payload) - - if len(text) > max_len: - return text[:max_len] + "..." - return text - - @staticmethod - def _extract_prefix_limit(user_message: str) -> Optional[int]: - """Extract requested output prefix length like '前100字'.""" - if not user_message: - return None - - match = re.search(r"前\s*(\d{1,4})\s*字", user_message) - if not match: - return None - - try: - limit = int(match.group(1)) - except (TypeError, ValueError): - return None - - if limit <= 0: - return None - return min(limit, 5000) - - @staticmethod - def _extract_processing_payload(user_message: str) -> Optional[str]: - """Extract text payload like '处理以下文本:...' from user message.""" - if not user_message: - return None - - text = user_message.strip() - markers = [ - "以下文本:", - "以下文本:", - "文本:", - "文本:", - ] - for marker in markers: - idx = text.find(marker) - if idx < 0: - continue - payload = text[idx + len(marker) :].strip() - if payload: - return payload - - pattern = re.compile( - r"(?:处理|润色|改写|人性化处理|优化)[\s\S]{0,32}(?:如下|以下)[::]\s*([\s\S]+)$" - ) - match = pattern.search(text) - if match: - payload = (match.group(1) or "").strip() - if payload: - return payload - - return None - - @staticmethod - def _compact_identifier(text: str) -> str: - """Compact identifier for fuzzy matching (e.g. humanizer_zh -> humanizerzh).""" - return re.sub(r"[^a-z0-9]+", "", (text or "").lower()) - - @staticmethod - def _extract_forced_tool_name( - user_message: str, available_tool_names: List[str] - ) -> Optional[str]: - if not user_message or not available_tool_names: - return None - - triggers = [ - "调用工具", - "使用工具", - "只调用", - "务必调用", - "必须调用", - "调用", - "使用", - "tool", - ] - if not any(trigger in user_message for trigger in triggers): - return None - - pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)") - explicit_matches = [ - name for name in pattern.findall(user_message) if name in available_tool_names - ] - if len(explicit_matches) == 1: - return explicit_matches[0] - if len(explicit_matches) > 1: - return None - - contained = [name for name in available_tool_names if name in user_message] - if len(contained) == 1: - return contained[0] - - # 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。 - prefixes = sorted( - {name.split(".", 1)[0] for name in available_tool_names}, - key=len, - reverse=True, - ) - matched_prefixes = [ - prefix - for prefix in prefixes - if re.search(rf"\b{re.escape(prefix)}\b", user_message) - ] - if len(matched_prefixes) == 1: - prefix_tools = [ - name - for name in available_tool_names - if name.startswith(f"{matched_prefixes[0]}.") - ] - if len(prefix_tools) == 1: - return prefix_tools[0] - - # 模糊匹配:支持省略下划线/点号的写法(如 humanizerzh)。 - compact_message = AIClient._compact_identifier(user_message) - if compact_message: - compact_full_matches = [] - for tool_name in available_tool_names: - compact_tool_name = AIClient._compact_identifier(tool_name) - if compact_tool_name and compact_tool_name in compact_message: - compact_full_matches.append(tool_name) - - if len(compact_full_matches) == 1: - return compact_full_matches[0] - if len(compact_full_matches) > 1: - return None - - compact_prefix_map: Dict[str, List[str]] = {} - for tool_name in available_tool_names: - prefix = tool_name.split(".", 1)[0] - compact_prefix = AIClient._compact_identifier(prefix) - if not compact_prefix: - continue - compact_prefix_map.setdefault(compact_prefix, []).append(tool_name) - - compact_prefix_matches = [ - compact_prefix - for compact_prefix in compact_prefix_map - if compact_prefix in compact_message - ] - if len(compact_prefix_matches) == 1: - matched_tools = compact_prefix_map[compact_prefix_matches[0]] - if len(matched_tools) == 1: - return matched_tools[0] - - return None - - def set_personality(self, personality_name: str) -> bool: - """设置人格。""" - return self.personality.set_personality(personality_name) - def list_personalities(self) -> List[str]: - """列出所有人格。""" return self.personality.list_personalities() def switch_model(self, model_config: ModelConfig) -> bool: - """Runtime switch for primary chat model.""" - new_model = self._create_model(model_config) - self.model = new_model + self.model = self._create_model(model_config) self.config = model_config logger.info( - f"已切换主模型: {model_config.provider.value}/{model_config.model_name}" + "Primary model switched", + extra={ + "provider": model_config.provider.value, + "model": model_config.model_name, + }, ) return True - async def create_long_task( - self, - user_id: str, - title: str, - description: str, - steps: List[Dict], - metadata: Optional[Dict] = None - ) -> str: - """创建长任务。""" - return self.task_manager.create_task( - user_id=user_id, - title=title, - description=description, - steps=steps, - metadata=metadata - ) - - async def start_task( - self, - task_id: str, - progress_callback: Optional[callable] = None - ): - """启动任务。""" - await self.task_manager.start_task(task_id, progress_callback) - - def get_task_status(self, task_id: str) -> Optional[Dict]: - """获取任务状态。""" - return self.task_manager.get_task_status(task_id) - - def register_tool( - self, - name: str, - description: str, - parameters: Dict, - function: callable, - source: str = "custom", - ): - """注册工具。""" - from .base import ToolDefinition - tool = ToolDefinition( - name=name, - description=description, - parameters=parameters, - function=function - ) - self.tools.register(tool) - self._tool_sources[name] = source - logger.info(f"已注册工具: {name} (source={source})") - - def unregister_tool(self, name: str) -> bool: - """卸载工具。""" - removed = self.tools.unregister(name) - if removed: - self._tool_sources.pop(name, None) - logger.info(f"已卸载工具: {name}") - return removed - - def unregister_tools_by_prefix(self, prefix: str) -> int: - """按前缀批量卸载工具。""" - removed_count = self.tools.unregister_by_prefix(prefix) - for tool_name in list(self._tool_sources.keys()): - if tool_name.startswith(prefix): - self._tool_sources.pop(tool_name, None) - if removed_count: - logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") - return removed_count - def clear_memory(self, user_id: str): - """清除用户短期记忆。""" self.memory.clear_short_term(user_id) - logger.info(f"Cleared short-term memory for user {user_id}") + logger.info(f"Cleared short-term memory for {user_id}") async def clear_long_term_memory(self, user_id: str) -> bool: try: await self.memory.clear_long_term(user_id) - logger.info(f"Cleared long-term memory for user {user_id}") + logger.info(f"Cleared long-term memory for {user_id}") return True - except Exception as e: - logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}") + except Exception as exc: + logger.warning(f"Failed to clear long-term memory for {user_id}: {exc}") return False async def list_long_term_memories(self, user_id: str, limit: int = 20): @@ -823,10 +362,18 @@ class AIClient: return await self.memory.delete_long_term(user_id, memory_id) async def clear_all_memory(self, user_id: str) -> bool: - """清除用户全部记忆(短期 + 长期)。""" self.clear_memory(user_id) - try: - return await self.clear_long_term_memory(user_id) - except Exception: - return False + return await self.clear_long_term_memory(user_id) + async def close(self): + await self.memory.close() + + for model in [self.model, self.embed_model]: + if model is None: + continue + close = getattr(model, "close", None) + if close is None: + continue + maybe_awaitable = close() + if asyncio.iscoroutine(maybe_awaitable): + await maybe_awaitable diff --git a/src/ai/mcp/__init__.py b/src/ai/mcp/__init__.py deleted file mode 100644 index 4bf4500..0000000 --- a/src/ai/mcp/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -MCP模块 -""" -from .base import MCPServer, MCPClient, MCPManager, MCPResource, MCPTool, MCPPrompt - -__all__ = [ - 'MCPServer', - 'MCPClient', - 'MCPManager', - 'MCPResource', - 'MCPTool', - 'MCPPrompt' -] diff --git a/src/ai/mcp/base.py b/src/ai/mcp/base.py deleted file mode 100644 index d375176..0000000 --- a/src/ai/mcp/base.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -MCP (Model Context Protocol) 支持 -""" -import asyncio -import json -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass, asdict -from pathlib import Path -from src.utils.logger import setup_logger - -logger = setup_logger('MCPSystem') - - -@dataclass -class MCPResource: - """MCP资源""" - uri: str - name: str - description: str - mime_type: str - - -@dataclass -class MCPTool: - """MCP工具""" - name: str - description: str - input_schema: Dict[str, Any] - - -@dataclass -class MCPPrompt: - """MCP提示词""" - name: str - description: str - arguments: List[Dict[str, Any]] - - -class MCPServer: - """MCP服务器基类""" - - def __init__(self, name: str, version: str): - self.name = name - self.version = version - self.resources: Dict[str, MCPResource] = {} - self.tools: Dict[str, Callable] = {} - self.tool_specs: Dict[str, MCPTool] = {} - self.prompts: Dict[str, MCPPrompt] = {} - - async def initialize(self): - """初始化服务器""" - pass - - async def shutdown(self): - """关闭服务器""" - pass - - def register_resource(self, resource: MCPResource): - """注册资源""" - self.resources[resource.uri] = resource - - def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): - """注册工具""" - tool = MCPTool(name=name, description=description, input_schema=input_schema) - self.tool_specs[name] = tool - self.tools[name] = handler - logger.info(f"✅ MCP工具注册: {self.name}.{name}") - - def register_prompt(self, prompt: MCPPrompt): - """注册提示词""" - self.prompts[prompt.name] = prompt - - async def list_resources(self) -> List[MCPResource]: - """列出资源""" - return list(self.resources.values()) - - async def read_resource(self, uri: str) -> Optional[str]: - """读取资源""" - raise NotImplementedError - - async def list_tools(self) -> List[MCPTool]: - """列出工具""" - return list(self.tool_specs.values()) - - async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - if name not in self.tools: - raise ValueError(f"工具不存在: {name}") - - handler = self.tools[name] - logger.info( - "MCP工具调用开始: " - f"server={self.name}, tool={name}, " - f"args={json.dumps(arguments, ensure_ascii=False, default=str)}" - ) - try: - result = await handler(**arguments) - except Exception as exc: - logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}") - raise - logger.info(f"MCP工具调用成功: server={self.name}, tool={name}") - return result - - async def list_prompts(self) -> List[MCPPrompt]: - """列出提示词""" - return list(self.prompts.values()) - - async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]: - """获取提示词""" - raise NotImplementedError - - -class MCPClient: - """MCP客户端""" - - def __init__(self): - self.servers: Dict[str, MCPServer] = {} - - async def connect_server(self, server: MCPServer): - """连接服务器""" - await server.initialize() - self.servers[server.name] = server - logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}") - - async def disconnect_server(self, server_name: str): - """断开服务器""" - if server_name in self.servers: - await self.servers[server_name].shutdown() - del self.servers[server_name] - logger.info(f"✅ 断开MCP服务器: {server_name}") - - def get_server(self, name: str) -> Optional[MCPServer]: - """获取服务器""" - return self.servers.get(name) - - def list_servers(self) -> List[str]: - """列出所有服务器""" - return list(self.servers.keys()) - - async def list_all_resources(self) -> Dict[str, List[MCPResource]]: - """列出所有资源""" - result = {} - for name, server in self.servers.items(): - result[name] = await server.list_resources() - return result - - async def list_all_tools(self) -> Dict[str, List[MCPTool]]: - """列出所有工具""" - result = {} - for name, server in self.servers.items(): - result[name] = await server.list_tools() - return result - - async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - server = self.get_server(server_name) - if not server: - raise ValueError(f"服务器不存在: {server_name}") - - return await server.call_tool(tool_name, arguments) - - -class MCPManager: - """MCP管理器""" - - def __init__(self, config_path: Path): - self.config_path = config_path - self.client = MCPClient() - self.server_configs: Dict[str, Dict] = {} - self._load_config() - - def _load_config(self): - """加载配置""" - if self.config_path.exists(): - with open(self.config_path, 'r', encoding='utf-8') as f: - self.server_configs = json.load(f) - - def _save_config(self): - """保存配置""" - self.config_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.config_path, 'w', encoding='utf-8') as f: - json.dump(self.server_configs, f, ensure_ascii=False, indent=2) - - async def register_server(self, server: MCPServer, config: Optional[Dict] = None): - """注册服务器""" - await self.client.connect_server(server) - - if config: - self.server_configs[server.name] = config - self._save_config() - - async def unregister_server(self, server_name: str): - """注销服务器""" - await self.client.disconnect_server(server_name) - - if server_name in self.server_configs: - del self.server_configs[server_name] - self._save_config() - - def get_client(self) -> MCPClient: - """获取客户端""" - return self.client - - async def get_all_tools_for_ai(self) -> List[Dict]: - """获取所有工具(AI格式)""" - all_tools = [] - tools_by_server = await self.client.list_all_tools() - - for server_name, tools in tools_by_server.items(): - for tool in tools: - all_tools.append({ - "type": "function", - "function": { - "name": f"{server_name}.{tool.name}", - "description": tool.description, - "parameters": tool.input_schema - } - }) - - return all_tools - - async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any: - """执行工具""" - parts = full_tool_name.split('.', 1) - if len(parts) != 2: - raise ValueError(f"工具名格式错误: {full_tool_name}") - - server_name, tool_name = parts - logger.info(f"MCP执行请求: {full_tool_name}") - return await self.client.call_tool(server_name, tool_name, arguments) diff --git a/src/ai/mcp/servers/__init__.py b/src/ai/mcp/servers/__init__.py deleted file mode 100644 index 2b5a18c..0000000 --- a/src/ai/mcp/servers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -MCP服务器实现 -""" -from .filesystem import FileSystemMCPServer - -__all__ = ['FileSystemMCPServer'] diff --git a/src/ai/mcp/servers/filesystem.py b/src/ai/mcp/servers/filesystem.py deleted file mode 100644 index 8f793da..0000000 --- a/src/ai/mcp/servers/filesystem.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -MCP示例服务器 - 文件系统访问 -""" -from pathlib import Path -from typing import Optional -from ..base import MCPServer, MCPResource - - -class FileSystemMCPServer(MCPServer): - """文件系统MCP服务器""" - - def __init__(self, root_path: Path): - super().__init__(name="filesystem", version="1.0.0") - self.root_path = root_path - - async def initialize(self): - """初始化""" - # 注册工具 - self.register_tool( - name="read_file", - description="读取文件内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "文件路径" - } - }, - "required": ["path"] - }, - handler=self.read_file - ) - - self.register_tool( - name="write_file", - description="写入文件内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "文件路径" - }, - "content": { - "type": "string", - "description": "文件内容" - } - }, - "required": ["path", "content"] - }, - handler=self.write_file - ) - - self.register_tool( - name="list_directory", - description="列出目录内容", - input_schema={ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "目录路径" - } - }, - "required": ["path"] - }, - handler=self.list_directory - ) - - def _resolve_path(self, path: str) -> Path: - """解析路径""" - full_path = (self.root_path / path).resolve() - - # 安全检查:确保路径在root_path内 - if not str(full_path).startswith(str(self.root_path)): - raise ValueError("路径超出允许范围") - - return full_path - - async def read_file(self, path: str) -> str: - """读取文件""" - file_path = self._resolve_path(path) - - if not file_path.exists(): - raise FileNotFoundError(f"文件不存在: {path}") - - if not file_path.is_file(): - raise ValueError(f"不是文件: {path}") - - with open(file_path, 'r', encoding='utf-8') as f: - return f.read() - - async def write_file(self, path: str, content: str) -> str: - """写入文件""" - file_path = self._resolve_path(path) - - file_path.parent.mkdir(parents=True, exist_ok=True) - - with open(file_path, 'w', encoding='utf-8') as f: - f.write(content) - - return f"文件已写入: {path}" - - async def list_directory(self, path: str) -> list: - """列出目录""" - dir_path = self._resolve_path(path) - - if not dir_path.exists(): - raise FileNotFoundError(f"目录不存在: {path}") - - if not dir_path.is_dir(): - raise ValueError(f"不是目录: {path}") - - items = [] - for item in dir_path.iterdir(): - items.append({ - "name": item.name, - "type": "directory" if item.is_dir() else "file", - "size": item.stat().st_size if item.is_file() else None - }) - - return items diff --git a/src/ai/memory.py b/src/ai/memory.py index 6c20dee..1918197 100644 --- a/src/ai/memory.py +++ b/src/ai/memory.py @@ -1,180 +1,166 @@ -""" -记忆系统:短期记忆、长期记忆与 RAG 检索(向量数据库)。 """ -import asyncio +Memory system: short-term window + long-term retrieval. +""" + +from __future__ import annotations + import hashlib import shutil import time import uuid -from typing import List, Dict, Optional, Tuple, Callable, Awaitable +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path -from collections import deque -from .vector_store import VectorStore, VectorMemory, ChromaVectorStore, JSONVectorStore +from typing import Awaitable, Callable, Dict, List, Optional, Tuple + +from .vector_store import ChromaVectorStore, JSONVectorStore, VectorMemory, VectorStore from src.utils.logger import setup_logger -logger = setup_logger('MemorySystem') +logger = setup_logger("MemorySystem") @dataclass class MemoryItem: - """记忆项(用于短期记忆)。""" + """In-memory short-term record.""" + content: str timestamp: datetime user_id: str importance: float = 0.5 metadata: Dict = field(default_factory=dict) - + def to_dict(self) -> Dict: - """转换为字典。""" return { - 'content': self.content, - 'timestamp': self.timestamp.isoformat(), - 'user_id': self.user_id, - 'importance': self.importance, - 'metadata': self.metadata + "content": self.content, + "timestamp": self.timestamp.isoformat(), + "user_id": self.user_id, + "importance": self.importance, + "metadata": self.metadata, } class ShortTermMemory: - """短期记忆(滑动窗口)。""" - + """Short-term memory window.""" + def __init__(self, max_size: int = 20, max_age_minutes: int = 30): self.max_size = max_size self.max_age = timedelta(minutes=max_age_minutes) - self.memories: Dict[str, deque] = {} # user_id -> deque of MemoryItem - + self.memories: Dict[str, deque] = {} + def add(self, user_id: str, content: str, metadata: Optional[Dict] = None): - """添加短期记忆。""" if user_id not in self.memories: self.memories[user_id] = deque(maxlen=self.max_size) - - memory = MemoryItem( - content=content, - timestamp=datetime.now(), - user_id=user_id, - metadata=metadata or {} + + self.memories[user_id].append( + MemoryItem( + content=content, + timestamp=datetime.now(), + user_id=user_id, + metadata=metadata or {}, + ) ) - - self.memories[user_id].append(memory) - + def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]: - """获取短期记忆。""" - if user_id not in self.memories: + items = list(self.memories.get(user_id, [])) + if not items: return [] - - # 过滤过期记忆 + now = datetime.now() - valid_memories = [ - m for m in self.memories[user_id] - if now - m.timestamp <= self.max_age - ] - - if limit: - valid_memories = valid_memories[-limit:] - - return valid_memories - + valid = [m for m in items if now - m.timestamp <= self.max_age] + self.memories[user_id] = deque(valid, maxlen=self.max_size) + + if limit and limit > 0: + return valid[-limit:] + return valid + def clear(self, user_id: str): - """清除用户短期记忆。""" - if user_id in self.memories: - self.memories.pop(user_id, None) + self.memories.pop(user_id, None) class MemorySystem: - """记忆系统:整合短期记忆与长期记忆。""" - + """Memory system: short-term + long-term storage.""" + def __init__( self, storage_path: Path, - embed_func: Optional[callable] = None, + embed_func: Optional[Callable[[str], Awaitable[List[float]]]] = None, importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None, importance_threshold: float = 0.6, use_vector_db: bool = True, use_query_embedding: bool = False, + max_long_term_per_user: int = 500, + dedup_window_seconds: int = 300, ): self.short_term = ShortTermMemory() self.embed_func = embed_func self.importance_evaluator = importance_evaluator self.importance_threshold = importance_threshold - # Only embed retrieval queries when explicitly enabled. self.use_query_embedding = use_query_embedding - - # 初始化向量存储 + self.max_long_term_per_user = max(10, int(max_long_term_per_user)) + self.dedup_window = timedelta(seconds=max(1, int(dedup_window_seconds))) + if use_vector_db: chroma_path = storage_path.parent / "chroma_db" chroma_store = self._init_chroma_store(chroma_path) if chroma_store is not None: self.vector_store = chroma_store - logger.info("Using Chroma vector store") else: self.vector_store = JSONVectorStore(storage_path) else: - # 使用 JSON 存储(向后兼容) self.vector_store = JSONVectorStore(storage_path) - logger.info("使用 JSON 存储") @staticmethod def _is_chroma_table_conflict(error: Exception) -> bool: - msg = str(error).lower() - return "table embeddings already exists" in msg + return "table embeddings already exists" in str(error).lower() @staticmethod def _is_chroma_trigram_error(error: Exception) -> bool: - msg = str(error).lower() - return "no such tokenizer: trigram" in msg + return "no such tokenizer: trigram" in str(error).lower() def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]: - """初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。""" try: return ChromaVectorStore(chroma_path) except Exception as error: if self._is_chroma_trigram_error(error): logger.warning( - "Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。" - "请在运行环境升级 sqlite 或安装 pysqlite3-binary。" + "Chroma unavailable (sqlite trigram unsupported), fallback to JSON store." ) return None - if not self._is_chroma_table_conflict(error): - logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}") + logger.warning(f"Chroma init failed, fallback to JSON: {error}") return None - # 先做一次短暂重试,处理并发启动时的瞬时冲突。 - logger.warning(f"Chroma 初始化出现 schema 冲突,正在重试: {error}") + logger.warning(f"Chroma schema conflict, retry once: {error}") time.sleep(0.2) try: return ChromaVectorStore(chroma_path) except Exception as retry_error: if not self._is_chroma_table_conflict(retry_error): - logger.warning(f"Chroma 重试失败,降级为 JSON 存储: {retry_error}") + logger.warning(f"Chroma retry failed, fallback to JSON: {retry_error}") return None - backup_name = ( + backup_path = chroma_path.parent / ( f"{chroma_path.name}_backup_conflict_" f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" ) - backup_path = chroma_path.parent / backup_name - try: if chroma_path.exists(): shutil.move(str(chroma_path), str(backup_path)) chroma_path.mkdir(parents=True, exist_ok=True) repaired = ChromaVectorStore(chroma_path) logger.warning( - f"检测到 Chroma 元数据库冲突,已重建目录并保留备份: {backup_path}" + f"Chroma metadata repaired by rebuilding directory. Backup: {backup_path}" ) return repaired except Exception as repair_error: - logger.warning(f"Chroma 修复失败,降级为 JSON 存储: {repair_error}") + logger.warning(f"Chroma repair failed, fallback to JSON: {repair_error}") return None - + @staticmethod def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]: if not values: return [0.0] * dim - normalized = [float(v) for v in values[:dim]] if len(normalized) < dim: normalized.extend([0.0] * (dim - len(normalized))) @@ -191,14 +177,11 @@ class MemorySystem: return vec for idx, byte in enumerate(encoded): - bucket = idx % dim - vec[bucket] += (byte / 255.0) + vec[idx % dim] += byte / 255.0 digest = hashlib.sha256(encoded).digest() for idx, byte in enumerate(digest): - bucket = idx % dim - vec[bucket] += ((byte / 255.0) - 0.5) * 0.1 - + vec[idx % dim] += ((byte / 255.0) - 0.5) * 0.1 return vec async def _build_embedding(self, text: str) -> List[float]: @@ -207,9 +190,8 @@ class MemorySystem: embedding = await self.embed_func(text) if embedding: return [float(v) for v in list(embedding)] - except Exception as e: - logger.warning(f"embedding generation failed: {e}") - + except Exception as exc: + logger.warning(f"Embedding generation failed, fallback to local: {exc}") return self._local_embedding(text) async def _add_vector_memory( @@ -231,10 +213,8 @@ class MemorySystem: ): return True - # Chroma collection may have a fixed historical embedding dimension. - candidate_dims = [] - base_len = len(embedding or []) - for dim in [base_len, 1024, 1536, 768, 384, 3072]: + candidate_dims: List[int] = [] + for dim in [len(embedding or []), 1024, 1536, 768, 384, 3072]: if dim and dim > 0 and dim not in candidate_dims: candidate_dims.append(dim) @@ -250,7 +230,6 @@ class MemorySystem: ) if ok: return True - return False @staticmethod @@ -261,15 +240,53 @@ class MemorySystem: value = 0.5 return max(0.0, min(1.0, value)) + async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float: + if not content or not content.strip(): + return 0.0 + + if self.importance_evaluator: + try: + score = await self.importance_evaluator(content, metadata) + return self._normalize_importance(score) + except Exception as exc: + logger.warning(f"Importance evaluation failed, fallback to 0.5: {exc}") + return 0.5 + + async def _is_duplicate_long_term(self, user_id: str, content: str) -> bool: + normalized = " ".join((content or "").split()) + if not normalized: + return True + + all_memories = await self.vector_store.get_all(user_id) + recent_cutoff = datetime.now() - self.dedup_window + for memory in all_memories[-20:]: + if ( + " ".join((memory.content or "").split()) == normalized + and memory.timestamp >= recent_cutoff + ): + return True + return False + + async def _trim_user_long_term(self, user_id: str): + all_memories = await self.vector_store.get_all(user_id) + if len(all_memories) <= self.max_long_term_per_user: + return + + all_memories.sort(key=lambda m: (m.importance, m.timestamp)) + to_delete = len(all_memories) - self.max_long_term_per_user + for memory in all_memories[:to_delete]: + await self.vector_store.delete(memory.id) + async def add_message( self, user_id: str, role: str, content: str, - metadata: Optional[Dict] = None + metadata: Optional[Dict] = None, ): - """向短期记忆添加单条消息(不做长期记忆评分)。""" - self.short_term.add(user_id, content, metadata) + payload = dict(metadata or {}) + payload.setdefault("role", role) + self.short_term.add(user_id, content, payload) async def add_qa_pair( self, @@ -278,9 +295,6 @@ class MemorySystem: answer: str, metadata: Optional[Dict] = None, ) -> Optional[VectorMemory]: - """ - 添加最新问答对,并仅对该问答对做模型重要性评估。 - """ user_meta = {"role": "user"} assistant_meta = {"role": "assistant"} if isinstance(metadata, dict): @@ -297,6 +311,8 @@ class MemorySystem: importance = await self._evaluate_importance(qa_content, qa_metadata) if importance < self.importance_threshold: return None + if await self._is_duplicate_long_term(user_id, qa_content): + return None embedding = await self._build_embedding(qa_content) memory_id = str(uuid.uuid4()) @@ -311,97 +327,88 @@ class MemorySystem: if not ok: return None + await self._trim_user_long_term(user_id) return await self.get_long_term(user_id, memory_id) - - async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float: - """评估记忆重要性。""" - if not content or not content.strip(): + + @staticmethod + def _simple_text_score(query: str, text: str) -> float: + query_tokens = [token for token in query.lower().split() if token] + if not query_tokens: return 0.0 + text_lower = text.lower() + hit = sum(1 for token in query_tokens if token in text_lower) + return hit / len(query_tokens) - if self.importance_evaluator: - try: - score = await self.importance_evaluator(content, metadata) - return self._normalize_importance(score) - except Exception as e: - logger.warning(f"importance evaluation failed, fallback to neutral score: {e}") - - # 当模型评估不可用时,使用中性分数作为兜底。 - return 0.5 - async def get_context( self, user_id: str, query: Optional[str] = None, max_short_term: int = 10, - max_long_term: int = 5 + max_long_term: int = 5, ) -> Tuple[List[MemoryItem], List[VectorMemory]]: - """获取上下文(短期 + 长期记忆)。""" - # 获取短期记忆 short_term_memories = self.short_term.get(user_id, limit=max_short_term) - - # 获取相关长期记忆 - long_term_memories = [] - + long_term_memories: List[VectorMemory] = [] + if query and self.use_query_embedding: try: - # 使用向量检索 query_embedding = await self._build_embedding(query) - if query_embedding: - long_term_memories = await self.vector_store.search( - user_id=user_id, - query_embedding=query_embedding, - limit=max_long_term - ) - except Exception as e: - logger.warning(f"向量检索失败,改用重要性检索: {e}") + long_term_memories = await self.vector_store.search( + user_id=user_id, + query_embedding=query_embedding, + limit=max_long_term, + ) + except Exception as exc: + logger.warning(f"Vector search failed, fallback to lexical: {exc}") if query and not long_term_memories: - query_lower = query.lower() - try: - candidates = await self.vector_store.get_all(user_id) - matches = [m for m in candidates if query_lower in m.content.lower()] - matches.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) - long_term_memories = matches[:max_long_term] - except Exception: - pass - - # 濡傛灉鍚戦噺妫€绱㈠け璐ユ垨娌℃湁缁撴灉锛屼娇鐢ㄩ噸瑕佹€ф绱? + candidates = await self.vector_store.get_all(user_id) + scored = [] + for memory in candidates: + score = self._simple_text_score(query, memory.content) + if score <= 0: + continue + combined = score * 0.7 + memory.importance * 0.3 + scored.append((combined, memory)) + scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True) + long_term_memories = [item[1] for item in scored[:max_long_term]] + if not long_term_memories: long_term_memories = await self.vector_store.get_by_importance( user_id=user_id, - limit=max_long_term + limit=max_long_term, ) - - # 更新长期记忆访问记录 + for memory in long_term_memories: await self.vector_store.update_access(memory.id) - - return short_term_memories, long_term_memories - - def format_context( - self, - short_term: List[MemoryItem], - long_term: List[VectorMemory] - ) -> str: - """格式化上下文为文本。""" - context = "" - - if long_term: - context += "## 相关历史记忆\n" - for i, memory in enumerate(long_term, 1): - context += f"{i}. {memory.content}\n" - context += "\n" - - if short_term: - context += "## 最近对话\n" - for memory in short_term: - context += f"- {memory.content}\n" - - return context - async def list_long_term( - self, user_id: str, limit: int = 20 - ) -> List[VectorMemory]: + return short_term_memories, long_term_memories + + def format_context( + self, short_term: List[MemoryItem], long_term: List[VectorMemory] + ) -> str: + lines: List[str] = [] + + if long_term: + lines.append("## 相关历史记忆") + for idx, memory in enumerate(long_term, 1): + lines.append(f"{idx}. {memory.content}") + lines.append("") + + if short_term: + lines.append("## 最近对话") + for item in short_term: + role = str((item.metadata or {}).get("role") or "").strip().lower() + if role == "user": + prefix = "用户" + elif role == "assistant": + prefix = "助手" + else: + prefix = "对话" + lines.append(f"- {prefix}: {item.content}") + + return "\n".join(lines).strip() + + async def list_long_term(self, user_id: str, limit: int = 20) -> List[VectorMemory]: memories = await self.vector_store.get_all(user_id) memories.sort(key=lambda m: m.timestamp, reverse=True) if limit > 0: @@ -422,21 +429,24 @@ class MemorySystem: importance: float = 0.8, metadata: Optional[Dict] = None, ) -> Optional[VectorMemory]: - memory_id = str(uuid.uuid4()) - importance = self._normalize_importance(importance) - embedding = await self._build_embedding(content) + if await self._is_duplicate_long_term(user_id, content): + return None + memory_id = str(uuid.uuid4()) + normalized_importance = self._normalize_importance(importance) + embedding = await self._build_embedding(content) ok = await self._add_vector_memory( memory_id=memory_id, user_id=user_id, content=content, embedding=embedding, - importance=importance, + importance=normalized_importance, metadata=metadata or {}, ) if not ok: return None + await self._trim_user_long_term(user_id) return await self.get_long_term(user_id, memory_id) async def search_long_term( @@ -457,10 +467,15 @@ class MemorySystem: return results all_memories = await self.vector_store.get_all(user_id) - query_lower = query.lower() - matched = [m for m in all_memories if query_lower in m.content.lower()] - matched.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) - return matched[:limit] + scored = [] + for memory in all_memories: + score = self._simple_text_score(query, memory.content) + if score <= 0: + continue + combined = score * 0.7 + memory.importance * 0.3 + scored.append((combined, memory)) + scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True) + return [item[1] for item in scored[:limit]] async def update_long_term( self, @@ -503,7 +518,6 @@ class MemorySystem: ) if not added: return None - return await self.get_long_term(user_id, memory_id) async def delete_long_term(self, user_id: str, memory_id: str) -> bool: @@ -511,15 +525,12 @@ class MemorySystem: if not memory: return False return await self.vector_store.delete(memory_id) - + def clear_short_term(self, user_id: str): - """清除短期记忆。""" self.short_term.clear(user_id) - + async def clear_long_term(self, user_id: str): - """清除长期记忆。""" await self.vector_store.clear_user(user_id) - + async def close(self): - """关闭记忆系统。""" await self.vector_store.close() diff --git a/src/ai/models/anthropic_model.py b/src/ai/models/anthropic_model.py index 3854b0e..059e88b 100644 --- a/src/ai/models/anthropic_model.py +++ b/src/ai/models/anthropic_model.py @@ -1,100 +1,94 @@ """ -Anthropic Claude模型实现 +Anthropic Claude model implementation. """ -from typing import List, Optional, AsyncIterator + +from __future__ import annotations + +from typing import AsyncIterator, List, Optional + from anthropic import AsyncAnthropic + from ..base import BaseAIModel, Message, ModelConfig class AnthropicModel(BaseAIModel): - """Anthropic Claude模型实现""" - + """Anthropic Claude model implementation.""" + def __init__(self, config: ModelConfig): super().__init__(config) self.client = AsyncAnthropic( api_key=config.api_key, base_url=config.api_base, - timeout=config.timeout + timeout=config.timeout, ) - + async def chat( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> Message: - """同步对话""" - # 分离system消息 system_message = None formatted_messages = [] - + for msg in messages: if msg.role == "system": system_message = msg.content else: - formatted_messages.append({ - "role": msg.role, - "content": msg.content - }) - + formatted_messages.append({"role": msg.role, "content": msg.content}) + params = { "model": self.config.model_name, "messages": formatted_messages, "max_tokens": self.config.max_tokens, "temperature": self.config.temperature, } - if system_message: params["system"] = system_message - if tools: params["tools"] = tools - params.update(kwargs) - + response = await self.client.messages.create(**params) - + content = "" tool_calls = [] - for block in response.content: if block.type == "text": content += block.text elif block.type == "tool_use": - tool_calls.append({ - "id": block.id, - "type": "function", - "function": { - "name": block.name, - "arguments": block.input + tool_calls.append( + { + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": block.input, + }, } - }) - + ) + return Message( role="assistant", content=content, - tool_calls=tool_calls if tool_calls else None + tool_calls=tool_calls if tool_calls else None, ) - + async def chat_stream( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> AsyncIterator[str]: - """流式对话""" system_message = None formatted_messages = [] - + for msg in messages: if msg.role == "system": system_message = msg.content else: - formatted_messages.append({ - "role": msg.role, - "content": msg.content - }) - + formatted_messages.append({"role": msg.role, "content": msg.content}) + params = { "model": self.config.model_name, "messages": formatted_messages, @@ -102,19 +96,24 @@ class AnthropicModel(BaseAIModel): "temperature": self.config.temperature, "stream": True, } - if system_message: params["system"] = system_message - if tools: params["tools"] = tools - params.update(kwargs) - + async with self.client.messages.stream(**params) as stream: async for text in stream.text_stream: yield text - + async def embed(self, text: str) -> List[float]: - """文本嵌入(Anthropic不直接提供,需要使用其他服务)""" - raise NotImplementedError("Anthropic不提供嵌入API,请使用OpenAI或其他服务") + raise NotImplementedError( + "Anthropic does not provide embedding API; use OpenAI-compatible embedding model." + ) + + async def close(self): + close_fn = getattr(self.client, "close", None) + if callable(close_fn): + maybe_awaitable = close_fn() + if hasattr(maybe_awaitable, "__await__"): + await maybe_awaitable diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index 40c5707..a40b0e9 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -24,7 +24,7 @@ class OpenAIModel(BaseAIModel): self.logger = logger self._embedding_token_limit: Optional[int] = None - http_client = httpx.AsyncClient( + self._http_client = httpx.AsyncClient( timeout=config.timeout, limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), ) @@ -33,7 +33,7 @@ class OpenAIModel(BaseAIModel): api_key=config.api_key, base_url=config.api_base, timeout=config.timeout, - http_client=http_client, + http_client=self._http_client, ) self._supports_tools = False @@ -590,3 +590,8 @@ class OpenAIModel(BaseAIModel): self.logger.error(f"text preview: {repr(candidate_text[:100])}") self.logger.error(f"full traceback:\n{traceback.format_exc()}") raise + + async def close(self): + """Release network resources.""" + if self._http_client: + await self._http_client.aclose() diff --git a/src/ai/personality.py b/src/ai/personality.py index 0f1d458..caeb940 100644 --- a/src/ai/personality.py +++ b/src/ai/personality.py @@ -1,4 +1,6 @@ -"""Personality system for role-play profiles.""" +"""Personality system for role-play profiles with scope priority.""" + +from __future__ import annotations from dataclasses import dataclass, field from enum import Enum @@ -32,8 +34,6 @@ class PersonalityProfile: custom_instructions: str = "" def to_system_prompt(self) -> str: - """Build plain-text system prompt.""" - traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly" lines = [ "Role Setting", @@ -55,13 +55,29 @@ class PersonalityProfile: class PersonalitySystem: - """Personality management and persistence.""" + """Personality management, persistence and scope resolution.""" - def __init__(self, config_path: Optional[Path] = None): + def __init__( + self, + config_path: Optional[Path] = None, + state_path: Optional[Path] = None, + ): self.config_path = config_path or Path("config/personalities.json") + self.state_path = state_path or self.config_path.with_name("personality_state.json") + self.personalities: Dict[str, PersonalityProfile] = {} - self.current_personality: Optional[PersonalityProfile] = None + self._active_global_key: str = "default" + self._active_user_keys: Dict[str, str] = {} + self._active_group_keys: Dict[str, str] = {} + self._active_session_keys: Dict[str, str] = {} + self._load_personalities() + self._load_state() + self._ensure_valid_active_keys() + + @property + def current_personality(self) -> Optional[PersonalityProfile]: + return self.personalities.get(self._active_global_key) def _dict_to_profile(self, config: Dict) -> PersonalityProfile: trait_names = config.get("traits", []) @@ -83,27 +99,30 @@ class PersonalitySystem: custom_instructions=str(config.get("custom_instructions", "")), ) - def _load_personalities(self): - """Load personality config from disk or create defaults.""" + def _profile_to_dict(self, profile: PersonalityProfile) -> Dict: + return { + "name": profile.name, + "description": profile.description, + "traits": [trait.name for trait in profile.traits], + "speaking_style": profile.speaking_style, + "example_responses": profile.example_responses, + "custom_instructions": profile.custom_instructions, + } + def _load_personalities(self): if self.config_path.exists(): with open(self.config_path, "r", encoding="utf-8") as f: data = json.load(f) for key, config in data.items(): - self.personalities[key] = self._dict_to_profile(config) + if isinstance(config, dict): + self.personalities[key] = self._dict_to_profile(config) - if "default" in self.personalities: - self.current_personality = self.personalities["default"] - elif self.personalities: - first_key = next(iter(self.personalities.keys())) - self.current_personality = self.personalities[first_key] - return + if self.personalities: + return self._create_default_personalities() def _create_default_personalities(self): - """Create and persist built-in default profiles.""" - default = PersonalityProfile( name="Assistant", description="A friendly and practical AI assistant.", @@ -146,87 +165,200 @@ class PersonalitySystem: "tech_expert": tech_expert, "creative": creative, } - self.current_personality = default + self._active_global_key = "default" self._save_personalities() def _save_personalities(self): - """Persist personalities to disk.""" - self.config_path.parent.mkdir(parents=True, exist_ok=True) - data = {} - - for key, profile in self.personalities.items(): - data[key] = { - "name": profile.name, - "description": profile.description, - "traits": [trait.name for trait in profile.traits], - "speaking_style": profile.speaking_style, - "example_responses": profile.example_responses, - "custom_instructions": profile.custom_instructions, - } - + data = { + key: self._profile_to_dict(profile) + for key, profile in self.personalities.items() + } with open(self.config_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) - def set_personality(self, key: str) -> bool: - """Switch active personality by key.""" + def _load_state(self): + if not self.state_path.exists(): + return + try: + with open(self.state_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return + + self._active_global_key = str(data.get("global", self._active_global_key)) + self._active_user_keys = { + str(k): str(v) + for k, v in (data.get("user") or {}).items() + if k is not None and v is not None + } + self._active_group_keys = { + str(k): str(v) + for k, v in (data.get("group") or {}).items() + if k is not None and v is not None + } + self._active_session_keys = { + str(k): str(v) + for k, v in (data.get("session") or {}).items() + if k is not None and v is not None + } + + def _save_state(self): + self.state_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "global": self._active_global_key, + "user": self._active_user_keys, + "group": self._active_group_keys, + "session": self._active_session_keys, + } + with open(self.state_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def _ensure_valid_active_keys(self): + if self._active_global_key not in self.personalities: + if "default" in self.personalities: + self._active_global_key = "default" + elif self.personalities: + self._active_global_key = next(iter(self.personalities.keys())) + else: + self._active_global_key = "" + + self._active_user_keys = { + scope_id: key + for scope_id, key in self._active_user_keys.items() + if key in self.personalities + } + self._active_group_keys = { + scope_id: key + for scope_id, key in self._active_group_keys.items() + if key in self.personalities + } + self._active_session_keys = { + scope_id: key + for scope_id, key in self._active_session_keys.items() + if key in self.personalities + } + self._save_state() + + def set_personality( + self, key: str, scope: str = "global", scope_id: Optional[str] = None + ) -> bool: if key not in self.personalities: return False - self.current_personality = self.personalities[key] + scope_normalized = (scope or "global").strip().lower() + if scope_normalized == "global": + self._active_global_key = key + self._save_state() + return True + + if not scope_id: + return False + + if scope_normalized == "user": + self._active_user_keys[scope_id] = key + elif scope_normalized == "group": + self._active_group_keys[scope_id] = key + elif scope_normalized == "session": + self._active_session_keys[scope_id] = key + else: + return False + + self._save_state() return True - def get_system_prompt(self) -> str: - """Get current personality prompt.""" + def get_active_personality( + self, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Optional[PersonalityProfile]: + # Priority: session > group > user > global > default + if session_id: + key = self._active_session_keys.get(session_id) + if key in self.personalities: + return self.personalities[key] - if self.current_personality: - return self.current_personality.to_system_prompt() - return "" + if group_id: + key = self._active_group_keys.get(group_id) + if key in self.personalities: + return self.personalities[key] + + if user_id: + key = self._active_user_keys.get(user_id) + if key in self.personalities: + return self.personalities[key] + + if self._active_global_key in self.personalities: + return self.personalities[self._active_global_key] + + return self.personalities.get("default") + + def get_system_prompt( + self, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> str: + profile = self.get_active_personality( + user_id=user_id, + group_id=group_id, + session_id=session_id, + ) + return profile.to_system_prompt() if profile else "" def add_personality(self, key: str, profile: PersonalityProfile) -> bool: - """Add a new personality profile.""" - key = key.strip() if not key: return False self.personalities[key] = profile - if not self.current_personality: - self.current_personality = profile + if not self._active_global_key: + self._active_global_key = key + self._save_state() self._save_personalities() return True def remove_personality(self, key: str) -> bool: - """Remove a personality profile.""" - if key == "default": return False - if key not in self.personalities: return False - removed_profile = self.personalities[key] del self.personalities[key] + if not self.personalities: + self._create_default_personalities() - if self.current_personality == removed_profile: - if "default" in self.personalities: - self.current_personality = self.personalities["default"] - elif self.personalities: - first_key = next(iter(self.personalities.keys())) - self.current_personality = self.personalities[first_key] - else: - self.current_personality = None + if self._active_global_key == key: + self._active_global_key = ( + "default" + if "default" in self.personalities + else next(iter(self.personalities.keys())) + ) + + self._active_user_keys = { + scope_id: active_key + for scope_id, active_key in self._active_user_keys.items() + if active_key != key + } + self._active_group_keys = { + scope_id: active_key + for scope_id, active_key in self._active_group_keys.items() + if active_key != key + } + self._active_session_keys = { + scope_id: active_key + for scope_id, active_key in self._active_session_keys.items() + if active_key != key + } self._save_personalities() + self._save_state() return True def list_personalities(self) -> List[str]: - """List all personality keys.""" - return sorted(self.personalities.keys()) def get_personality(self, key: str) -> Optional[PersonalityProfile]: - """Get personality by key.""" - return self.personalities.get(key) diff --git a/src/ai/skills/__init__.py b/src/ai/skills/__init__.py deleted file mode 100644 index 40ee6ef..0000000 --- a/src/ai/skills/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Skills系统初始化 -""" -from .base import Skill, SkillsManager, SkillMetadata, create_skill_template - -__all__ = ['Skill', 'SkillsManager', 'SkillMetadata', 'create_skill_template'] diff --git a/src/ai/skills/base.py b/src/ai/skills/base.py deleted file mode 100644 index e9d4878..0000000 --- a/src/ai/skills/base.py +++ /dev/null @@ -1,750 +0,0 @@ -""" -Skills 系统 - 可扩展技能插件框架。 -""" - -from dataclasses import dataclass -import importlib -import inspect -import json -from pathlib import Path -import re -import shutil -import sys -import tempfile -import time -from typing import Any, Callable, Dict, List, Optional, Tuple -import urllib.parse -import urllib.request -import zipfile -import os -import stat - -from src.utils.logger import setup_logger - -logger = setup_logger("SkillsSystem") - - -@dataclass -class SkillMetadata: - """技能元数据。""" - - name: str - version: str - description: str - author: str - dependencies: List[str] - enabled: bool = True - - -class Skill: - """技能基类。""" - - def __init__(self): - self.metadata: Optional[SkillMetadata] = None - self.tools: Dict[str, Callable] = {} - self.manager = None - - async def initialize(self): - """初始化技能。""" - - async def cleanup(self): - """清理技能。""" - - def get_tools(self) -> Dict[str, Callable]: - """获取技能提供的工具。""" - - return self.tools - - def register_tool(self, name: str, func: Callable): - """注册工具。""" - - self.tools[name] = func - - -class SkillsManager: - """技能管理器。""" - - _SKILL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]") - _GITHUB_SHORTCUT_PATTERN = re.compile( - r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+(?:#[A-Za-z0-9_.-]+)?$" - ) - - def __init__(self, skills_dir: Path): - self.skills_dir = skills_dir - self.skills: Dict[str, Skill] = {} - self.skills_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"✅ Skills 目录: {skills_dir}") - - @classmethod - def normalize_skill_key(cls, raw_name: str) -> str: - """将任意输入规范化为可导入的 Python 包名。""" - - key = raw_name.strip().lower().replace("-", "_").replace(" ", "_") - key = cls._SKILL_KEY_PATTERN.sub("_", key) - key = re.sub(r"_+", "_", key).strip("_") - - if not key: - raise ValueError("技能名不能为空") - - if key[0].isdigit(): - key = f"skill_{key}" - - return key - - def _get_skill_path(self, skill_name: str) -> Path: - return self.skills_dir / self.normalize_skill_key(skill_name) - - @staticmethod - def _on_rmtree_error(func, path, exc_info): - """Handle Windows readonly/locked file deletion errors.""" - try: - os.chmod(path, stat.S_IWRITE) - func(path) - except Exception: - # Keep original failure path for upper retry logic. - pass - - def _read_metadata(self, skill_path: Path, fallback_name: str) -> Dict[str, Any]: - metadata_file = skill_path / "skill.json" - if metadata_file.exists(): - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - else: - metadata = {} - - metadata.setdefault("name", fallback_name) - metadata.setdefault("version", "1.0.0") - metadata.setdefault("description", f"{fallback_name} skill") - metadata.setdefault("author", "unknown") - metadata.setdefault("dependencies", []) - metadata.setdefault("enabled", True) - - with open(metadata_file, "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - return metadata - - def _ensure_skill_package_layout(self, skill_path: Path, skill_key: str): - """确保技能目录满足运行最小结构。""" - - skill_path.mkdir(parents=True, exist_ok=True) - - init_file = skill_path / "__init__.py" - if not init_file.exists(): - init_file.write_text("", encoding="utf-8") - - main_file = skill_path / "main.py" - if not main_file.exists(): - template = f'''"""{skill_key} skill""" -from src.ai.skills.base import Skill - - -class {"".join(p.capitalize() for p in skill_key.split("_"))}Skill(Skill): - async def initialize(self): - self.register_tool("ping", self.ping) - - async def ping(self, text: str = "ok") -> str: - return text - - async def cleanup(self): - pass -''' - main_file.write_text(template, encoding="utf-8") - - self._read_metadata(skill_path, skill_key) - - async def load_skill(self, skill_name: str) -> bool: - """加载技能。""" - - try: - skill_name = self.normalize_skill_key(skill_name) - - if skill_name in self.skills: - logger.info(f"✅ 技能已加载: {skill_name}") - return True - - skill_path = self._get_skill_path(skill_name) - if not skill_path.exists(): - logger.error(f"❌ 技能不存在: {skill_name}") - return False - - metadata_file = skill_path / "skill.json" - if not metadata_file.exists(): - logger.error(f"❌ 技能元数据不存在: {skill_name}") - return False - - with open(metadata_file, "r", encoding="utf-8") as f: - metadata_dict = json.load(f) - - metadata = SkillMetadata(**metadata_dict) - - if not metadata.enabled: - logger.info(f"⏸️ 技能已禁用: {skill_name}") - return False - - module_path = f"skills.{skill_name}.main" - importlib.invalidate_caches() - - try: - old_dont_write = sys.dont_write_bytecode - sys.dont_write_bytecode = True - try: - if module_path in sys.modules: - module = importlib.reload(sys.modules[module_path]) - else: - module = importlib.import_module(module_path) - finally: - sys.dont_write_bytecode = old_dont_write - except Exception as exc: - logger.error(f"❌ 无法导入技能模块 {module_path}: {exc}") - return False - - skill_class = None - for _, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, Skill) and obj != Skill: - skill_class = obj - break - - if not skill_class: - logger.error(f"❌ 技能中未找到 Skill 子类: {skill_name}") - return False - - skill = skill_class() - skill.metadata = metadata - skill.manager = self - - await skill.initialize() - self.skills[skill_name] = skill - - logger.info(f"✅ 加载技能: {skill_name} v{metadata.version}") - return True - - except Exception as exc: - logger.error(f"❌ 加载技能失败 {skill_name}: {exc}") - return False - - async def load_all_skills(self): - """加载所有可用技能。""" - - for skill_name in self.list_available_skills(): - await self.load_skill(skill_name) - - async def unload_skill(self, skill_name: str) -> bool: - """仅卸载内存中的技能。""" - - skill_name = self.normalize_skill_key(skill_name) - if skill_name not in self.skills: - return False - - skill = self.skills[skill_name] - await skill.cleanup() - del self.skills[skill_name] - - sys.modules.pop(f"skills.{skill_name}.main", None) - sys.modules.pop(f"skills.{skill_name}", None) - importlib.invalidate_caches() - - logger.info(f"✅ 卸载技能: {skill_name}") - return True - - async def uninstall_skill(self, skill_name: str, delete_files: bool = True) -> bool: - """卸载技能并可选删除文件。""" - - skill_name = self.normalize_skill_key(skill_name) - - if skill_name in self.skills: - await self.unload_skill(skill_name) - - if not delete_files: - return True - - skill_path = self._get_skill_path(skill_name) - if not skill_path.exists(): - return False - - removed = False - for _ in range(3): - try: - shutil.rmtree(skill_path, ignore_errors=False, onerror=self._on_rmtree_error) - except PermissionError: - pass - - if not skill_path.exists(): - removed = True - break - - time.sleep(0.2) - - if not removed: - try: - metadata_file = skill_path / "skill.json" - metadata = {} - if metadata_file.exists(): - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - metadata["enabled"] = False - with open(metadata_file, "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - logger.warning(f"⚠️ 删除目录失败,已软卸载技能: {skill_name}") - return True - except Exception: - return False - - importlib.invalidate_caches() - logger.info(f"✅ 删除技能目录: {skill_name}") - return True - - def get_skill(self, skill_name: str) -> Optional[Skill]: - """获取已加载技能实例。""" - - skill_name = self.normalize_skill_key(skill_name) - return self.skills.get(skill_name) - - def list_skills(self) -> List[str]: - """列出已加载技能。""" - - return sorted(self.skills.keys()) - - def list_available_skills(self) -> List[str]: - """列出可加载技能目录。""" - - if not self.skills_dir.exists(): - return [] - - available: List[str] = [] - for skill_dir in self.skills_dir.iterdir(): - if not skill_dir.is_dir() or skill_dir.name.startswith("_"): - continue - - if (skill_dir / "skill.json").exists() and (skill_dir / "main.py").exists(): - try: - with open(skill_dir / "skill.json", "r", encoding="utf-8") as f: - metadata = json.load(f) - if not metadata.get("enabled", True): - continue - available.append(self.normalize_skill_key(skill_dir.name)) - except ValueError: - continue - except Exception: - continue - - return sorted(set(available)) - - def get_all_tools(self) -> Dict[str, Callable]: - """获取全部技能工具。""" - - all_tools: Dict[str, Callable] = {} - for skill_name, skill in self.skills.items(): - for tool_name, tool_func in skill.get_tools().items(): - all_tools[f"{skill_name}.{tool_name}"] = tool_func - return all_tools - - async def reload_skill(self, skill_name: str) -> bool: - """重载技能。""" - - skill_name = self.normalize_skill_key(skill_name) - if skill_name in self.skills: - await self.unload_skill(skill_name) - return await self.load_skill(skill_name) - - def _parse_github_repo_source( - self, - source: str, - ) -> Optional[Tuple[str, str, str, Optional[str]]]: - """解析 GitHub 仓库 URL,返回 owner/repo/branch/subpath。""" - - try: - parsed = urllib.parse.urlparse(source.strip()) - except Exception: - return None - - if parsed.scheme not in {"http", "https"}: - return None - - if parsed.netloc.lower() != "github.com": - return None - - parts = [part for part in parsed.path.split("/") if part] - if len(parts) < 2: - return None - - owner = parts[0] - repo = parts[1] - if repo.endswith(".git"): - repo = repo[:-4] - - if not owner or not repo: - return None - - branch = "main" - subpath: Optional[str] = None - - if len(parts) >= 4 and parts[2] == "tree": - branch = parts[3] - if len(parts) > 4: - subpath = "/".join(parts[4:]) - elif len(parts) > 2: - # 不支持 /issues、/blob 等深链接 - return None - - return owner, repo, branch, subpath - - def _resolve_network_source( - self, - source: str, - ) -> Tuple[str, Optional[str], Optional[str]]: - """将网络来源解析为下载 URL,并返回安装提示信息。""" - - source = source.strip() - - github_repo = self._parse_github_repo_source(source) - if github_repo: - owner, repo, branch, subpath = github_repo - codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}" - return codeload_url, self.normalize_skill_key(repo), subpath - - if source.startswith(("http://", "https://")): - return source, None, None - - if self._GITHUB_SHORTCUT_PATTERN.match(source): - repo_ref, _, branch = source.partition("#") - owner, repo = repo_ref.split("/", 1) - if repo.endswith(".git"): - repo = repo[:-4] - branch = branch or "main" - codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}" - return codeload_url, self.normalize_skill_key(repo), None - - raise ValueError("source 必须是 URL 或 owner/repo[#branch]") - - def _download_zip(self, url: str, output_zip: Path): - """下载 zip 包到本地。""" - - req = urllib.request.Request(url, headers={"User-Agent": "QQBot-Skills/1.0"}) - with urllib.request.urlopen(req, timeout=30) as resp: - data = resp.read() - output_zip.write_bytes(data) - - def _find_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]: - """在目录中扫描技能候选项。""" - - candidates: List[Tuple[str, Path]] = [] - for metadata_file in root_dir.rglob("skill.json"): - candidate_dir = metadata_file.parent - if not (candidate_dir / "main.py").exists(): - continue - - try: - with open(metadata_file, "r", encoding="utf-8") as f: - metadata = json.load(f) - raw_name = str(metadata.get("name") or candidate_dir.name) - except Exception: - raw_name = candidate_dir.name - - try: - skill_key = self.normalize_skill_key(raw_name) - except ValueError: - continue - - candidates.append((skill_key, candidate_dir)) - - uniq: Dict[str, Path] = {} - for key, path in candidates: - uniq[key] = path - - return sorted(uniq.items(), key=lambda x: x[0]) - - def _find_codex_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]: - """在目录中扫描仅包含 SKILL.md 的候选项。""" - - candidates: List[Tuple[str, Path]] = [] - for skill_doc in root_dir.rglob("SKILL.md"): - candidate_dir = skill_doc.parent - if (candidate_dir / "skill.json").exists() and (candidate_dir / "main.py").exists(): - continue - - try: - skill_key = self.normalize_skill_key(candidate_dir.name) - except ValueError: - continue - - candidates.append((skill_key, candidate_dir)) - - uniq: Dict[str, Path] = {} - for key, path in candidates: - uniq[key] = path - - return sorted(uniq.items(), key=lambda x: x[0]) - - def _scope_extract_root(self, extract_root: Path, subpath: Optional[str]) -> Path: - """将解压目录缩小到 repo 子路径(若指定)。""" - - if not subpath: - return extract_root - - clean_subpath = Path(subpath.strip("/\\")) - if str(clean_subpath) == ".": - return extract_root - - direct = extract_root / clean_subpath - if direct.exists(): - return direct - - for child in extract_root.iterdir(): - candidate = child / clean_subpath - if candidate.exists(): - return candidate - - return extract_root - - @staticmethod - def _extract_markdown_title(markdown_content: str) -> str: - for line in markdown_content.splitlines(): - stripped = line.strip() - if stripped.startswith("#"): - return stripped.lstrip("#").strip() - return "" - - def _install_codex_skill_adapter(self, source_dir: Path, target_path: Path, skill_key: str): - """将 Codex SKILL.md 转换为可加载的项目技能。""" - - skill_doc = source_dir / "SKILL.md" - if not skill_doc.exists(): - raise FileNotFoundError(f"SKILL.md not found in {source_dir}") - - content = skill_doc.read_text(encoding="utf-8") - title = self._extract_markdown_title(content) - description = ( - f"Imported from SKILL.md: {title}" if title else f"Imported from SKILL.md ({skill_key})" - ) - - target_path.mkdir(parents=True, exist_ok=True) - (target_path / "SKILL.md").write_text(content, encoding="utf-8") - (target_path / "__init__.py").write_text("", encoding="utf-8") - - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": description, - "author": "imported", - "dependencies": [], - "enabled": True, - } - with open(target_path / "skill.json", "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill" - main_code = f'''"""{skill_key} adapter skill (generated from SKILL.md)""" -from pathlib import Path -from src.ai.skills.base import Skill - - -class {class_name}(Skill): - async def initialize(self): - self.register_tool("read_skill_doc", self.read_skill_doc) - - async def read_skill_doc(self) -> str: - skill_doc = Path(__file__).with_name("SKILL.md") - if not skill_doc.exists(): - return "SKILL.md not found." - return skill_doc.read_text(encoding="utf-8") - - async def cleanup(self): - pass -''' - (target_path / "main.py").write_text(main_code, encoding="utf-8") - - def install_skill_from_source( - self, - source: str, - skill_name: Optional[str] = None, - overwrite: bool = False, - ) -> Tuple[bool, str]: - """从网络或本地源安装技能目录(仅落盘,不自动加载)。""" - - desired_key = self.normalize_skill_key(skill_name) if skill_name else None - - with tempfile.TemporaryDirectory(prefix="qqbot_skill_") as tmp: - tmp_dir = Path(tmp) - extract_root: Optional[Path] = None - source_hint_key: Optional[str] = None - source_subpath: Optional[str] = None - - source_path = Path(source) - if source_path.exists(): - if source_path.is_dir(): - extract_root = source_path - elif source_path.is_file() and source_path.suffix.lower() == ".zip": - with zipfile.ZipFile(source_path, "r") as zf: - zf.extractall(tmp_dir / "extract") - extract_root = tmp_dir / "extract" - else: - return False, "本地 source 仅支持目录或 zip 文件" - else: - try: - url, source_hint_key, source_subpath = self._resolve_network_source(source) - except ValueError as exc: - return False, str(exc) - - download_zip = tmp_dir / "download.zip" - try: - self._download_zip(url, download_zip) - except Exception as exc: - # GitHub 简写默认 main 失败时尝试 master - if "codeload.github.com" in url and url.endswith("/main"): - fallback = url[:-4] + "master" - try: - self._download_zip(fallback, download_zip) - except Exception: - return False, f"下载技能失败: {exc}" - else: - return False, f"下载技能失败: {exc}" - - try: - with zipfile.ZipFile(download_zip, "r") as zf: - zf.extractall(tmp_dir / "extract") - except Exception as exc: - return False, f"解压技能失败: {exc}" - - extract_root = tmp_dir / "extract" - - extract_root = self._scope_extract_root(extract_root, source_subpath) - - candidates = self._find_skill_candidates(extract_root) - use_codex_adapter = False - - selected_key: Optional[str] = None - selected_path: Optional[Path] = None - - if candidates: - if desired_key: - for key, path in candidates: - if key == desired_key: - selected_key, selected_path = key, path - break - - if not selected_path: - names = ", ".join([k for k, _ in candidates]) - return False, f"源中未找到技能 {desired_key},可选: {names}" - else: - if len(candidates) > 1: - names = ", ".join([k for k, _ in candidates]) - return False, f"检测到多个技能,请指定 skill_name。可选: {names}" - selected_key, selected_path = candidates[0] - else: - codex_candidates = self._find_codex_skill_candidates(extract_root) - if not codex_candidates: - return False, "未找到可安装技能(需包含 skill.json 与 main.py,或 SKILL.md)" - - use_codex_adapter = True - if desired_key: - for key, path in codex_candidates: - if key == desired_key: - selected_key, selected_path = key, path - break - - if not selected_path: - if len(codex_candidates) == 1: - selected_key = desired_key - selected_path = codex_candidates[0][1] - else: - names = ", ".join([k for k, _ in codex_candidates]) - return False, f"源中未找到技能 {desired_key},可选: {names}" - else: - if source_hint_key: - for key, path in codex_candidates: - if key == source_hint_key: - selected_key, selected_path = key, path - break - - if not selected_path: - if len(codex_candidates) > 1: - names = ", ".join([k for k, _ in codex_candidates]) - return False, f"检测到多个 SKILL.md 技能,请指定 skill_name。可选: {names}" - selected_key, selected_path = codex_candidates[0] - if source_hint_key: - selected_key = source_hint_key - - assert selected_key is not None and selected_path is not None - target_path = self._get_skill_path(selected_key) - - if target_path.exists(): - if not overwrite: - return False, f"技能已存在: {selected_key}" - shutil.rmtree(target_path) - - if use_codex_adapter: - self._install_codex_skill_adapter(selected_path, target_path, selected_key) - else: - shutil.copytree( - selected_path, - target_path, - ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".git", ".github"), - ) - self._ensure_skill_package_layout(target_path, selected_key) - - importlib.invalidate_caches() - - logger.info(f"✅ 安装技能成功: {selected_key} <- {source}") - return True, selected_key - - -def create_skill_template( - skill_name: str, - output_dir: Path, - description: str = "技能描述", - author: str = "QQBot", -): - """创建技能模板。""" - - skill_key = SkillsManager.normalize_skill_key(skill_name) - skill_dir = output_dir / skill_key - skill_dir.mkdir(parents=True, exist_ok=True) - - metadata = { - "name": skill_key, - "version": "1.0.0", - "description": description, - "author": author, - "dependencies": [], - "enabled": True, - } - - with open(skill_dir / "skill.json", "w", encoding="utf-8") as f: - json.dump(metadata, f, ensure_ascii=False, indent=2) - - class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill" - main_code = f'''"""{skill_key} skill""" -from src.ai.skills.base import Skill - - -class {class_name}(Skill): - async def initialize(self): - self.register_tool("example_tool", self.example_tool) - - async def example_tool(self, text: str) -> str: - return f"{skill_key} 收到: {{text}}" - - async def cleanup(self): - pass -''' - - with open(skill_dir / "main.py", "w", encoding="utf-8") as f: - f.write(main_code) - - with open(skill_dir / "__init__.py", "w", encoding="utf-8") as f: - f.write("") - - readme = f"""# {skill_key} - -## 描述 -{description} - -## 工具 -- example_tool(text) -""" - - with open(skill_dir / "README.md", "w", encoding="utf-8") as f: - f.write(readme) - - logger.info(f"✅ 创建技能模板: {skill_dir}") diff --git a/src/ai/task_manager.py b/src/ai/task_manager.py deleted file mode 100644 index 2a0d45e..0000000 --- a/src/ai/task_manager.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -长任务管理器 - 处理需要多步骤的复杂任务 -""" -import asyncio -from typing import List, Dict, Optional, Callable, Any -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -import json -from pathlib import Path -import uuid - - -class TaskStatus(Enum): - """任务状态""" - PENDING = "pending" - RUNNING = "running" - PAUSED = "paused" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -@dataclass -class TaskStep: - """任务步骤""" - step_id: str - description: str - action: str - params: Dict[str, Any] - status: TaskStatus = TaskStatus.PENDING - result: Optional[Any] = None - error: Optional[str] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - - def to_dict(self) -> Dict: - return { - 'step_id': self.step_id, - 'description': self.description, - 'action': self.action, - 'params': self.params, - 'status': self.status.value, - 'result': self.result, - 'error': self.error, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None - } - - -@dataclass -class LongTask: - """长任务""" - task_id: str - user_id: str - title: str - description: str - steps: List[TaskStep] = field(default_factory=list) - status: TaskStatus = TaskStatus.PENDING - created_at: datetime = field(default_factory=datetime.now) - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - progress: float = 0.0 - metadata: Dict = field(default_factory=dict) - - def to_dict(self) -> Dict: - return { - 'task_id': self.task_id, - 'user_id': self.user_id, - 'title': self.title, - 'description': self.description, - 'steps': [step.to_dict() for step in self.steps], - 'status': self.status.value, - 'created_at': self.created_at.isoformat(), - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'completed_at': self.completed_at.isoformat() if self.completed_at else None, - 'progress': self.progress, - 'metadata': self.metadata - } - - @classmethod - def from_dict(cls, data: Dict) -> 'LongTask': - steps = [ - TaskStep( - step_id=s['step_id'], - description=s['description'], - action=s['action'], - params=s['params'], - status=TaskStatus(s['status']), - result=s.get('result'), - error=s.get('error'), - started_at=datetime.fromisoformat(s['started_at']) if s.get('started_at') else None, - completed_at=datetime.fromisoformat(s['completed_at']) if s.get('completed_at') else None - ) - for s in data['steps'] - ] - - return cls( - task_id=data['task_id'], - user_id=data['user_id'], - title=data['title'], - description=data['description'], - steps=steps, - status=TaskStatus(data['status']), - created_at=datetime.fromisoformat(data['created_at']), - started_at=datetime.fromisoformat(data['started_at']) if data.get('started_at') else None, - completed_at=datetime.fromisoformat(data['completed_at']) if data.get('completed_at') else None, - progress=data.get('progress', 0.0), - metadata=data.get('metadata', {}) - ) - - -class LongTaskManager: - """长任务管理器""" - - def __init__(self, storage_path: Path): - self.storage_path = storage_path - self.tasks: Dict[str, LongTask] = {} - self.action_handlers: Dict[str, Callable] = {} - self.running_tasks: Dict[str, asyncio.Task] = {} - self._load() - - def _load(self): - """加载任务""" - if self.storage_path.exists(): - with open(self.storage_path, 'r', encoding='utf-8') as f: - data = json.load(f) - for task_data in data: - task = LongTask.from_dict(task_data) - self.tasks[task.task_id] = task - - def _save(self): - """保存任务""" - self.storage_path.parent.mkdir(parents=True, exist_ok=True) - data = [task.to_dict() for task in self.tasks.values()] - with open(self.storage_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - def register_action(self, action_name: str, handler: Callable): - """注册动作处理器""" - self.action_handlers[action_name] = handler - - def create_task( - self, - user_id: str, - title: str, - description: str, - steps: List[Dict], - metadata: Optional[Dict] = None - ) -> str: - """创建任务""" - task_id = str(uuid.uuid4()) - - task_steps = [ - TaskStep( - step_id=str(uuid.uuid4()), - description=step['description'], - action=step['action'], - params=step.get('params', {}) - ) - for step in steps - ] - - task = LongTask( - task_id=task_id, - user_id=user_id, - title=title, - description=description, - steps=task_steps, - metadata=metadata or {} - ) - - self.tasks[task_id] = task - self._save() - - return task_id - - async def execute_task( - self, - task_id: str, - progress_callback: Optional[Callable[[str, float, str], None]] = None - ) -> bool: - """执行任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - - if task.status == TaskStatus.RUNNING: - return False - - task.status = TaskStatus.RUNNING - task.started_at = datetime.now() - self._save() - - try: - total_steps = len(task.steps) - - for i, step in enumerate(task.steps): - # 检查是否被取消 - if task.status == TaskStatus.CANCELLED: - break - - step.status = TaskStatus.RUNNING - step.started_at = datetime.now() - - # 执行步骤 - try: - handler = self.action_handlers.get(step.action) - if not handler: - raise ValueError(f"未找到动作处理器: {step.action}") - - result = await handler(**step.params) - step.result = result - step.status = TaskStatus.COMPLETED - - except Exception as e: - step.error = str(e) - step.status = TaskStatus.FAILED - task.status = TaskStatus.FAILED - self._save() - return False - - finally: - step.completed_at = datetime.now() - - # 更新进度 - task.progress = (i + 1) / total_steps - self._save() - - if progress_callback: - await progress_callback( - task_id, - task.progress, - f"完成步骤 {i+1}/{total_steps}: {step.description}" - ) - - task.status = TaskStatus.COMPLETED - task.completed_at = datetime.now() - task.progress = 1.0 - self._save() - - return True - - except Exception as e: - task.status = TaskStatus.FAILED - self._save() - raise e - - async def start_task( - self, - task_id: str, - progress_callback: Optional[Callable[[str, float, str], None]] = None - ): - """启动任务(异步)""" - if task_id in self.running_tasks: - return - - async def run(): - try: - await self.execute_task(task_id, progress_callback) - finally: - if task_id in self.running_tasks: - del self.running_tasks[task_id] - - self.running_tasks[task_id] = asyncio.create_task(run()) - - def pause_task(self, task_id: str) -> bool: - """暂停任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - if task.status == TaskStatus.RUNNING: - task.status = TaskStatus.PAUSED - self._save() - return True - - return False - - def cancel_task(self, task_id: str) -> bool: - """取消任务""" - if task_id not in self.tasks: - return False - - task = self.tasks[task_id] - task.status = TaskStatus.CANCELLED - self._save() - - # 取消正在运行的任务 - if task_id in self.running_tasks: - self.running_tasks[task_id].cancel() - - return True - - def get_task(self, task_id: str) -> Optional[LongTask]: - """获取任务""" - return self.tasks.get(task_id) - - def get_user_tasks(self, user_id: str) -> List[LongTask]: - """获取用户的所有任务""" - return [ - task for task in self.tasks.values() - if task.user_id == user_id - ] - - def get_task_status(self, task_id: str) -> Optional[Dict]: - """获取任务状态""" - task = self.get_task(task_id) - if not task: - return None - - completed_steps = sum(1 for step in task.steps if step.status == TaskStatus.COMPLETED) - total_steps = len(task.steps) - - return { - 'task_id': task.task_id, - 'title': task.title, - 'status': task.status.value, - 'progress': task.progress, - 'completed_steps': completed_steps, - 'total_steps': total_steps, - 'current_step': next( - (step.description for step in task.steps if step.status == TaskStatus.RUNNING), - None - ) - } diff --git a/src/ai/vector_store/json_store.py b/src/ai/vector_store/json_store.py index 87fea5b..7abd2b7 100644 --- a/src/ai/vector_store/json_store.py +++ b/src/ai/vector_store/json_store.py @@ -1,64 +1,78 @@ """ -JSON文件存储实现(向后兼容) +JSON-backed vector store implementation. """ + +from __future__ import annotations + +import asyncio import json import uuid -from typing import List, Dict, Optional -from pathlib import Path from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + import numpy as np -from .base import VectorStore, VectorMemory + +from .base import VectorMemory, VectorStore from src.utils.logger import setup_logger -logger = setup_logger('JSONStore') +logger = setup_logger("JSONStore") class JSONVectorStore(VectorStore): - """JSON文件存储实现(向后兼容旧版本)""" - + """JSON file storage implementation.""" + def __init__(self, storage_path: Path): - """初始化JSON存储""" self.storage_path = storage_path - self.memories: Dict[str, List[VectorMemory]] = {} # user_id -> List[VectorMemory] + self.memories: Dict[str, List[VectorMemory]] = {} + self._lock = asyncio.Lock() self._load() - logger.info(f"✅ JSON存储初始化: {storage_path}") - + logger.info(f"JSON storage initialized: {storage_path}") + def _load(self): - """加载记忆""" - if self.storage_path.exists(): - try: - with open(self.storage_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - for user_id, items in data.items(): - self.memories[user_id] = [] - for item in items: - # 兼容旧格式 - if 'id' not in item: - item['id'] = str(uuid.uuid4()) - - memory = VectorMemory.from_dict(item) - self.memories[user_id].append(memory) - - logger.info(f"加载了 {sum(len(v) for v in self.memories.values())} 条记忆") - except Exception as e: - logger.error(f"加载记忆失败: {e}") - self.memories = {} - - def _save(self): - """保存记忆""" + if not self.storage_path.exists(): + return + try: - self.storage_path.parent.mkdir(parents=True, exist_ok=True) - data = { - user_id: [memory.to_dict() for memory in memories] - for user_id, memories in self.memories.items() - } - - with open(self.storage_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"保存记忆失败: {e}") - + with open(self.storage_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception as exc: + logger.error(f"Failed to load memory file: {exc}") + self.memories = {} + return + + loaded = 0 + memories: Dict[str, List[VectorMemory]] = {} + for user_id, items in (data or {}).items(): + if not isinstance(items, list): + continue + + normalized: List[VectorMemory] = [] + for item in items: + if not isinstance(item, dict): + continue + if "id" not in item: + item["id"] = str(uuid.uuid4()) + try: + normalized.append(VectorMemory.from_dict(item)) + loaded += 1 + except Exception: + continue + + memories[str(user_id)] = normalized + + self.memories = memories + logger.info(f"Loaded {loaded} memories from JSON store") + + def _save_locked(self): + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + data = { + user_id: [memory.to_dict() for memory in user_memories] + for user_id, user_memories in self.memories.items() + } + with open(self.storage_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + async def add( self, id: str, @@ -66,9 +80,8 @@ class JSONVectorStore(VectorStore): content: str, embedding: List[float], importance: float, - metadata: Optional[Dict] = None + metadata: Optional[Dict] = None, ) -> bool: - """添加记忆""" try: memory = VectorMemory( id=id, @@ -79,120 +92,108 @@ class JSONVectorStore(VectorStore): timestamp=datetime.now(), metadata=metadata or {}, access_count=0, - last_access=None + last_access=None, ) - - if user_id not in self.memories: - self.memories[user_id] = [] - - self.memories[user_id].append(memory) - self._save() - - logger.debug(f"添加记忆: {id} (用户: {user_id})") + async with self._lock: + self.memories.setdefault(user_id, []).append(memory) + self._save_locked() return True - except Exception as e: - logger.error(f"添加记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to add memory: {exc}") return False - + async def search( self, user_id: str, query_embedding: List[float], limit: int = 5, - min_importance: float = 0.3 + min_importance: float = 0.3, ) -> List[VectorMemory]: - """搜索相似记忆""" - if user_id not in self.memories: + async with self._lock: + source = list(self.memories.get(user_id, [])) + + candidates = [m for m in source if m.importance >= min_importance] + if not candidates: return [] - - memories = self.memories[user_id] - - # 过滤重要性 - memories = [m for m in memories if m.importance >= min_importance] - - if not memories: - return [] - - # 使用向量相似度排序 + scored_memories = [] - for memory in memories: + for memory in candidates: if memory.embedding: similarity = self._cosine_similarity(query_embedding, memory.embedding) - scored_memories.append((similarity, memory)) - + if similarity is not None: + scored_memories.append((similarity, memory)) + if not scored_memories: - # 如果没有嵌入向量,按重要性排序 return await self.get_by_importance(user_id, limit, min_importance) - + scored_memories.sort(reverse=True, key=lambda x: x[0]) - return [m for _, m in scored_memories[:limit]] - + return [memory for _, memory in scored_memories[:limit]] + async def get_by_importance( self, user_id: str, limit: int = 5, - min_importance: float = 0.3 + min_importance: float = 0.3, ) -> List[VectorMemory]: - """按重要性获取记忆""" - if user_id not in self.memories: - return [] - - memories = [m for m in self.memories[user_id] if m.importance >= min_importance] + async with self._lock: + source = list(self.memories.get(user_id, [])) + + memories = [m for m in source if m.importance >= min_importance] memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) return memories[:limit] - - def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: - """计算余弦相似度""" - vec1 = np.array(vec1) - vec2 = np.array(vec2) - return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) - + + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> Optional[float]: + arr1 = np.array(vec1, dtype=float) + arr2 = np.array(vec2, dtype=float) + denom = np.linalg.norm(arr1) * np.linalg.norm(arr2) + if denom == 0: + return None + return float(np.dot(arr1, arr2) / denom) + async def update_access(self, memory_id: str) -> bool: - """更新访问记录""" try: - for memories in self.memories.values(): - for memory in memories: - if memory.id == memory_id: - memory.access_count += 1 - memory.last_access = datetime.now() - self._save() - return True + async with self._lock: + for memories in self.memories.values(): + for memory in memories: + if memory.id == memory_id: + memory.access_count += 1 + memory.last_access = datetime.now() + self._save_locked() + return True return False - except Exception as e: - logger.error(f"更新访问记录失败: {e}") + except Exception as exc: + logger.error(f"Failed to update access: {exc}") return False - + async def delete(self, memory_id: str) -> bool: - """删除记忆""" try: - for user_id, memories in self.memories.items(): - for i, memory in enumerate(memories): - if memory.id == memory_id: - del self.memories[user_id][i] - self._save() - return True + async with self._lock: + for user_id, memories in self.memories.items(): + for idx, memory in enumerate(memories): + if memory.id == memory_id: + del self.memories[user_id][idx] + self._save_locked() + return True return False - except Exception as e: - logger.error(f"删除记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to delete memory: {exc}") return False - + async def get_all(self, user_id: str) -> List[VectorMemory]: - """获取用户所有记忆""" - return self.memories.get(user_id, []) - + async with self._lock: + return list(self.memories.get(user_id, [])) + async def clear_user(self, user_id: str) -> bool: - """清除用户所有记忆""" try: - if user_id in self.memories: - del self.memories[user_id] - self._save() - logger.info(f"清除用户记忆: {user_id}") + async with self._lock: + self.memories.pop(user_id, None) + self._save_locked() return True - except Exception as e: - logger.error(f"清除用户记忆失败: {e}") + except Exception as exc: + logger.error(f"Failed to clear user memories: {exc}") return False - + async def close(self): - """关闭连接""" - self._save() - logger.info("JSON存储已关闭") + async with self._lock: + self._save_locked() diff --git a/src/core/bot.py b/src/core/bot.py index f05f9e2..e51208e 100644 --- a/src/core/bot.py +++ b/src/core/bot.py @@ -1,105 +1,71 @@ """ -QQ机器人主程序 -基于官方SDK: https://github.com/tencent-connect/botpy -官方文档: https://bot.q.qq.com/wiki/develop/api-v2/ +QQ bot application entry module. """ + +from __future__ import annotations + import botpy from botpy.message import Message -from src.core.config import Config -from src.utils.logger import setup_logger -from src.handlers.message_handler_ai import MessageHandler -logger = setup_logger('QQBot') +from src.core.config import Config +from src.handlers.message_handler_ai import MessageHandler +from src.utils.logger import setup_logger + +logger = setup_logger("QQBot") def build_intents() -> botpy.Intents: - """ - 构建最小可用的 intents。 - - - public_guild_messages: 频道公域 @机器人 消息 - - public_messages: 群聊@ + C2C 私聊(好友单聊)消息 - """ intents = botpy.Intents.none() intents.public_guild_messages = True - # 新版 botpy 中,QQ 群聊@ / C2C 私聊依赖 public_messages(GROUP_AND_C2C_EVENT)。 if hasattr(intents, "public_messages"): intents.public_messages = True - logger.info("✅ 已启用 public_messages(支持群聊@与 C2C 私聊)") + logger.info("Enabled public_messages for group/C2C events") else: - logger.warning("⚠️ 当前 botpy 不支持 public_messages,可能无法接收 C2C 私聊事件") + logger.warning("Current botpy version does not expose public_messages") return intents class MyClient(botpy.Client): - """QQ机器人客户端""" - + """QQ bot client wrapper.""" + def __init__(self, intents: botpy.Intents): super().__init__(intents=intents) self.message_handler = MessageHandler(self) - + async def on_ready(self): - """机器人启动完成事件""" - logger.info(f"🤖 机器人已启动: {self.robot.name} (ID: {self.robot.id})") - + logger.info(f"Bot is ready: {self.robot.name} (ID: {self.robot.id})") + async def on_at_message_create(self, message: Message): - """处理@机器人的消息(频道公域消息)""" await self.message_handler.handle_at_message(message) - + async def on_message_create(self, message: Message): - """处理普通消息(需要私域权限)""" await self.message_handler.handle_at_message(message) - + async def on_direct_message_create(self, message: Message): - """处理私信消息""" await self.message_handler.handle_at_message(message) - + async def on_group_at_message_create(self, message: Message): - """处理群聊@消息""" await self.message_handler.handle_at_message(message) - + async def on_c2c_message_create(self, message: Message): - """处理C2C消息(单聊)""" await self.message_handler.handle_at_message(message) - - async def on_guild_create(self, guild): - """机器人加入频道事件""" - logger.info(f"➕ 加入频道: {guild.name} (ID: {guild.id})") - - async def on_guild_delete(self, guild): - """机器人离开频道事件""" - logger.info(f"➖ 离开频道: {guild.name} (ID: {guild.id})") - + async def on_error(self, error): - """错误处理""" - logger.error(f"❌ 发生错误: {error}") + logger.error(f"Bot error: {error}", exc_info=True) + + async def on_close(self): + ai_client = getattr(self.message_handler, "ai_client", None) + if ai_client: + try: + await ai_client.close() + except Exception as exc: + logger.warning(f"Failed to close AI resources cleanly: {exc}") def main(): - """主函数""" - try: - # 验证配置 - Config.validate() - logger.info("✅ 配置验证通过") - - # 创建机器人实例 - 使用最小可用 intents,避免 disallowed intents(4014) - intents = build_intents() - logger.info("✅ Intents 配置完成(最小权限模式)") - - client = MyClient(intents=intents) - - # 启动机器人 - logger.info("🚀 正在启动机器人...") - client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) - - except ValueError as e: - logger.error(f"❌ 配置错误: {e}") - logger.error("请检查 .env 文件配置") - except Exception as e: - logger.error(f"❌ 启动失败: {e}") - raise - - -if __name__ == "__main__": - main() + Config.validate() + intents = build_intents() + client = MyClient(intents=intents) + client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) diff --git a/src/core/config.py b/src/core/config.py index b4a6591..8c18007 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,72 +1,123 @@ """ -QQ机器人配置管理模块 +Centralized project configuration. """ + +from __future__ import annotations + import os -from typing import Optional +from typing import Optional, Set + from dotenv import load_dotenv -# 加载环境变量 load_dotenv() def _read_env(name: str, default: Optional[str] = None) -> Optional[str]: - """ - 读取并清洗环境变量。 - - 去除首尾空白 - - 空字符串视为未设置 - - 以 # 开头的值视为注释占位,视为未设置 - """ value = os.getenv(name) if value is None: return default - + value = value.strip() - if not value or value.startswith('#'): + if not value or value.startswith("#"): return default - return value +def _read_bool(name: str, default: bool) -> bool: + raw = _read_env(name, None) + if raw is None: + return default + lowered = raw.lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def _read_int(name: str, default: int) -> int: + raw = _read_env(name, None) + if raw is None: + return default + try: + return int(raw) + except ValueError: + return default + + +def _read_float(name: str, default: float) -> float: + raw = _read_env(name, None) + if raw is None: + return default + try: + return float(raw) + except ValueError: + return default + + +def _read_csv_set(name: str) -> Set[str]: + raw = _read_env(name, "") or "" + return {item.strip() for item in raw.split(",") if item.strip()} + + class Config: - """机器人配置类""" - - # 机器人基本信息 - BOT_APPID = _read_env('BOT_APPID', '') or '' - BOT_SECRET = _read_env('BOT_SECRET', '') or '' - - # 日志配置 - LOG_LEVEL = _read_env('LOG_LEVEL', 'INFO') or 'INFO' - - # 沙箱模式 - SANDBOX_MODE = os.getenv('SANDBOX_MODE', 'False').lower() == 'true' - - # AI配置 - AI_PROVIDER = _read_env('AI_PROVIDER', 'openai') or 'openai' - AI_MODEL = _read_env('AI_MODEL', 'gpt-4') or 'gpt-4' - AI_API_KEY = _read_env('AI_API_KEY', '') or '' - AI_API_BASE = _read_env('AI_API_BASE', None) - - # AI嵌入模型配置(用于RAG) - AI_EMBED_PROVIDER = _read_env('AI_EMBED_PROVIDER', 'openai') or 'openai' - AI_EMBED_MODEL = _read_env('AI_EMBED_MODEL', 'text-embedding-3-small') or 'text-embedding-3-small' - AI_EMBED_API_KEY = _read_env('AI_EMBED_API_KEY', None) # 留空则使用 AI_API_KEY - AI_EMBED_API_BASE = _read_env('AI_EMBED_API_BASE', None) # 留空则使用 AI_API_BASE - - # 向量数据库配置 - AI_USE_VECTOR_DB = os.getenv('AI_USE_VECTOR_DB', 'true').lower() == 'true' - + """Application runtime configuration.""" + + BOT_APPID = _read_env("BOT_APPID", "") or "" + BOT_SECRET = _read_env("BOT_SECRET", "") or "" + + ENV = _read_env("APP_ENV", "dev") or "dev" + LOG_LEVEL = _read_env("LOG_LEVEL", "INFO") or "INFO" + LOG_FORMAT = (_read_env("LOG_FORMAT", "text") or "text").lower() + + SANDBOX_MODE = _read_bool("SANDBOX_MODE", False) + + AI_PROVIDER = _read_env("AI_PROVIDER", "openai") or "openai" + AI_MODEL = _read_env("AI_MODEL", "gpt-4") or "gpt-4" + AI_API_KEY = _read_env("AI_API_KEY", "") or "" + AI_API_BASE = _read_env("AI_API_BASE", None) + + AI_EMBED_PROVIDER = _read_env("AI_EMBED_PROVIDER", "openai") or "openai" + AI_EMBED_MODEL = ( + _read_env("AI_EMBED_MODEL", "text-embedding-3-small") + or "text-embedding-3-small" + ) + AI_EMBED_API_KEY = _read_env("AI_EMBED_API_KEY", None) + AI_EMBED_API_BASE = _read_env("AI_EMBED_API_BASE", None) + + AI_USE_VECTOR_DB = _read_bool("AI_USE_VECTOR_DB", True) + AI_USE_QUERY_EMBEDDING = _read_bool("AI_USE_QUERY_EMBEDDING", False) + + # `user`: one memory bucket per user + # `session`: memory bucket scoped to chat session (user + channel/group) + AI_MEMORY_SCOPE = (_read_env("AI_MEMORY_SCOPE", "session") or "session").lower() + + AI_CHAT_RETRIES = max(0, _read_int("AI_CHAT_RETRIES", 1)) + AI_CHAT_RETRY_BACKOFF_SECONDS = max( + 0.0, _read_float("AI_CHAT_RETRY_BACKOFF_SECONDS", 0.8) + ) + + MESSAGE_DEDUP_SECONDS = max(1, _read_int("MESSAGE_DEDUP_SECONDS", 30)) + MESSAGE_DEDUP_MAX_SIZE = max(128, _read_int("MESSAGE_DEDUP_MAX_SIZE", 4096)) + + BOT_ADMIN_IDS = _read_csv_set("BOT_ADMIN_IDS") + @classmethod - def validate(cls): - """验证配置是否完整""" + def is_admin(cls, user_id: Optional[str]) -> bool: + if not cls.BOT_ADMIN_IDS: + return True + if not user_id: + return False + return str(user_id) in cls.BOT_ADMIN_IDS + + @classmethod + def validate(cls) -> bool: if not cls.BOT_APPID: raise ValueError("BOT_APPID 未配置") if not cls.BOT_SECRET: raise ValueError("BOT_SECRET 未配置") - - # AI配置验证(可选) - if cls.AI_API_KEY: - print(f"✅ AI配置: {cls.AI_PROVIDER}/{cls.AI_MODEL}") - else: - print("⚠️ AI_API_KEY 未设置,AI功能将不可用") - + + if cls.AI_MEMORY_SCOPE not in {"user", "session"}: + raise ValueError("AI_MEMORY_SCOPE 仅支持 user 或 session") + return True diff --git a/src/handlers/__init__.py b/src/handlers/__init__.py index 0e5c565..15e0907 100644 --- a/src/handlers/__init__.py +++ b/src/handlers/__init__.py @@ -1,6 +1,5 @@ -""" -消息处理模块 -""" -from .message_handler import MessageHandler +"""Message handlers.""" -__all__ = ['MessageHandler'] +from .message_handler_ai import MessageHandler + +__all__ = ["MessageHandler"] diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 55da2ef..874b742 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -1,45 +1,54 @@ """ -集成 AI 能力的 QQ 消息处理器。 +QQ message handler with AI, memory and persona capabilities. """ +from __future__ import annotations + import asyncio import json -from pathlib import Path import re -from typing import Any, Dict, Optional +import time +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Tuple import httpx - -from botpy.message import Message +try: + from botpy.message import Message +except Exception: # pragma: no cover - used for test environments without botpy + class Message: # type: ignore[no-redef] + pass from src.ai import AIClient from src.ai.base import ModelConfig, ModelProvider -from src.ai.mcp import MCPManager -from src.ai.mcp.servers import FileSystemMCPServer from src.ai.personality import PersonalityProfile, PersonalityTrait -from src.ai.skills import SkillsManager from src.core.config import Config from src.utils.logger import setup_logger logger = setup_logger("MessageHandler") +@dataclass +class ChatScope: + user_id: str + group_id: Optional[str] + session_id: str + memory_key: str + + class MessageHandler: - """消息处理器(集成 AI 功能)。""" - MODEL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]") - + MENTION_PATTERNS = [re.compile(r"<@!\d+>"), re.compile(r"<@\d+>")] MARKDOWN_PATTERNS = [ (re.compile(r"```[\s\S]*?```"), lambda m: m.group(0).replace("```", "")), (re.compile(r"`([^`]+)`"), r"\1"), (re.compile(r"\*\*([^*]+)\*\*"), r"\1"), (re.compile(r"\*([^*]+)\*"), r"\1"), (re.compile(r"__([^_]+)__"), r"\1"), - # Avoid stripping underscores inside identifiers like model keys. (re.compile(r"(?\s?", re.MULTILINE), ""), - # Keep link label only to avoid QQ URL-blocking in private messages. (re.compile(r"\[([^\]]+)\]\(([^)]+)\)"), r"\1"), (re.compile(r"^[-*]\s+", re.MULTILINE), "- "), (re.compile(r"^\d+\.\s+", re.MULTILINE), "- "), @@ -52,82 +61,16 @@ class MessageHandler: def __init__(self, bot): self.bot = bot - self.ai_client = None - self.skills_manager = None - self.mcp_manager = None + self.ai_client: Optional[AIClient] = None self.model_profiles_path = Path("config/models.json") self.model_profiles: Dict[str, Dict[str, Any]] = {} self.active_model_key = "default" self._ai_initialized = False self._init_lock = asyncio.Lock() - @staticmethod - def _get_user_id(message: Message) -> str: - author = getattr(message, "author", None) - if not author: - return "unknown" - - return ( - getattr(author, "id", None) - or getattr(author, "user_openid", None) - or getattr(author, "member_openid", None) - or getattr(author, "union_openid", None) - or "unknown" - ) - - @staticmethod - def _build_skills_usage(command_name: str = "/skills") -> str: - return ( - "技能命令:\n" - f"{command_name} 或 {command_name} list\n" - f"{command_name} install [skill_name]\n" - f" source 支持:本地技能名、URL、owner/repo、owner/repo#branch、GitHub 仓库 URL(.git)\n" - f"{command_name} uninstall \n" - f"{command_name} reload " - ) - - @staticmethod - def _build_personality_usage() -> str: - return ( - "人设命令:\n" - "/personality\n" - "/personality list\n" - "/personality set \n" - "/personality add \n" - " 兼容旧格式: 或 管道 name|description|speaking_style|TRAIT1,TRAIT2|custom_instructions\n" - "/personality remove " - ) - - @staticmethod - def _build_models_usage(command_name: str = "/models") -> str: - return ( - "模型命令:\n" - f"{command_name} 或 {command_name} list\n" - f"{command_name} current\n" - f"{command_name} add (保持 provider/api_base/api_key 不变)\n" - f"{command_name} add [api_base]\n" - f"{command_name} add \n" - " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n" - f"{command_name} switch \n" - f"{command_name} remove " - ) - - @staticmethod - def _build_memory_usage(command_name: str = "/memory") -> str: - return ( - "记忆命令:\n" - f"{command_name} 或 {command_name} list [limit]\n" - f"{command_name} get \n" - f"{command_name} add \n" - f"{command_name} add \n" - f"{command_name} update \n" - f"{command_name} delete \n" - f"{command_name} search [limit]\n" - "/clear (默认等价 /clear short)\n" - "/clear short\n" - "/clear long\n" - "/clear all" - ) + self._seen_messages: OrderedDict[str, float] = OrderedDict() + self._dedup_window_seconds = Config.MESSAGE_DEDUP_SECONDS + self._dedup_max_size = Config.MESSAGE_DEDUP_MAX_SIZE @staticmethod def _provider_map() -> Dict[str, ModelProvider]: @@ -139,83 +82,34 @@ class MessageHandler: "siliconflow": ModelProvider.OPENAI, } - @classmethod - def _normalize_model_key(cls, raw_key: str) -> str: - key = raw_key.strip().lower().replace("-", "_").replace(" ", "_") - key = cls.MODEL_KEY_PATTERN.sub("_", key) - key = re.sub(r"_+", "_", key).strip("_") - if not key: - raise ValueError("模型 key 不能为空") - if key[0].isdigit(): - key = f"model_{key}" - return key + @staticmethod + def _get_user_id(message: Message) -> str: + author = getattr(message, "author", None) + if not author: + return "unknown" + return ( + getattr(author, "id", None) + or getattr(author, "user_openid", None) + or getattr(author, "member_openid", None) + or getattr(author, "union_openid", None) + or "unknown" + ) @staticmethod - def _compact_model_key(raw_key: str) -> str: - return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower()) + def _get_group_id(message: Message) -> Optional[str]: + return ( + getattr(message, "group_openid", None) + or getattr(message, "group_id", None) + or getattr(message, "channel_id", None) + or getattr(message, "guild_id", None) + ) - def _ordered_model_keys(self) -> list[str]: - return sorted(self.model_profiles.keys()) - - def _resolve_model_selector(self, selector: str) -> str: - raw = (selector or "").strip() - if not raw: - raise ValueError("模型 key 不能为空") - - ordered_keys = self._ordered_model_keys() - if raw.isdigit(): - index = int(raw) - if index < 1 or index > len(ordered_keys): - raise ValueError( - f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}" - ) - return ordered_keys[index - 1] - - if raw in self.model_profiles: - return raw - - normalized_selector: Optional[str] - try: - normalized_selector = self._normalize_model_key(raw) - except ValueError: - normalized_selector = None - - if normalized_selector and normalized_selector in self.model_profiles: - return normalized_selector - - normalized_candidates: Dict[str, list[str]] = {} - compact_candidates: Dict[str, list[str]] = {} - for key in ordered_keys: - try: - normalized_key = self._normalize_model_key(key) - except ValueError: - normalized_key = key.strip().lower() - normalized_candidates.setdefault(normalized_key, []).append(key) - compact_candidates.setdefault( - self._compact_model_key(normalized_key), [] - ).append(key) - - if normalized_selector and normalized_selector in normalized_candidates: - matches = normalized_candidates[normalized_selector] - if len(matches) == 1: - return matches[0] - raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") - - compact_selector = self._compact_model_key(normalized_selector or raw) - if compact_selector in compact_candidates: - matches = compact_candidates[compact_selector] - if len(matches) == 1: - return matches[0] - raise ValueError(f"匹配到多个模型 key,请使用完整 key: {', '.join(matches)}") - - raise ValueError(f"模型配置不存在: {raw}") - - @classmethod - def _parse_provider(cls, raw_provider: str) -> ModelProvider: - provider = cls._provider_map().get(raw_provider.strip().lower()) - if not provider: - raise ValueError(f"不支持的 provider: {raw_provider}") - return provider + def _build_scope(self, message: Message) -> ChatScope: + user_id = self._get_user_id(message) + group_id = self._get_group_id(message) + session_id = f"{group_id}:{user_id}" if group_id else user_id + memory_key = user_id if Config.AI_MEMORY_SCOPE == "user" else session_id + return ChatScope(user_id=user_id, group_id=group_id, session_id=session_id, memory_key=memory_key) @staticmethod def _coerce_float(value: Any, default: float) -> float: @@ -246,10 +140,64 @@ class MessageHandler: return default return bool(value) + @classmethod + def _normalize_model_key(cls, raw_key: str) -> str: + key = raw_key.strip().lower().replace("-", "_").replace(" ", "_") + key = cls.MODEL_KEY_PATTERN.sub("_", key) + key = re.sub(r"_+", "_", key).strip("_") + if not key: + raise ValueError("模型 key 不能为空") + if key[0].isdigit(): + key = f"model_{key}" + return key + @staticmethod - def _model_config_to_dict( - config: ModelConfig, include_api_key: bool = True - ) -> Dict[str, Any]: + def _compact_model_key(raw_key: str) -> str: + return re.sub(r"[^a-z0-9]", "", (raw_key or "").strip().lower()) + + def _ordered_model_keys(self) -> list[str]: + return sorted(self.model_profiles.keys()) + + @classmethod + def _parse_provider(cls, raw_provider: str) -> ModelProvider: + provider = cls._provider_map().get(raw_provider.strip().lower()) + if not provider: + raise ValueError(f"不支持的 provider: {raw_provider}") + return provider + + def _resolve_model_selector(self, selector: str) -> str: + raw = (selector or "").strip() + if not raw: + raise ValueError("模型 key 不能为空") + + ordered_keys = self._ordered_model_keys() + if raw.isdigit(): + index = int(raw) + if index < 1 or index > len(ordered_keys): + raise ValueError(f"模型序号超出范围: {index},可选 1-{len(ordered_keys)}") + return ordered_keys[index - 1] + + if raw in self.model_profiles: + return raw + + normalized_selector: Optional[str] + try: + normalized_selector = self._normalize_model_key(raw) + except ValueError: + normalized_selector = None + + if normalized_selector and normalized_selector in self.model_profiles: + return normalized_selector + + compact_selector = self._compact_model_key(normalized_selector or raw) + for key in ordered_keys: + if compact_selector == self._compact_model_key(key): + return key + + raise ValueError(f"模型配置不存在: {raw}") + + @staticmethod + def _model_config_to_dict(config: ModelConfig, include_api_key: bool = True) -> Dict[str, Any]: data: Dict[str, Any] = { "provider": config.provider.value, "model_name": config.model_name, @@ -267,9 +215,7 @@ class MessageHandler: data["api_key"] = config.api_key return data - def _model_config_from_dict( - self, raw: Dict[str, Any], fallback: ModelConfig - ) -> ModelConfig: + def _model_config_from_dict(self, raw: Dict[str, Any], fallback: ModelConfig) -> ModelConfig: provider = self._parse_provider(str(raw.get("provider") or fallback.provider.value)) model_name = str(raw.get("model_name") or fallback.model_name).strip() if not model_name: @@ -291,12 +237,8 @@ class MessageHandler: temperature=self._coerce_float(raw.get("temperature"), fallback.temperature), max_tokens=self._coerce_int(raw.get("max_tokens"), fallback.max_tokens), top_p=self._coerce_float(raw.get("top_p"), fallback.top_p), - frequency_penalty=self._coerce_float( - raw.get("frequency_penalty"), fallback.frequency_penalty - ), - presence_penalty=self._coerce_float( - raw.get("presence_penalty"), fallback.presence_penalty - ), + frequency_penalty=self._coerce_float(raw.get("frequency_penalty"), fallback.frequency_penalty), + presence_penalty=self._coerce_float(raw.get("presence_penalty"), fallback.presence_penalty), timeout=self._coerce_int(raw.get("timeout"), fallback.timeout), stream=self._coerce_bool(raw.get("stream"), fallback.stream), ) @@ -323,43 +265,20 @@ class MessageHandler: for raw_key, raw_profile in raw_profiles.items(): if not isinstance(raw_profile, dict): continue - key_text = str(raw_key or "").strip() if not key_text: continue - try: normalized_key = self._normalize_model_key(key_text) except ValueError: continue - - if normalized_key in profiles and profiles[normalized_key] != raw_profile: - logger.warning( - f"duplicate model key after normalization, keep first: {normalized_key}" - ) - continue - profiles[normalized_key] = raw_profile if not profiles: - profiles = { - "default": self._model_config_to_dict( - default_config, include_api_key=False - ) - } + profiles = {"default": self._model_config_to_dict(default_config, include_api_key=False)} active_raw = str(payload.get("active") or "").strip() - active = "" - if active_raw in profiles: - active = active_raw - elif active_raw: - try: - normalized_active = self._normalize_model_key(active_raw) - except ValueError: - normalized_active = "" - if normalized_active in profiles: - active = normalized_active - + active = active_raw if active_raw in profiles else "" if not active: active = "default" if "default" in profiles else sorted(profiles.keys())[0] @@ -367,44 +286,21 @@ class MessageHandler: self.active_model_key = active try: - active_config = self._model_config_from_dict( - self.model_profiles[self.active_model_key], default_config - ) - except Exception as exc: - logger.warning(f"active model profile invalid, fallback to default: {exc}") - self.model_profiles = { - "default": self._model_config_to_dict( - default_config, include_api_key=False - ) - } - self.active_model_key = "default" + active_config = self._model_config_from_dict(self.model_profiles[active], default_config) + except Exception: active_config = default_config - self._save_model_profiles() return active_config def _ensure_model_profiles_ready(self): if not self.ai_client: return - if self.model_profiles and self.active_model_key in self.model_profiles: return - active_config = self._load_model_profiles(self.ai_client.config) if active_config != self.ai_client.config: self.ai_client.switch_model(active_config) - def _plain_text(self, text: str) -> str: - if not text: - return "" - - result = text - for pattern, replacement in self.MARKDOWN_PATTERNS: - result = pattern.sub(replacement, result) - result = self._strip_urls(result) - - return result.strip() - @classmethod def _strip_urls(cls, text: str) -> str: result = text @@ -412,99 +308,85 @@ class MessageHandler: result = pattern.sub("[链接已省略]", result) return result + def _plain_text(self, text: str) -> str: + if not text: + return "" + result = text + for pattern, replacement in self.MARKDOWN_PATTERNS: + result = pattern.sub(replacement, result) + return self._strip_urls(result).strip() + + @staticmethod + def _looks_like_url_block_error(exc: Exception) -> bool: + lower = str(exc).lower() + return "url" in lower and ( + "not allow" in lower or "not allowed" in lower or "不允许" in lower or "forbidden" in lower + ) + async def _reply_plain(self, message: Message, text: str): - content = self._plain_text(text) + content = self._plain_text(text) or " " try: await message.reply(content=content) + return except Exception as exc: - # QQ C2C may reject any message containing URL. - if "不允许发送url" not in str(exc).lower(): - raise + if not self._looks_like_url_block_error(exc): + logger.warning(f"reply failed once, retrying: {exc}") + await asyncio.sleep(0.2) + await message.reply(content=content) + return + fallback = self._strip_urls(content).strip() or "内容包含受限链接,已省略。" + await message.reply(content=fallback) - logger.warning("消息被平台判定包含 URL,尝试二次清洗后重发") - fallback = self._strip_urls(content).strip() or "内容包含受限链接,已省略。" - await message.reply(content=fallback) + def _extract_message_identity(self, message: Message) -> str: + msg_id = ( + getattr(message, "id", None) + or getattr(message, "message_id", None) + or getattr(message, "msg_id", None) + ) + if msg_id: + return str(msg_id) + author = self._get_user_id(message) + group = self._get_group_id(message) or "-" + content = (getattr(message, "content", "") or "").strip() + return f"fallback:{author}:{group}:{hash(content)}" - def _register_skill_tools(self, skill_name: str) -> int: - if not self.skills_manager or not self.ai_client: - return 0 + def _is_duplicate_message(self, message: Message) -> bool: + now = time.time() + message_key = self._extract_message_identity(message) + expired_before = now - self._dedup_window_seconds + while self._seen_messages: + first_key = next(iter(self._seen_messages)) + if self._seen_messages[first_key] >= expired_before: + break + self._seen_messages.popitem(last=False) + if message_key in self._seen_messages: + return True + self._seen_messages[message_key] = now + if len(self._seen_messages) > self._dedup_max_size: + self._seen_messages.popitem(last=False) + return False - skill = self.skills_manager.get_skill(skill_name) - if not skill: - return 0 + def _is_admin(self, user_id: str) -> bool: + return Config.is_admin(user_id) - count = 0 - for tool_name, tool_func in skill.get_tools().items(): - full_tool_name = f"{skill_name}.{tool_name}" - self.ai_client.register_tool( - name=full_tool_name, - description=f"技能工具: {full_tool_name}", - parameters={"type": "object", "properties": {}}, - function=tool_func, - source="skills", - ) - count += 1 - - return count - - async def _register_mcp_tools(self) -> int: - if not self.mcp_manager or not self.ai_client: - return 0 - - tools = await self.mcp_manager.get_all_tools_for_ai() - count = 0 - - for item in tools: - function_info = item.get("function") if isinstance(item, dict) else None - if not isinstance(function_info, dict): - continue - - full_tool_name = function_info.get("name") - if not full_tool_name: - continue - - parameters = function_info.get("parameters") - if not isinstance(parameters, dict): - parameters = {"type": "object", "properties": {}} - - async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs): - if not self.mcp_manager: - raise RuntimeError("MCP manager not initialized") - return await self.mcp_manager.execute_tool(_full_tool_name, kwargs) - - self.ai_client.register_tool( - name=full_tool_name, - description=f"MCP工具: {full_tool_name}", - parameters=parameters, - function=_mcp_proxy, - source="mcp", - ) - count += 1 - - return count + async def _require_admin(self, message: Message, user_id: str, action: str) -> bool: + if self._is_admin(user_id): + return True + await self._reply_plain(message, f"权限不足,无法执行: {action}") + return False def _parse_traits(self, traits_raw) -> list[PersonalityTrait]: - traits = [] - if isinstance(traits_raw, str): trait_names = [name.strip().upper() for name in traits_raw.split(",") if name.strip()] elif isinstance(traits_raw, list): trait_names = [str(name).strip().upper() for name in traits_raw if str(name).strip()] else: trait_names = [] - - for name in trait_names: - if name in PersonalityTrait.__members__: - traits.append(PersonalityTrait[name]) - - if not traits: - traits = [PersonalityTrait.FRIENDLY] - - return traits + traits = [PersonalityTrait[name] for name in trait_names if name in PersonalityTrait.__members__] + return traits or [PersonalityTrait.FRIENDLY] def _parse_personality_payload(self, key: str, payload: str) -> PersonalityProfile: payload = payload.strip() - if payload.startswith("{"): data = json.loads(payload) return PersonalityProfile( @@ -517,98 +399,129 @@ class MessageHandler: ) if "|" not in payload: - introduction = payload return PersonalityProfile( name=key, - description=introduction, + description=payload, traits=[PersonalityTrait.FRIENDLY], speaking_style="自然口语", - custom_instructions=introduction, + custom_instructions=payload, ) parts = [part.strip() for part in payload.split("|")] - name = parts[0] if len(parts) >= 1 and parts[0] else key - description = parts[1] if len(parts) >= 2 and parts[1] else "自定义人设" - speaking_style = parts[2] if len(parts) >= 3 and parts[2] else "自然口语" - traits = self._parse_traits(parts[3] if len(parts) >= 4 else "") - custom_instructions = parts[4] if len(parts) >= 5 else "" - return PersonalityProfile( - name=name, - description=description, - traits=traits, - speaking_style=speaking_style, - custom_instructions=custom_instructions, + name=parts[0] if len(parts) >= 1 and parts[0] else key, + description=parts[1] if len(parts) >= 2 and parts[1] else "自定义人设", + traits=self._parse_traits(parts[3] if len(parts) >= 4 else ""), + speaking_style=parts[2] if len(parts) >= 3 and parts[2] else "自然口语", + custom_instructions=parts[4] if len(parts) >= 5 else "", + ) + + @staticmethod + def _parse_scope_token(token: Optional[str], scope: ChatScope) -> Tuple[str, Optional[str]]: + raw = (token or "").strip().lower() + if not raw: + return "global", None + if ":" in raw: + scope_name, scope_id = raw.split(":", 1) + return scope_name, scope_id or None + if raw == "global": + return "global", None + if raw == "user": + return "user", scope.user_id + if raw == "group": + return "group", scope.group_id + if raw == "session": + return "session", scope.session_id + return "global", None + + @staticmethod + def _build_personality_usage() -> str: + return ( + "人设命令:\n" + "/personality\n" + "/personality list\n" + "/personality set [global|user|group|session]\n" + "/personality add \n" + "/personality remove " + ) + + @staticmethod + def _build_models_usage(command_name: str = "/models") -> str: + return ( + "模型命令:\n" + f"{command_name} 或 {command_name} list\n" + f"{command_name} current\n" + f"{command_name} add \n" + f"{command_name} add [api_base]\n" + f"{command_name} add \n" + f"{command_name} switch \n" + f"{command_name} remove " + ) + + @staticmethod + def _build_memory_usage(command_name: str = "/memory") -> str: + return ( + "记忆命令:\n" + f"{command_name} 或 {command_name} list [limit]\n" + f"{command_name} get \n" + f"{command_name} add \n" + f"{command_name} add \n" + f"{command_name} update \n" + f"{command_name} delete \n" + f"{command_name} search [limit]\n" + "/clear (默认等价 /clear short)\n" + "/clear short\n" + "/clear long\n" + "/clear all" ) async def _init_ai(self): - if self._ai_initialized: - return + provider = self._provider_map().get(Config.AI_PROVIDER.lower(), ModelProvider.OPENAI) + config = ModelConfig( + provider=provider, + model_name=Config.AI_MODEL, + api_key=Config.AI_API_KEY, + api_base=Config.AI_API_BASE, + temperature=0.7, + ) + config = self._load_model_profiles(config) - try: - provider = self._provider_map().get( - Config.AI_PROVIDER.lower(), ModelProvider.OPENAI + embed_config = None + if Config.AI_EMBED_PROVIDER and Config.AI_EMBED_MODEL: + embed_provider = self._provider_map().get(Config.AI_EMBED_PROVIDER.lower(), ModelProvider.OPENAI) + embed_config = ModelConfig( + provider=embed_provider, + model_name=Config.AI_EMBED_MODEL, + api_key=Config.AI_EMBED_API_KEY or Config.AI_API_KEY, + api_base=Config.AI_EMBED_API_BASE or Config.AI_API_BASE, ) - config = ModelConfig( - provider=provider, - model_name=Config.AI_MODEL, - api_key=Config.AI_API_KEY, - api_base=Config.AI_API_BASE, - temperature=0.7, - ) - config = self._load_model_profiles(config) + self.ai_client = AIClient( + model_config=config, + embed_config=embed_config, + data_dir=Path("data/ai"), + use_vector_db=Config.AI_USE_VECTOR_DB, + use_query_embedding=Config.AI_USE_QUERY_EMBEDDING, + chat_retries=Config.AI_CHAT_RETRIES, + chat_retry_backoff=Config.AI_CHAT_RETRY_BACKOFF_SECONDS, + ) + self._ai_initialized = True + logger.info("AI system initialized") - embed_config = None - if Config.AI_EMBED_PROVIDER and Config.AI_EMBED_MODEL: - embed_provider = self._provider_map().get( - Config.AI_EMBED_PROVIDER.lower(), ModelProvider.OPENAI - ) - - embed_config = ModelConfig( - provider=embed_provider, - model_name=Config.AI_EMBED_MODEL, - api_key=Config.AI_EMBED_API_KEY or Config.AI_API_KEY, - api_base=Config.AI_EMBED_API_BASE or Config.AI_API_BASE, - ) - - self.ai_client = AIClient( - model_config=config, - embed_config=embed_config, - data_dir=Path("data/ai"), - use_vector_db=Config.AI_USE_VECTOR_DB, - ) - - self.skills_manager = SkillsManager(Path("skills")) - await self.skills_manager.load_all_skills() - - total_tools = 0 - for skill_name in self.skills_manager.list_skills(): - total_tools += self._register_skill_tools(skill_name) - logger.info( - f"技能系统初始化完成: {len(self.skills_manager.list_skills())} skills, {total_tools} tools" - ) - - try: - self.mcp_manager = MCPManager(Path("config/mcp.json")) - fs_server = FileSystemMCPServer(root_path=Path("data")) - await self.mcp_manager.register_server(fs_server) - mcp_tool_count = await self._register_mcp_tools() - logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools") - except Exception as exc: - logger.warning(f"MCP 初始化失败: {exc}") - - self._ai_initialized = True - logger.info("AI 系统初始化完成") - - except Exception as exc: - logger.error(f"AI 初始化失败: {exc}") - import traceback - - logger.error(traceback.format_exc()) + def _strip_mentions(self, content: str) -> str: + cleaned = content or "" + for pattern in self.MENTION_PATTERNS: + cleaned = pattern.sub("", cleaned) + if getattr(self.bot, "robot", None) and getattr(self.bot.robot, "name", None): + cleaned = cleaned.replace(f"@{self.bot.robot.name}", "") + return cleaned.strip() async def handle_at_message(self, message: Message): try: + if self._is_duplicate_message(message): + logger.info("duplicate message skipped") + return + if not self._ai_initialized: async with self._init_lock: if not self._ai_initialized: @@ -618,191 +531,83 @@ class MessageHandler: await self._reply_plain(message, "AI 初始化失败,请检查配置") return - content = (message.content or "").strip() - user_id = self._get_user_id(message) - - if content.startswith(f"<@!{self.bot.robot.id}>"): - content = content.replace(f"<@!{self.bot.robot.id}>", "").strip() - + scope = self._build_scope(message) + content = self._strip_mentions((message.content or "").strip()) if not content: return if content.startswith("/"): - await self._handle_command(message, content) + await self._handle_command(message, content, scope) return response = await self.ai_client.chat( - user_id=user_id, + user_id=scope.user_id, user_message=content, use_memory=True, - use_tools=True, + group_id=scope.group_id, + session_id=scope.session_id, + memory_key=scope.memory_key, ) await self._reply_plain(message, response) - except Exception as exc: - logger.error(f"处理消息失败: {exc}") - import traceback - - logger.error(traceback.format_exc()) + logger.error(f"message handling failed: {exc}", exc_info=True) if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): - await self._reply_plain( - message, - "模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。", - ) + await self._reply_plain(message, "模型响应超时,请稍后重试,或将 timeout 调大(建议 120-180 秒)。") return await self._reply_plain(message, "消息处理失败,请稍后重试") - async def _handle_skills_command(self, message: Message, command: str): - if not self.skills_manager or not self.ai_client: - await self._reply_plain(message, "技能系统未初始化") + async def _handle_personality_command(self, message: Message, command: str, scope: ChatScope): + if not self.ai_client: + await self._reply_plain(message, "AI 客户端未初始化") return - parts = command.split() - command_name = parts[0].lower() - action = parts[1].lower() if len(parts) > 1 else "list" + parts = command.split(maxsplit=4) + action = parts[1].lower() if len(parts) > 1 else "show" - if action in {"list", "ls"} and len(parts) <= 2: - loaded = self.skills_manager.list_skills() - available = self.skills_manager.list_available_skills() - unloaded = [name for name in available if name not in loaded] - - lines = ["技能状态:"] - lines.append("已加载: " + (", ".join(loaded) if loaded else "无")) - lines.append("可安装: " + (", ".join(unloaded) if unloaded else "无")) - lines.append(self._build_skills_usage(command_name)) - await self._reply_plain(message, "\n".join(lines)) - return - - if action not in {"install", "uninstall", "remove", "reload"}: - await self._reply_plain(message, self._build_skills_usage(command_name)) - return - - if len(parts) < 3: - await self._reply_plain(message, self._build_skills_usage(command_name)) - return - - if action == "install": - source = parts[2] - desired_name = parts[3] if len(parts) > 3 else None - - try: - source_key = self.skills_manager.normalize_skill_key(source) - except Exception: - source_key = None - - if source_key and source_key in self.skills_manager.list_available_skills(): - success = await self.skills_manager.load_skill(source_key) - if not success: - await self._reply_plain(message, f"加载技能失败: {source_key}") - return - - tool_count = self._register_skill_tools(source_key) - await self._reply_plain( - message, - f"已加载本地技能: {source_key}\n注册工具: {tool_count}", - ) - return - - ok, result = self.skills_manager.install_skill_from_source( - source=source, - skill_name=desired_name, - overwrite=False, + if action in {"show", "current"} and len(parts) <= 2: + profile = self.ai_client.personality.get_active_personality( + user_id=scope.user_id, + group_id=scope.group_id, + session_id=scope.session_id, ) - if not ok: - await self._reply_plain(message, f"安装失败: {result}") - return - - installed_key = result - success = await self.skills_manager.load_skill(installed_key) - if not success: - await self._reply_plain(message, f"安装成功但加载失败: {installed_key}") - return - - tool_count = self._register_skill_tools(installed_key) - await self._reply_plain( - message, - f"已从来源安装并加载技能: {installed_key}\n注册工具: {tool_count}", - ) - return - - skill_name = parts[2] - try: - skill_key = self.skills_manager.normalize_skill_key(skill_name) - except Exception: - await self._reply_plain(message, f"非法技能名: {skill_name}") - return - - if action in {"uninstall", "remove"}: - removed_tools = self.ai_client.unregister_tools_by_prefix(f"{skill_key}.") - removed = await self.skills_manager.uninstall_skill(skill_key, delete_files=True) - if not removed: - await self._reply_plain(message, f"卸载失败或技能不存在: {skill_key}") - return - - await self._reply_plain( - message, - f"已卸载技能: {skill_key}\n注销工具: {removed_tools}", - ) - return - - self.ai_client.unregister_tools_by_prefix(f"{skill_key}.") - success = await self.skills_manager.reload_skill(skill_key) - if not success: - await self._reply_plain(message, f"重载失败: {skill_key}") - return - - tool_count = self._register_skill_tools(skill_key) - await self._reply_plain(message, f"已重载技能: {skill_key}\n注册工具: {tool_count}") - - async def _handle_personality_command(self, message: Message, command: str): - parts = command.split(maxsplit=3) - - if len(parts) == 1: - current = self.ai_client.personality.current_personality names = self.ai_client.list_personalities() - if current: - await self._reply_plain( - message, - f"当前人设: {current.name}\n简介: {current.description}\n可用: {', '.join(names)}", - ) + if profile: + await self._reply_plain(message, f"当前人设: {profile.name}\n简介: {profile.description}\n可用: {', '.join(names)}") else: - await self._reply_plain(message, "当前没有激活的人设") + await self._reply_plain(message, "当前没有可用人设") return - action = parts[1].lower() - if action == "list": - names = self.ai_client.list_personalities() - await self._reply_plain(message, "可用人设: " + ", ".join(names)) + await self._reply_plain(message, "可用人设: " + ", ".join(self.ai_client.list_personalities())) return if action in {"set", "use"}: if len(parts) < 3: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - if self.ai_client.set_personality(key): + key = parts[2].strip() + scope_name, scope_id = self._parse_scope_token(parts[3] if len(parts) >= 4 else None, scope) + if scope_name != "global" and not await self._require_admin(message, scope.user_id, "personality scoped set"): + return + if self.ai_client.set_personality(key, scope=scope_name, scope_id=scope_id): await self._reply_plain(message, f"已切换人设: {key}") else: - await self._reply_plain(message, f"人设不存在: {key}") + await self._reply_plain(message, f"人设不存在或 scope 参数非法: {key}") return if action in {"add", "create"}: + if not await self._require_admin(message, scope.user_id, "personality add"): + return if len(parts) < 4: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - payload = parts[3] - + key = parts[2].strip() + payload = parts[3].strip() try: profile = self._parse_personality_payload(key, payload) - ok = self.ai_client.personality.add_personality(key, profile) - if not ok: + if not self.ai_client.personality.add_personality(key, profile): await self._reply_plain(message, f"新增人设失败: {key}") return - self.ai_client.set_personality(key) await self._reply_plain(message, f"已新增并切换人设: {key}") except Exception as exc: @@ -810,33 +615,32 @@ class MessageHandler: return if action in {"remove", "delete"}: + if not await self._require_admin(message, scope.user_id, "personality remove"): + return if len(parts) < 3: await self._reply_plain(message, self._build_personality_usage()) return - - key = parts[2] - removed = self.ai_client.personality.remove_personality(key) - if removed: + key = parts[2].strip() + if self.ai_client.personality.remove_personality(key): await self._reply_plain(message, f"已删除人设: {key}") else: await self._reply_plain(message, f"删除失败(可能是默认人设或不存在): {key}") return - # 兼容旧命令: /personality if self.ai_client.set_personality(parts[1]): await self._reply_plain(message, f"已切换人设: {parts[1]}") return await self._reply_plain(message, self._build_personality_usage()) - async def _handle_memory_command(self, message: Message, command: str): + async def _handle_memory_command(self, message: Message, command: str, scope: ChatScope): if not self.ai_client: await self._reply_plain(message, "AI 客户端未初始化") return - user_id = self._get_user_id(message) parts = command.split(maxsplit=3) action = parts[1].lower() if len(parts) > 1 else "list" + memory_user_id = scope.memory_key if action in {"list", "ls"}: limit = 10 @@ -846,179 +650,11 @@ class MessageHandler: except ValueError: await self._reply_plain(message, self._build_memory_usage()) return - - memories = await self.ai_client.list_long_term_memories(user_id, limit=limit) + memories = await self.ai_client.list_long_term_memories(memory_user_id, limit=limit) if not memories: await self._reply_plain(message, "暂无长期记忆") return - lines = ["长期记忆列表:"] - for memory in memories: - content = self._plain_text(memory.content).replace("\n", " ").strip() - if len(content) > 60: - content = content[:57] + "..." - lines.append( - f"- {memory.id} | 重要性={memory.importance:.2f} | {memory.timestamp.isoformat()} | {content}" - ) - await self._reply_plain(message, "\n".join(lines)) - return - - if action == "get": - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - memory = await self.ai_client.get_long_term_memory(user_id, memory_id) - if not memory: - await self._reply_plain(message, f"记忆不存在: {memory_id}") - return - - meta_text = json.dumps(memory.metadata or {}, ensure_ascii=False) - await self._reply_plain( - message, - "记忆详情:\n" - f"id: {memory.id}\n" - f"重要性: {memory.importance:.2f}\n" - f"时间: {memory.timestamp.isoformat()}\n" - f"访问次数: {memory.access_count}\n" - f"内容: {memory.content}\n" - f"元数据: {meta_text}", - ) - return - - if action == "add": - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - payload = " ".join(parts[2:]).strip() - content = payload - importance = 0.8 - metadata = None - - if payload.startswith("{"): - try: - data = json.loads(payload) - if not isinstance(data, dict): - raise ValueError("payload 必须是对象") - content = str(data.get("content") or "").strip() - importance = float(data.get("importance", 0.8)) - raw_meta = data.get("metadata") - metadata = raw_meta if isinstance(raw_meta, dict) else None - except Exception as exc: - await self._reply_plain(message, f"JSON 解析失败: {exc}") - return - - if not content: - await self._reply_plain(message, "内容不能为空") - return - - memory = await self.ai_client.add_long_term_memory( - user_id=user_id, - content=content, - importance=importance, - metadata=metadata, - ) - if not memory: - await self._reply_plain(message, "新增长期记忆失败") - return - - await self._reply_plain( - message, f"已新增长期记忆: {memory.id} (重要性={memory.importance:.2f})" - ) - return - - if action in {"update", "set"}: - if len(parts) < 4: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - payload = parts[3].strip() - content = payload - importance = None - metadata = None - - if payload.startswith("{"): - try: - data = json.loads(payload) - if not isinstance(data, dict): - raise ValueError("payload 必须是对象") - if "content" in data: - content = str(data.get("content") or "") - else: - content = None - if "importance" in data: - importance = float(data.get("importance")) - raw_meta = data.get("metadata") - if raw_meta is not None: - if not isinstance(raw_meta, dict): - raise ValueError("metadata 必须是对象") - metadata = raw_meta - except Exception as exc: - await self._reply_plain(message, f"JSON 解析失败: {exc}") - return - - updated = await self.ai_client.update_long_term_memory( - user_id=user_id, - memory_id=memory_id, - content=content, - importance=importance, - metadata=metadata, - ) - if not updated: - await self._reply_plain(message, f"更新失败或记忆不存在: {memory_id}") - return - - await self._reply_plain( - message, - f"已更新长期记忆: {memory_id} (重要性={updated.importance:.2f})", - ) - return - - if action in {"delete", "remove", "rm"}: - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - memory_id = parts[2].strip() - deleted = await self.ai_client.delete_long_term_memory(user_id, memory_id) - if not deleted: - await self._reply_plain(message, f"删除失败或记忆不存在: {memory_id}") - return - - await self._reply_plain(message, f"已删除长期记忆: {memory_id}") - return - - if action in {"search", "find"}: - if len(parts) < 3: - await self._reply_plain(message, self._build_memory_usage()) - return - - query_payload = " ".join(parts[2:]).strip() - limit = 10 - query = query_payload - if " " in query_payload: - possible_query, possible_limit = query_payload.rsplit(" ", 1) - if possible_limit.isdigit(): - query = possible_query - limit = max(1, min(100, int(possible_limit))) - - if not query: - await self._reply_plain(message, "搜索词不能为空") - return - - memories = await self.ai_client.search_long_term_memories( - user_id=user_id, - query=query, - limit=limit, - ) - if not memories: - await self._reply_plain(message, f"没有匹配到相关记忆: {query}") - return - - lines = [f"搜索结果({len(memories)} 条):"] for memory in memories: content = self._plain_text(memory.content).replace("\n", " ").strip() if len(content) > 60: @@ -1027,13 +663,123 @@ class MessageHandler: await self._reply_plain(message, "\n".join(lines)) return + if action == "get": + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + memory = await self.ai_client.get_long_term_memory(memory_user_id, memory_id) + if not memory: + await self._reply_plain(message, f"记忆不存在: {memory_id}") + return + meta_text = json.dumps(memory.metadata or {}, ensure_ascii=False) + await self._reply_plain( + message, + "记忆详情:\n" + f"id: {memory.id}\n重要性: {memory.importance:.2f}\n时间: {memory.timestamp.isoformat()}\n" + f"访问次数: {memory.access_count}\n内容: {memory.content}\n元数据: {meta_text}", + ) + return + + if action == "add": + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + payload = " ".join(parts[2:]).strip() + content = payload + importance = 0.8 + metadata = None + if payload.startswith("{"): + try: + data = json.loads(payload) + content = str(data.get("content") or "").strip() + importance = float(data.get("importance", 0.8)) + raw_meta = data.get("metadata") + metadata = raw_meta if isinstance(raw_meta, dict) else None + except Exception as exc: + await self._reply_plain(message, f"JSON 解析失败: {exc}") + return + if not content: + await self._reply_plain(message, "内容不能为空") + return + memory = await self.ai_client.add_long_term_memory(memory_user_id, content, importance, metadata) + if not memory: + await self._reply_plain(message, "新增长期记忆失败(可能命中去重策略)") + return + await self._reply_plain(message, f"已新增长期记忆: {memory.id} (重要性={memory.importance:.2f})") + return + + if action in {"update", "set"}: + if len(parts) < 4: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + payload = parts[3].strip() + content = payload + importance = None + metadata = None + if payload.startswith("{"): + try: + data = json.loads(payload) + content = str(data.get("content") or "") if "content" in data else None + importance = float(data.get("importance")) if "importance" in data else None + raw_meta = data.get("metadata") + if raw_meta is not None and not isinstance(raw_meta, dict): + raise ValueError("metadata 必须是对象") + metadata = raw_meta + except Exception as exc: + await self._reply_plain(message, f"JSON 解析失败: {exc}") + return + + updated = await self.ai_client.update_long_term_memory( + user_id=memory_user_id, + memory_id=memory_id, + content=content, + importance=importance, + metadata=metadata, + ) + if not updated: + await self._reply_plain(message, f"更新失败或记忆不存在: {memory_id}") + return + await self._reply_plain(message, f"已更新长期记忆: {memory_id} (重要性={updated.importance:.2f})") + return + + if action in {"delete", "remove", "rm"}: + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + memory_id = parts[2].strip() + deleted = await self.ai_client.delete_long_term_memory(memory_user_id, memory_id) + await self._reply_plain(message, f"已删除长期记忆: {memory_id}" if deleted else f"删除失败或记忆不存在: {memory_id}") + return + + if action in {"search", "find"}: + if len(parts) < 3: + await self._reply_plain(message, self._build_memory_usage()) + return + payload = " ".join(parts[2:]).strip() + query, limit = payload, 10 + if " " in payload: + maybe_query, maybe_limit = payload.rsplit(" ", 1) + if maybe_limit.isdigit(): + query, limit = maybe_query, max(1, min(100, int(maybe_limit))) + memories = await self.ai_client.search_long_term_memories(memory_user_id, query, limit) + if not memories: + await self._reply_plain(message, f"没有匹配到相关记忆: {query}") + return + lines = [f"搜索结果({len(memories)} 条):"] + for memory in memories: + content = self._plain_text(memory.content).replace("\n", " ").strip() + lines.append(f"- {memory.id} | 重要性={memory.importance:.2f} | {content[:60]}") + await self._reply_plain(message, "\n".join(lines)) + return + await self._reply_plain(message, self._build_memory_usage()) - async def _handle_models_command(self, message: Message, command: str): + async def _handle_models_command(self, message: Message, command: str, scope: ChatScope): if not self.ai_client: await self._reply_plain(message, "AI 客户端未初始化") return - self._ensure_model_profiles_ready() parts = command.split(maxsplit=3) @@ -1041,69 +787,40 @@ class MessageHandler: if action in {"list", "ls"} and len(parts) <= 2: lines = [f"当前模型配置: {self.active_model_key}"] - ordered_keys = self._ordered_model_keys() - for idx, key in enumerate(ordered_keys, start=1): + for idx, key in enumerate(self._ordered_model_keys(), start=1): profile = self.model_profiles.get(key, {}) marker = "*" if key == self.active_model_key else "-" - provider = str(profile.get("provider") or "?") - model_name = str(profile.get("model_name") or "?") - lines.append(f"{marker} {idx}. {key}: {provider}/{model_name}") - - if ordered_keys: - lines.append(f"提示: 可用 /models switch <序号>,例如 /models switch 2") - + lines.append(f"{marker} {idx}. {key}: {profile.get('provider')}/{profile.get('model_name')}") lines.append(self._build_models_usage("/models")) await self._reply_plain(message, "\n".join(lines)) return if action in {"current", "show"}: - config = self.ai_client.config - await self._reply_plain( - message, - "当前模型:\n" - f"配置名: {self.active_model_key}\n" - f"供应商: {config.provider.value}\n" - f"模型: {config.model_name}\n" - f"API 地址: {config.api_base or '-'}", - ) + cfg = self.ai_client.config + await self._reply_plain(message, f"当前模型: {self.active_model_key}\n{cfg.provider.value}/{cfg.model_name}") + return + + if not await self._require_admin(message, scope.user_id, f"models {action}"): return if action in {"switch", "set", "use"}: if len(parts) < 3: await self._reply_plain(message, self._build_models_usage()) return - - try: - key = self._resolve_model_selector(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - - try: - config = self._model_config_from_dict( - self.model_profiles[key], self.ai_client.config - ) - self.ai_client.switch_model(config) - except Exception as exc: - await self._reply_plain(message, f"切换模型失败: {exc}") - return - + key = self._resolve_model_selector(parts[2]) + config = self._model_config_from_dict(self.model_profiles[key], self.ai_client.config) + self.ai_client.switch_model(config) self.active_model_key = key self._save_model_profiles() - await self._reply_plain( - message, f"已切换模型: {key} ({config.provider.value}/{config.model_name})" - ) + await self._reply_plain(message, f"已切换模型: {key} ({config.provider.value}/{config.model_name})") return if action in {"add", "create"}: - # 快捷方式: /models add - # 仅替换 model_name,保留 provider/api_base/api_key。 + if len(parts) < 3: + await self._reply_plain(message, self._build_models_usage()) + return if len(parts) == 3: model_name = parts[2].strip() - if not model_name: - await self._reply_plain(message, self._build_models_usage()) - return - key = self._normalize_model_key(model_name) config = ModelConfig( provider=self.ai_client.config.provider, @@ -1115,241 +832,79 @@ class MessageHandler: top_p=self.ai_client.config.top_p, frequency_penalty=self.ai_client.config.frequency_penalty, presence_penalty=self.ai_client.config.presence_penalty, - timeout=self.ai_client.config.outtime, + timeout=self.ai_client.config.timeout, stream=self.ai_client.config.stream, ) - - self.model_profiles[key] = self._model_config_to_dict( - config, include_api_key=False - ) - self.active_model_key = key - self._save_model_profiles() - self.ai_client.switch_model(config) - - await self._reply_plain( - message, - f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})", - ) - return - - if len(parts) < 4: - await self._reply_plain(message, self._build_models_usage()) - return - - try: + else: key = self._normalize_model_key(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - - payload = parts[3].strip() - include_api_key = False - - try: + payload = parts[3].strip() if payload.startswith("{"): raw_profile = json.loads(payload) - if not isinstance(raw_profile, dict): - raise ValueError("模型参数必须是 JSON 对象") else: payload_parts = payload.split() if len(payload_parts) < 2: - raise ValueError( - "用法: /models add [api_base]" - ) + await self._reply_plain(message, self._build_models_usage()) + return raw_profile = { "provider": payload_parts[0], "model_name": payload_parts[1], } if len(payload_parts) >= 3: raw_profile["api_base"] = payload_parts[2] - - include_api_key = bool(raw_profile.get("api_key")) config = self._model_config_from_dict(raw_profile, self.ai_client.config) - except Exception as exc: - await self._reply_plain(message, f"新增模型失败: {exc}") - return - self.model_profiles[key] = self._model_config_to_dict( - config, include_api_key=include_api_key - ) + self.model_profiles[key] = self._model_config_to_dict(config, include_api_key=False) self.active_model_key = key self._save_model_profiles() self.ai_client.switch_model(config) - - await self._reply_plain( - message, - f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})", - ) + await self._reply_plain(message, f"已保存并切换模型: {key} ({config.provider.value}/{config.model_name})") return if action in {"remove", "delete", "rm"}: if len(parts) < 3: await self._reply_plain(message, self._build_models_usage()) return - - try: - key = self._resolve_model_selector(parts[2]) - except ValueError as exc: - await self._reply_plain(message, str(exc)) - return - + key = self._resolve_model_selector(parts[2]) if key == "default": await self._reply_plain(message, "默认模型配置不能删除") return - - del self.model_profiles[key] - switched_to = None - - if self.active_model_key == key: - fallback_key = ( - "default" - if "default" in self.model_profiles - else (sorted(self.model_profiles.keys())[0] if self.model_profiles else None) - ) - - if not fallback_key: - self.model_profiles["default"] = self._model_config_to_dict( - self.ai_client.config, include_api_key=False - ) - fallback_key = "default" - - fallback_config = self._model_config_from_dict( - self.model_profiles[fallback_key], self.ai_client.config - ) - self.ai_client.switch_model(fallback_config) - self.active_model_key = fallback_key - switched_to = fallback_key - + self.model_profiles.pop(key, None) self._save_model_profiles() - - if switched_to: - await self._reply_plain( - message, f"已删除模型: {key},已切换到: {switched_to}" - ) - else: - await self._reply_plain(message, f"已删除模型: {key}") + await self._reply_plain(message, f"已删除模型: {key}") return await self._reply_plain(message, self._build_models_usage()) - async def _handle_command(self, message: Message, command: str): - user_id = self._get_user_id(message) - + async def _handle_command(self, message: Message, command: str, scope: ChatScope): if command == "/help": - await self._reply_plain( - message, - "命令帮助\n" - "====================\n" - "基础命令\n" - "/help\n" - "/clear (默认等价 /clear short)\n" - "/clear short\n" - "/clear long\n" - "/clear all\n" - "\n" - "人设命令\n" - "/personality\n" - "/personality list\n" - "/personality set \n" - "/personality add \n" - "/personality remove \n" - "\n" - "技能命令\n" - "/skills\n" - "/skills install [skill_name]\n" - "/skills uninstall \n" - "/skills reload \n" - "\n" - "模型命令\n" - "/models\n" - "/models current\n" - "/models add \n" - "/models add [api_base]\n" - "/models switch \n" - "/models remove \n" - "\n" - "记忆命令\n" - "/memory\n" - "/memory get \n" - "/memory add \n" - "/memory update \n" - "/memory delete \n" - "/memory search [limit]\n" - "\n" - "任务命令\n" - "/task ", - ) + await self._reply_plain(message, "输入 /personality /models /memory /clear 查看对应能力") return if command.startswith("/clear"): - clear_parts = command.split(maxsplit=1) - clear_scope = clear_parts[1].strip().lower() if len(clear_parts) > 1 else "short" - + clear_scope = command.split(maxsplit=1)[1].strip().lower() if len(command.split(maxsplit=1)) > 1 else "short" if clear_scope in {"all", ""}: - cleared_all = await self.ai_client.clear_all_memory(user_id) - if cleared_all: - await self._reply_plain(message, "已清除短期记忆和长期记忆") - else: - await self._reply_plain( - message, - "短期记忆已清除,但长期记忆清除失败", - ) + ok = await self.ai_client.clear_all_memory(scope.memory_key) + await self._reply_plain(message, "已清除短期记忆和长期记忆" if ok else "短期记忆已清除,但长期记忆清除失败") return - if clear_scope in {"short", "short_term"}: - self.ai_client.clear_memory(user_id) + self.ai_client.clear_memory(scope.memory_key) await self._reply_plain(message, "已清除短期记忆") return - if clear_scope in {"long", "long_term"}: - cleared_long = await self.ai_client.clear_long_term_memory(user_id) - if cleared_long: - await self._reply_plain(message, "已清除长期记忆") - else: - await self._reply_plain(message, "清除长期记忆失败") + ok = await self.ai_client.clear_long_term_memory(scope.memory_key) + await self._reply_plain(message, "已清除长期记忆" if ok else "清除长期记忆失败") return - - await self._reply_plain( - message, - "用法:\n/clear\n/clear short\n/clear long\n/clear all", - ) + await self._reply_plain(message, "用法:\n/clear\n/clear short\n/clear long\n/clear all") return if command.startswith("/personality"): - await self._handle_personality_command(message, command) + await self._handle_personality_command(message, command, scope) return - - if command.startswith("/skills") or command.startswith("/skill"): - await self._handle_skills_command(message, command) - return - if command.startswith("/models") or command.startswith("/model"): - await self._handle_models_command(message, command) + await self._handle_models_command(message, command, scope) return - if command.startswith("/memory") or command.startswith("/mem"): - await self._handle_memory_command(message, command) - return - - if command.startswith("/task"): - parts = command.split(maxsplit=1) - if len(parts) == 2: - task_id = parts[1] - status = self.ai_client.get_task_status(task_id) - if status: - await self._reply_plain( - message, - "任务状态:\n" - f"标题: {status['title']}\n" - f"状态: {status['status']}\n" - f"进度: {status['progress'] * 100:.1f}%\n" - f"步骤: {status['completed_steps']}/{status['total_steps']}", - ) - else: - await self._reply_plain(message, "任务不存在") - else: - await self._reply_plain(message, "用法: /task ") + await self._handle_memory_command(message, command, scope) return await self._reply_plain(message, "未知命令,请输入 /help 查看帮助") - diff --git a/src/utils/logger.py b/src/utils/logger.py index feef2b0..f7a0da6 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -1,63 +1,101 @@ """ -日志配置模块 +Logging helpers. """ + +from __future__ import annotations + +import json import logging import os +from datetime import datetime, timezone from pathlib import Path +from typing import Any, Dict -def setup_logger(name='QQBot', level=None): - """ - 设置日志记录器 - - Args: - name: 日志记录器名称 - level: 日志级别,默认从环境变量读取 - - Returns: - logging.Logger: 配置好的日志记录器 - """ - # 创建logs目录 - log_dir = Path(__file__).parent.parent.parent / 'logs' +class _JsonFormatter(logging.Formatter): + """Simple JSON formatter for structured logs.""" + + _BASE_FIELDS = { + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + } + + def format(self, record: logging.LogRecord) -> str: + payload: Dict[str, Any] = { + "ts": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "file": record.filename, + "line": record.lineno, + } + + for key, value in record.__dict__.items(): + if key in self._BASE_FIELDS or key.startswith("_"): + continue + payload[key] = value + + if record.exc_info: + payload["exc"] = self.formatException(record.exc_info) + + return json.dumps(payload, ensure_ascii=False) + + +def setup_logger(name: str = "QQBot", level: str | None = None) -> logging.Logger: + log_dir = Path(__file__).parent.parent.parent / "logs" log_dir.mkdir(exist_ok=True) - - # 设置日志级别 + if level is None: - level = os.getenv('LOG_LEVEL', 'INFO') - - # 创建日志记录器 + level = os.getenv("LOG_LEVEL", "INFO") + log_format = (os.getenv("LOG_FORMAT", "text") or "text").lower() + logger = logging.getLogger(name) - logger.setLevel(getattr(logging, level.upper())) - # 避免向 root logger 传播导致重复输出 + logger.setLevel(getattr(logging, str(level).upper(), logging.INFO)) logger.propagate = False - - # 避免重复添加处理器 + if logger.handlers: return logger - - # 控制台处理器 + console_handler = logging.StreamHandler() + file_handler = logging.FileHandler(log_dir / "bot.log", encoding="utf-8") + + if log_format == "json": + formatter = _JsonFormatter() + console_handler.setFormatter(formatter) + file_handler.setFormatter(formatter) + else: + console_fmt = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + file_fmt = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + console_handler.setFormatter(console_fmt) + file_handler.setFormatter(file_fmt) + console_handler.setLevel(logging.INFO) - console_format = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - console_handler.setFormatter(console_format) - - # 文件处理器 - file_handler = logging.FileHandler( - log_dir / 'bot.log', - encoding='utf-8' - ) file_handler.setLevel(logging.DEBUG) - file_format = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler.setFormatter(file_format) - - # 添加处理器 + logger.addHandler(console_handler) logger.addHandler(file_handler) - return logger diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c5c7f8c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +from pathlib import Path +import sys + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) diff --git a/tests/test_ai.py b/tests/test_ai.py deleted file mode 100644 index 0638f2c..0000000 --- a/tests/test_ai.py +++ /dev/null @@ -1,628 +0,0 @@ -"""AI integration tests.""" - -import asyncio -import json -import os -from pathlib import Path -import shutil -import stat -import sys -import tempfile -import time -import zipfile - -from dotenv import load_dotenv - -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from src.ai import AIClient -from src.ai.base import ModelConfig, ModelProvider -from src.ai.memory import MemorySystem -from src.ai.skills import SkillsManager, create_skill_template -from src.handlers.message_handler_ai import MessageHandler - -load_dotenv(project_root / ".env") - - -TEST_DATA_DIR = Path("data/ai_test") - - -def _safe_rmtree(path: Path): - if not path.exists(): - return - - def _onerror(func, target, exc_info): - try: - os.chmod(target, stat.S_IWRITE) - func(target) - except Exception: - pass - - for _ in range(3): - try: - shutil.rmtree(path, onerror=_onerror) - return - except PermissionError: - time.sleep(0.2) - - -def _safe_unlink(path: Path): - if not path.exists(): - return - - for _ in range(3): - try: - path.unlink() - return - except PermissionError: - time.sleep(0.2) - - -def _read_env(name: str, default=None): - value = os.getenv(name) - if value is None: - return default - - value = value.strip() - if not value or value.startswith("#"): - return default - - return value - - -def get_ai_config() -> ModelConfig: - provider_map = { - "openai": ModelProvider.OPENAI, - "anthropic": ModelProvider.ANTHROPIC, - "deepseek": ModelProvider.DEEPSEEK, - "qwen": ModelProvider.QWEN, - "siliconflow": ModelProvider.OPENAI, - } - - provider_str = (_read_env("AI_PROVIDER", "openai") or "openai").lower() - provider = provider_map.get(provider_str, ModelProvider.OPENAI) - - return ModelConfig( - provider=provider, - model_name=_read_env("AI_MODEL", "gpt-3.5-turbo") or "gpt-3.5-turbo", - api_key=_read_env("AI_API_KEY", "") or "", - api_base=_read_env("AI_API_BASE"), - temperature=0.7, - ) - - -def get_embed_config() -> ModelConfig: - provider_map = { - "openai": ModelProvider.OPENAI, - "anthropic": ModelProvider.ANTHROPIC, - "deepseek": ModelProvider.DEEPSEEK, - "qwen": ModelProvider.QWEN, - "siliconflow": ModelProvider.OPENAI, - } - - provider_str = (_read_env("AI_EMBED_PROVIDER", "openai") or "openai").lower() - provider = provider_map.get(provider_str, ModelProvider.OPENAI) - - api_key = _read_env("AI_EMBED_API_KEY") or _read_env("AI_API_KEY", "") or "" - api_base = _read_env("AI_EMBED_API_BASE") or _read_env("AI_API_BASE") - - return ModelConfig( - provider=provider, - model_name=_read_env("AI_EMBED_MODEL", "text-embedding-3-small") - or "text-embedding-3-small", - api_key=api_key, - api_base=api_base, - temperature=0.0, - ) - - -class FakeMessage: - def __init__(self, content: str): - from types import SimpleNamespace - - self.content = content - self.author = SimpleNamespace(id="test_user") - self.replies = [] - - async def reply(self, content: str): - self.replies.append(content) - - -def make_handler() -> MessageHandler: - from types import SimpleNamespace - - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) - handler = MessageHandler(fake_bot) - handler.ai_client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - handler.skills_manager = SkillsManager(Path("skills")) - handler.model_profiles_path = TEST_DATA_DIR / "models_test.json" - TEST_DATA_DIR.mkdir(parents=True, exist_ok=True) - _safe_unlink(handler.model_profiles_path) - handler._ai_initialized = True - return handler - - -async def _test_basic_chat(): - print("=== test_basic_chat ===") - - config = get_ai_config() - if not config.api_key: - print("skip: AI_API_KEY not configured") - return - - embed_config = get_embed_config() - client = AIClient(config, embed_config=embed_config, data_dir=TEST_DATA_DIR) - response = await client.chat( - user_id="test_user", - user_message="你好,请介绍一下你自己", - use_memory=False, - use_tools=False, - ) - assert response - print("ok: chat response length", len(response)) - - -async def _test_memory(): - print("=== test_memory ===") - - config = get_ai_config() - if not config.api_key: - print("skip: AI_API_KEY not configured") - return - - client = AIClient(config, embed_config=get_embed_config(), data_dir=TEST_DATA_DIR) - await client.chat(user_id="test_user", user_message="鎴戝彨寮犱笁", use_memory=True) - await client.chat(user_id="test_user", user_message="what is my name", use_memory=True) - - short_term, long_term = await client.memory.get_context("test_user") - assert len(short_term) >= 2 - # 重要性改为模型评估后,是否入长期记忆取决于模型打分,不再固定断言数量。 - assert isinstance(long_term, list) - print("ok: memory short/long", len(short_term), len(long_term)) - - -async def _test_personality(): - print("=== test_personality ===") - - client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - names = client.list_personalities() - assert names - assert client.set_personality(names[0]) - - key = "roleplay_test" - added = client.personality.add_personality( - key, - client.personality.get_personality("default"), - ) - assert added - assert key in client.list_personalities() - assert client.personality.remove_personality(key) - assert key not in client.list_personalities() - - print("ok: personality add/remove") - - -async def _test_skills(): - print("=== test_skills ===") - - manager = SkillsManager(Path("skills")) - assert await manager.load_skill("weather") - - tools = manager.get_all_tools() - assert "weather.get_weather" in tools - weather = await tools["weather.get_weather"](city="鍖椾含") - assert weather - - assert await manager.load_skill("skills_creator") - tools = manager.get_all_tools() - assert "skills_creator.create_skill" in tools - - await manager.unload_skill("weather") - await manager.unload_skill("skills_creator") - print("ok: skills load/unload") - - -async def _test_skill_commands(): - print("=== test_skill_commands ===") - - handler = make_handler() - skill_key = f"cmd_zip_skill_{int(time.time() * 1000)}" - - # Prepare a zip package source for install testing - tmp_root = TEST_DATA_DIR / "tmp_skill_pkg" - if tmp_root.exists(): - _safe_rmtree(tmp_root) - tmp_root.mkdir(parents=True, exist_ok=True) - - create_skill_template(skill_key, tmp_root, description="zip skill", author="test") - - zip_path = TEST_DATA_DIR / f"{skill_key}.zip" - - with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: - for file in (tmp_root / skill_key).rglob("*"): - if file.is_file(): - zf.write(file, file.relative_to(tmp_root)) - - install_msg = FakeMessage(f"/skills install {zip_path}") - await handler._handle_command(install_msg, install_msg.content) - assert install_msg.replies, "install command no reply" - - list_msg = FakeMessage("/skills") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies, "list command no reply" - - reload_msg = FakeMessage(f"/skills reload {skill_key}") - await handler._handle_command(reload_msg, reload_msg.content) - assert reload_msg.replies, "reload command no reply" - - uninstall_msg = FakeMessage(f"/skills uninstall {skill_key}") - await handler._handle_command(uninstall_msg, uninstall_msg.content) - assert uninstall_msg.replies, "uninstall command no reply" - - if tmp_root.exists(): - _safe_rmtree(tmp_root) - _safe_unlink(zip_path) - - print("ok: skills install/reload/uninstall command") - - -async def _test_personality_commands(): - print("=== test_personality_commands ===") - - handler = make_handler() - intro = "You are a hot-blooded anime hero. Speak directly and stay in-character." - - add_cmd = ( - "/personality add roleplay_hero " - f"{intro}" - ) - add_msg = FakeMessage(add_cmd) - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - - set_msg = FakeMessage("/personality set roleplay_hero") - await handler._handle_command(set_msg, set_msg.content) - assert set_msg.replies - assert intro in handler.ai_client.personality.get_system_prompt() - - remove_msg = FakeMessage("/personality remove roleplay_hero") - await handler._handle_command(remove_msg, remove_msg.content) - assert remove_msg.replies - - assert "roleplay_hero" not in handler.ai_client.list_personalities() - print("ok: personality add/set/remove command") - - -async def _test_model_commands(): - print("=== test_model_commands ===") - - handler = make_handler() - - list_msg = FakeMessage("/models") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies - assert "default" in list_msg.replies[-1].lower() - - add_msg = FakeMessage("/models add roleplay_llm openai gpt-4o-mini") - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - assert handler.active_model_key == "roleplay_llm" - assert "roleplay_llm" in handler.model_profiles - - switch_msg = FakeMessage("/models switch default") - await handler._handle_command(switch_msg, switch_msg.content) - assert switch_msg.replies - assert handler.active_model_key == "default" - - current_msg = FakeMessage("/models current") - await handler._handle_command(current_msg, current_msg.content) - assert current_msg.replies - - old_config = handler.ai_client.config - shortcut_model = "Qwen/Qwen2.5-7B-Instruct" - shortcut_key = handler._normalize_model_key(shortcut_model) - shortcut_add_msg = FakeMessage(f"/models add {shortcut_model}") - await handler._handle_command(shortcut_add_msg, shortcut_add_msg.content) - assert shortcut_add_msg.replies - assert handler.active_model_key == shortcut_key - assert handler.ai_client.config.model_name == shortcut_model - assert handler.ai_client.config.provider == old_config.provider - assert handler.ai_client.config.api_base == old_config.api_base - assert handler.ai_client.config.api_key == old_config.api_key - - shortcut_remove_msg = FakeMessage(f"/models remove {shortcut_key}") - await handler._handle_command(shortcut_remove_msg, shortcut_remove_msg.content) - assert shortcut_remove_msg.replies - assert shortcut_key not in handler.model_profiles - - remove_msg = FakeMessage("/models remove roleplay_llm") - await handler._handle_command(remove_msg, remove_msg.content) - assert remove_msg.replies - assert "roleplay_llm" not in handler.model_profiles - - _safe_unlink(handler.model_profiles_path) - print("ok: model add/switch/remove command") - - -async def _test_memory_commands(): - print("=== test_memory_commands ===") - - handler = make_handler() - user_id = "test_user" - - await handler.ai_client.clear_all_memory(user_id) - - add_msg = FakeMessage("/memory add this is a long-term memory test") - await handler._handle_command(add_msg, add_msg.content) - assert add_msg.replies - assert "已新增长期记忆" in add_msg.replies[-1] - - memory_id = add_msg.replies[-1].split(": ", 1)[1].split(" ", 1)[0] - assert memory_id - - list_msg = FakeMessage("/memory list 5") - await handler._handle_command(list_msg, list_msg.content) - assert list_msg.replies - assert memory_id in list_msg.replies[-1] - - get_msg = FakeMessage(f"/memory get {memory_id}") - await handler._handle_command(get_msg, get_msg.content) - assert get_msg.replies - assert memory_id in get_msg.replies[-1] - - search_msg = FakeMessage("/memory search 长期记忆") - await handler._handle_command(search_msg, search_msg.content) - assert search_msg.replies - assert memory_id in search_msg.replies[-1] - - update_msg = FakeMessage(f"/memory update {memory_id} 这是更新后的长期记忆") - await handler._handle_command(update_msg, update_msg.content) - assert update_msg.replies - assert "已更新长期记忆" in update_msg.replies[-1] - - # Build short-term memory then clear only short-term. - await handler.ai_client.memory.add_message( - user_id=user_id, - role="user", - content="short memory for clear short test", - ) - assert handler.ai_client.memory.short_term.get(user_id) - - clear_short_msg = FakeMessage("/clear short") - await handler._handle_command(clear_short_msg, clear_short_msg.content) - assert clear_short_msg.replies - assert not handler.ai_client.memory.short_term.get(user_id) - - # Long-term memory should still exist after clearing short-term only. - still_exists = await handler.ai_client.get_long_term_memory(user_id, memory_id) - assert still_exists is not None - - delete_msg = FakeMessage(f"/memory delete {memory_id}") - await handler._handle_command(delete_msg, delete_msg.content) - assert delete_msg.replies - assert "已删除长期记忆" in delete_msg.replies[-1] - - removed = await handler.ai_client.get_long_term_memory(user_id, memory_id) - assert removed is None - - print("ok: memory command CRUD + clear short") - - -async def _test_plain_text_output(): - print("=== test_plain_text_output ===") - - handler = make_handler() - md_text = "# 标题\n**加粗** 和 `代码`\n- 列表\n[链接](https://example.com)" - plain = handler._plain_text(md_text) - - assert "#" not in plain - assert "**" not in plain - assert "`" not in plain - assert "[" not in plain - assert "](" not in plain - print("ok: markdown stripped") - - -async def _test_skills_creator_autoload(): - print("=== test_skills_creator_autoload ===") - - from types import SimpleNamespace - - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) - handler = MessageHandler(fake_bot) - handler.model_profiles_path = TEST_DATA_DIR / "models_autoload_test.json" - _safe_unlink(handler.model_profiles_path) - await handler._init_ai() - - assert handler.skills_manager is not None - assert "skills_creator" in handler.skills_manager.list_skills() - - tool_names = [tool.name for tool in handler.ai_client.tools.list()] - assert "skills_creator.create_skill" in tool_names - print("ok: skills_creator autoloaded") - - -async def _test_mcp(): - print("=== test_mcp ===") - - from src.ai.mcp import MCPManager - from src.ai.mcp.servers import FileSystemMCPServer - - manager = MCPManager(Path("config/mcp.json")) - fs_server = FileSystemMCPServer(root_path=Path("data")) - await manager.register_server(fs_server) - - tools = await manager.get_all_tools_for_ai() - assert len(tools) >= 1 - print("ok: mcp tools", len(tools)) - - -async def _test_long_task(): - print("=== test_long_task ===") - - client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR) - - async def step1(): - await asyncio.sleep(0.1) - return "step1" - - async def step2(): - await asyncio.sleep(0.1) - return "step2" - - client.task_manager.register_action("step1", step1) - client.task_manager.register_action("step2", step2) - - task_id = await client.create_long_task( - user_id="test_user", - title="test", - description="test task", - steps=[ - {"description": "s1", "action": "step1", "params": {}}, - {"description": "s2", "action": "step2", "params": {}}, - ], - ) - - await client.start_task(task_id) - await asyncio.sleep(0.5) - - status = client.get_task_status(task_id) - assert status is not None - assert status["status"] in {"completed", "running"} - print("ok: long task", status["status"]) - - -async def _test_memory_importance_evaluator(): - print("=== test_memory_importance_evaluator ===") - - called = {"value": False} - - async def fake_importance_eval(content, metadata): - called["value"] = True - assert "用户:" in content - assert "助手:" in content - return 0.91 - - store_path = TEST_DATA_DIR / "importance_test.json" - _safe_unlink(store_path) - - memory = MemorySystem( - storage_path=store_path, - importance_evaluator=fake_importance_eval, - use_vector_db=False, - ) - - stored = await memory.add_qa_pair( - user_id="u1", - question="请记住我的昵称是小明", - answer="好的,我记住了你的昵称是小明", - metadata={"source": "test"}, - ) - assert called["value"] - assert stored is not None - assert "用户:" in stored.content - assert "助手:" in stored.content - assert "小明" in stored.content - - long_term = await memory.list_long_term("u1") - assert len(long_term) == 1 - - # add_message 仅写入短期记忆,不触发长期记忆评分写入。 - await memory.add_message(user_id="u1", role="user", content="单条短期消息") - long_term_after_single = await memory.list_long_term("u1") - assert len(long_term_after_single) == 1 - - memory_without_eval = MemorySystem( - storage_path=TEST_DATA_DIR / "importance_fallback_test.json", - use_vector_db=False, - ) - fallback_score = await memory_without_eval._evaluate_importance("任意内容", None) - assert fallback_score == 0.5 - - await memory.close() - await memory_without_eval.close() - _safe_unlink(store_path) - _safe_unlink(TEST_DATA_DIR / "importance_fallback_test.json") - print("ok: memory importance evaluator") - - -def test_basic_chat(): - asyncio.run(_test_basic_chat()) - - -def test_memory(): - asyncio.run(_test_memory()) - - -def test_personality(): - asyncio.run(_test_personality()) - - -def test_skills(): - asyncio.run(_test_skills()) - - -def test_skill_commands(): - asyncio.run(_test_skill_commands()) - - -def test_personality_commands(): - asyncio.run(_test_personality_commands()) - - -def test_model_commands(): - asyncio.run(_test_model_commands()) - - -def test_memory_commands(): - asyncio.run(_test_memory_commands()) - - -def test_plain_text_output(): - asyncio.run(_test_plain_text_output()) - - -def test_skills_creator_autoload(): - asyncio.run(_test_skills_creator_autoload()) - - -def test_mcp(): - asyncio.run(_test_mcp()) - - -def test_long_task(): - asyncio.run(_test_long_task()) - - -def test_memory_importance_evaluator(): - asyncio.run(_test_memory_importance_evaluator()) - - -async def main(): - print("寮€濮?AI 鍔熻兘娴嬭瘯") - - await _test_personality() - await _test_skills() - await _test_skill_commands() - await _test_personality_commands() - await _test_model_commands() - await _test_memory_commands() - await _test_plain_text_output() - await _test_skills_creator_autoload() - await _test_mcp() - await _test_long_task() - await _test_memory_importance_evaluator() - - config = get_ai_config() - if config.api_key: - await _test_basic_chat() - await _test_memory() - else: - print("跳过需要 API Key 的对话/记忆测试") - - print("娴嬭瘯瀹屾垚") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py deleted file mode 100644 index 0669878..0000000 --- a/tests/test_ai_client_forced_tool.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Tests for AIClient forced tool name extraction.""" - -from src.ai.client import AIClient - - -def test_extract_forced_tool_name_full_name(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_unique_prefix(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "please call tool humanizer_zh only" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_compact_prefix_without_underscore(): - tools = [ - "humanizer_zh.read_skill_doc", - "skills_creator.create_skill", - ] - message = "调用humanizerzh人性化处理以下文本" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced == "humanizer_zh.read_skill_doc" - - -def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): - tools = [ - "skills_creator.create_skill", - "skills_creator.reload_skill", - ] - message = "please call tool skills_creator" - - forced = AIClient._extract_forced_tool_name(message, tools) - - assert forced is None - - -def test_extract_prefix_limit_from_user_message(): - assert AIClient._extract_prefix_limit("直接返回前100字") == 100 - assert AIClient._extract_prefix_limit("前 256 字") == 256 - assert AIClient._extract_prefix_limit("返回全文") is None - - -def test_extract_processing_payload_with_marker(): - message = "调用humanizer_zh.read_skill_doc人性化处理以下文本:\n第一段。\n第二段。" - payload = AIClient._extract_processing_payload(message) - assert payload == "第一段。\n第二段。" - - -def test_extract_processing_payload_with_generic_pattern(): - message = "请按技能规则优化如下:\n这是待处理文本。" - payload = AIClient._extract_processing_payload(message) - assert payload == "这是待处理文本。" - - -def test_extract_processing_payload_returns_none_when_absent(): - assert AIClient._extract_processing_payload("请调用工具 humanizer_zh.read_skill_doc") is None diff --git a/tests/test_mcp_tool_registration.py b/tests/test_mcp_tool_registration.py deleted file mode 100644 index 90bde10..0000000 --- a/tests/test_mcp_tool_registration.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio -from pathlib import Path - -from src.ai.mcp.base import MCPManager, MCPServer - - -class _DummyMCPServer(MCPServer): - def __init__(self): - super().__init__(name="dummy", version="1.0.0") - - async def initialize(self): - self.register_tool( - name="echo", - description="Echo text", - input_schema={ - "type": "object", - "properties": {"text": {"type": "string"}}, - "required": ["text"], - }, - handler=self.echo, - ) - - async def echo(self, text: str) -> str: - return text - - -def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path): - manager = MCPManager(tmp_path / "mcp.json") - asyncio.run(manager.register_server(_DummyMCPServer())) - - tools = asyncio.run(manager.get_all_tools_for_ai()) - assert len(tools) == 1 - function_info = tools[0]["function"] - assert function_info["name"] == "dummy.echo" - assert function_info["description"] == "Echo text" - assert function_info["parameters"]["required"] == ["text"] - - -def test_mcp_manager_execute_tool(tmp_path: Path): - manager = MCPManager(tmp_path / "mcp.json") - asyncio.run(manager.register_server(_DummyMCPServer())) - - result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"})) - assert result == "hello" diff --git a/tests/test_message_dedup.py b/tests/test_message_dedup.py new file mode 100644 index 0000000..21b9d4e --- /dev/null +++ b/tests/test_message_dedup.py @@ -0,0 +1,22 @@ +from types import SimpleNamespace + +from src.handlers.message_handler_ai import MessageHandler + + +def _build_handler() -> MessageHandler: + fake_bot = SimpleNamespace(robot=SimpleNamespace(id="bot_1", name="TestBot")) + return MessageHandler(fake_bot) + + +def test_message_dedup_by_message_id(): + handler = _build_handler() + msg = SimpleNamespace(id="m1", content="hello", author=SimpleNamespace(id="u1")) + assert handler._is_duplicate_message(msg) is False + assert handler._is_duplicate_message(msg) is True + + +def test_message_dedup_fallback_without_message_id(): + handler = _build_handler() + msg = SimpleNamespace(content="hello", author=SimpleNamespace(id="u1"), group_id="g1") + assert handler._is_duplicate_message(msg) is False + assert handler._is_duplicate_message(msg) is True diff --git a/tests/test_message_handler_text_sanitize.py b/tests/test_message_handler_text_sanitize.py index b41f922..9ce9aa9 100644 --- a/tests/test_message_handler_text_sanitize.py +++ b/tests/test_message_handler_text_sanitize.py @@ -10,16 +10,14 @@ from src.handlers.message_handler_ai import MessageHandler def _handler() -> MessageHandler: - fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) + fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot", name="TestBot")) return MessageHandler(fake_bot) def test_plain_text_removes_markdown_link_url(): handler = _handler() text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。" - result = handler._plain_text(text) - assert "Wikipedia" in result assert "http" not in result.lower() @@ -27,9 +25,7 @@ def test_plain_text_removes_markdown_link_url(): def test_plain_text_removes_bare_url(): handler = _handler() text = "访问 https://example.com/path?a=1 或 www.example.org 查看。" - result = handler._plain_text(text) - assert "http" not in result.lower() assert "www." not in result.lower() assert "[链接已省略]" in result diff --git a/tests/test_personality_scope_priority.py b/tests/test_personality_scope_priority.py new file mode 100644 index 0000000..e00d704 --- /dev/null +++ b/tests/test_personality_scope_priority.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from src.ai.personality import PersonalityProfile, PersonalitySystem, PersonalityTrait + + +def _profile(name: str) -> PersonalityProfile: + return PersonalityProfile( + name=name, + description=f"{name} profile", + traits=[PersonalityTrait.FRIENDLY], + speaking_style="plain", + ) + + +def test_scope_priority_session_over_group_over_user_over_global(tmp_path: Path): + cfg = tmp_path / "personalities.json" + state = tmp_path / "personality_state.json" + system = PersonalitySystem(config_path=cfg, state_path=state) + + system.add_personality("p_global", _profile("global")) + system.add_personality("p_user", _profile("user")) + system.add_personality("p_group", _profile("group")) + system.add_personality("p_session", _profile("session")) + + assert system.set_personality("p_global", scope="global") + assert system.set_personality("p_user", scope="user", scope_id="u1") + assert system.set_personality("p_group", scope="group", scope_id="g1") + assert system.set_personality("p_session", scope="session", scope_id="g1:u1") + + profile = system.get_active_personality(user_id="u1", group_id="g1", session_id="g1:u1") + assert profile is not None + assert profile.name == "session" + + profile_no_session = system.get_active_personality(user_id="u1", group_id="g1", session_id="other") + assert profile_no_session is not None + assert profile_no_session.name == "group" + + profile_user_only = system.get_active_personality(user_id="u1", group_id=None, session_id=None) + assert profile_user_only is not None + assert profile_user_only.name == "user" + + profile_global = system.get_active_personality(user_id="u2", group_id=None, session_id=None) + assert profile_global is not None + assert profile_global.name == "global" diff --git a/tests/test_skills_install_source.py b/tests/test_skills_install_source.py deleted file mode 100644 index f27fb36..0000000 --- a/tests/test_skills_install_source.py +++ /dev/null @@ -1,74 +0,0 @@ -import io -from pathlib import Path -import zipfile - -from src.ai.skills.base import SkillsManager - - -def _build_codex_skill_zip_bytes(markdown_text: str, root_name: str = "Humanizer-zh-main") -> bytes: - buffer = io.BytesIO() - with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf: - zf.writestr(f"{root_name}/SKILL.md", markdown_text) - return buffer.getvalue() - - -def test_resolve_network_source_supports_github_git_url(tmp_path: Path): - manager = SkillsManager(tmp_path / "skills") - url, hint_key, subpath = manager._resolve_network_source( - "https://github.com/op7418/Humanizer-zh.git" - ) - - assert url == "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main" - assert hint_key == "humanizer_zh" - assert subpath is None - - -def test_install_skill_from_local_skill_markdown_source(tmp_path: Path): - source_dir = tmp_path / "Humanizer-zh-main" - source_dir.mkdir(parents=True, exist_ok=True) - (source_dir / "SKILL.md").write_text( - "# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n", - encoding="utf-8", - ) - - manager = SkillsManager(tmp_path / "skills") - ok, installed_key = manager.install_skill_from_source(str(source_dir), skill_name="humanizer_zh") - - assert ok - assert installed_key == "humanizer_zh" - installed_dir = tmp_path / "skills" / "humanizer_zh" - assert (installed_dir / "skill.json").exists() - assert (installed_dir / "main.py").exists() - assert (installed_dir / "SKILL.md").exists() - - main_code = (installed_dir / "main.py").read_text(encoding="utf-8") - assert "read_skill_doc" in main_code - skill_text = (installed_dir / "SKILL.md").read_text(encoding="utf-8") - assert "Humanizer-zh" in skill_text - - -def test_install_skill_from_github_git_url_uses_repo_zip_and_markdown_adapter( - tmp_path: Path, monkeypatch -): - manager = SkillsManager(tmp_path / "skills") - zip_bytes = _build_codex_skill_zip_bytes( - "# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n" - ) - captured_urls = [] - - def fake_download(url: str, output_zip: Path): - captured_urls.append(url) - output_zip.write_bytes(zip_bytes) - - monkeypatch.setattr(manager, "_download_zip", fake_download) - - ok, installed_key = manager.install_skill_from_source( - "https://github.com/op7418/Humanizer-zh.git" - ) - - assert ok - assert installed_key == "humanizer_zh" - assert captured_urls == [ - "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main" - ] - assert (tmp_path / "skills" / "humanizer_zh" / "SKILL.md").exists()