Refactor configuration and enhance AI capabilities

Updated .env.example to improve clarity and added new configuration options for memory and reliability settings. Refactored main.py to streamline the bot's entry point and improved error handling. Enhanced README to reflect new features and command structure. Removed deprecated cmd_zip_skill and skills_creator modules to clean up the codebase. Updated AIClient and MemorySystem for better performance and flexibility in handling user interactions.
This commit is contained in:
Mimikko-zeus
2026-03-03 21:56:33 +08:00
parent 726d41ad79
commit a754e7843f
72 changed files with 1607 additions and 5015 deletions

View File

@@ -1,36 +1,36 @@
# QQ 机器人配置 # QQ Bot credentials
# 可从 https://bot.q.qq.com/open 获取
# 机器人 AppID必填
BOT_APPID=your_app_id_here BOT_APPID=your_app_id_here
# 机器人 AppSecret必填
BOT_SECRET=your_app_secret_here BOT_SECRET=your_app_secret_here
# 日志级别: DEBUG / INFO / WARNING / ERROR # Runtime
APP_ENV=dev
LOG_LEVEL=INFO LOG_LEVEL=INFO
LOG_FORMAT=text
# 是否启用沙箱环境
SANDBOX_MODE=False SANDBOX_MODE=False
# ==================== AI 配置 ==================== # Optional admin allow-list (comma separated user IDs).
# Empty means all users are treated as admin.
BOT_ADMIN_IDS=
# 主模型配置 # AI chat model
# 可选 provider: openai / anthropic / deepseek / qwen
AI_PROVIDER=openai AI_PROVIDER=openai
AI_MODEL=gpt-4 AI_MODEL=gpt-4
AI_API_KEY=your_api_key_here AI_API_KEY=your_api_key_here
# 可选,自定义 API 地址
AI_API_BASE=https://api.openai.com/v1 AI_API_BASE=https://api.openai.com/v1
# 嵌入模型配置(用于长期记忆检索) # Embedding model (optional)
# 不配置时将回退为主模型的 embedding 能力(如果可用)
AI_EMBED_PROVIDER=openai AI_EMBED_PROVIDER=openai
AI_EMBED_MODEL=text-embedding-3-small AI_EMBED_MODEL=text-embedding-3-small
AI_EMBED_API_KEY= AI_EMBED_API_KEY=
AI_EMBED_API_BASE= AI_EMBED_API_BASE=
# 向量数据库配置 # Memory storage and retrieval
# true: 使用 Chroma推荐
# false: 使用 JSON 存储
AI_USE_VECTOR_DB=true AI_USE_VECTOR_DB=true
AI_USE_QUERY_EMBEDDING=false
AI_MEMORY_SCOPE=session
# Reliability
AI_CHAT_RETRIES=1
AI_CHAT_RETRY_BACKOFF_SECONDS=0.8
MESSAGE_DEDUP_SECONDS=30
MESSAGE_DEDUP_MAX_SIZE=4096

131
README.md
View File

@@ -1,14 +1,16 @@
# QQbotAI 聊天机器人) # QQbot (Memory + Persona Core)
一个基于 `botpy` 的 QQ 机器人项目,支持多模型切换、长期/短期记忆、人设管理、Skills 插件与 MCP 能力 QQ 机器人项目,保留并强化两大核心能力
- `Memory`:短期/长期记忆、检索、清理、会话作用域
- `Persona`角色配置、作用域优先级session > group > user > global
## 功能概览 ## 主要能力
- 多模型配置与运行时切换`/models` - 多模型配置与运行时切换`/models`
- 人设增删改切换(`/personality` - 人设管理:`/personality`
- 短期/长期记忆管理`/clear``/memory` - 记忆管理`/memory``/clear`
- Skills 本地与网络安装/卸载/重载(`/skills` - QQ 消息安全输出:自动清理 Markdown/URL
- 自动去除 Markdown 格式后再回复(适配 QQ 聊天) - 工程增强:消息去重、失败重试、权限边界、结构化日志
## 快速开始 ## 快速开始
@@ -20,7 +22,9 @@ pip install -r requirements.txt
2. 配置环境变量 2. 配置环境变量
复制 `.env.example``.env`,填写 QQ 机器人和 AI 配置。 ```bash
copy .env.example .env
```
3. 启动 3. 启动
@@ -28,89 +32,42 @@ pip install -r requirements.txt
python main.py python main.py
``` ```
## 命令说明 ## 命令
### 通用 - 基础
- `/help`
- `/clear` `/clear short` `/clear long` `/clear all`
- 人设
- `/personality`
- `/personality list`
- `/personality set <key> [global|user|group|session]`
- `/personality add <name> <Introduction>`
- `/personality remove <key>`
- 模型
- `/models`
- `/models current`
- `/models add <model_name>`
- `/models add <key> <provider> <model_name> [api_base]`
- `/models switch <key|index>`
- `/models remove <key|index>`
- 记忆
- `/memory`
- `/memory get <id>`
- `/memory add <content|json>`
- `/memory update <id> <content|json>`
- `/memory delete <id>`
- `/memory search <query> [limit]`
- `/help` ## 关键配置
- `/clear`(默认等价 `/clear short`
- `/clear short`
- `/clear long`
- `/clear all`
### 人设命令 - `AI_MEMORY_SCOPE=user|session`:记忆作用域
- `BOT_ADMIN_IDS`:管理员白名单(逗号分隔)
- `/personality` - `AI_CHAT_RETRIES` / `AI_CHAT_RETRY_BACKOFF_SECONDS`:聊天失败重试
- `/personality list` - `MESSAGE_DEDUP_SECONDS` / `MESSAGE_DEDUP_MAX_SIZE`:消息去重窗口
- `/personality set <key>` - `LOG_FORMAT=text|json`:日志输出格式
- `/personality add <name> <Introduction>`
- `/personality remove <key>`
说明:
- `add` 会新增并切换到该人设
- `Introduction` 会作为人设简介与自定义指令
### Skills 命令
- `/skills`
- `/skills install <source> [skill_name]`
- `/skills uninstall <skill_name>`
- `/skills reload <skill_name>`
`source` 支持:
- 本地技能名(如 `weather`
- URLzip 包)
- GitHub 简写(`owner/repo``owner/repo#branch`
- GitHub 仓库 URL`https://github.com/op7418/Humanizer-zh.git`
兼容说明:
- 若源中包含标准技能结构(`skill.json` + `main.py`),按原方式安装
- 若仅包含 `SKILL.md`,会自动生成适配技能并提供 `read_skill_doc` 工具读取文档内容
### 模型命令
- `/models`
- `/models current`
- `/models add <model_name>`
- `/models add <key> <provider> <model_name> [api_base]`
- `/models switch <key|index>`
- `/models remove <key|index>`
说明:
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key
### 长期记忆命令
- `/memory`
- `/memory get <id>`
- `/memory add <content|json>`
- `/memory update <id> <content|json>`
- `/memory delete <id>`
- `/memory search <query> [limit]`
## 目录结构
```text
QQbot/
├─ src/
│ ├─ ai/
│ ├─ handlers/
│ ├─ core/
│ └─ utils/
├─ skills/
├─ config/
├─ docs/
└─ tests/
```
## 测试 ## 测试
```bash ```bash
python -m pytest -q pytest -q
```
如果你使用 conda 环境,请先执行:
```bash
conda activate qqbot
``` ```

View File

@@ -1,6 +0,0 @@
{
"filesystem": {
"enabled": true,
"root_path": "data"
}
}

41
main.py
View File

@@ -1,10 +1,12 @@
""" """
QQ机器人主入口 Project entrypoint.
""" """
from __future__ import annotations
import sys import sys
from pathlib import Path from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent project_root = Path(__file__).parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
@@ -23,10 +25,6 @@ def _sqlite_supports_trigram(sqlite_module) -> bool:
def _ensure_sqlite_for_chroma(): def _ensure_sqlite_for_chroma():
"""
Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma.
On some cloud images, system sqlite lacks trigram support.
"""
try: try:
import sqlite3 import sqlite3
except Exception: except Exception:
@@ -55,35 +53,8 @@ def _ensure_sqlite_for_chroma():
_ensure_sqlite_for_chroma() _ensure_sqlite_for_chroma()
from src.core.bot import MyClient, build_intents from src.core.bot import main as run_bot_main
from src.core.config import Config
from src.utils.logger import setup_logger
def main():
"""主函数"""
# 设置日志
logger = setup_logger()
try:
# 验证配置
Config.validate()
logger.info("配置验证通过")
# 创建并启动机器人(最小权限,避免 4014 disallowed intents
logger.info("正在启动QQ机器人...")
intents = build_intents()
client = MyClient(intents=intents)
client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET)
except ValueError as e:
logger.error(f"配置错误: {e}")
logger.error("请检查 .env 文件配置")
sys.exit(1)
except Exception as e:
logger.error(f"启动失败: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() run_bot_main()

View File

@@ -1,15 +1,12 @@
# ============================================ # Core runtime
# 核心依赖(必须安装)
# ============================================
qq-botpy qq-botpy
python-dotenv>=1.0.0 python-dotenv>=1.0.0
# ============================================ # AI providers
# AI 功能依赖(可选)
# 如果不需要 AI 对话功能,可以注释掉下面的依赖
# ============================================
openai>=1.0.0 openai>=1.0.0
anthropic>=0.18.0 anthropic>=0.18.0
# Memory storage
numpy>=1.24.0 numpy>=1.24.0
chromadb>=0.4.0 # 向量数据库,用于记忆存储 chromadb>=0.4.0
pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持 pysqlite3-binary>=0.5.3; platform_system != "Windows"

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill skill"""
from src.ai.skills.base import Skill
class CmdZipSkillSkill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465404375
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465404375 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465404375Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465404375 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465404375",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465434774
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465434774 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465434774Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465434774 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465434774",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465467809
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465467809 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465467809Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465467809 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465467809",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465652075
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465652075 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465652075Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465652075 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465652075",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465685352
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465685352 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465685352Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465685352 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465685352",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465936294
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465936294 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465936294Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465936294 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465936294",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772465966322
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772465966322 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772465966322Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772465966322 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772465966322",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,7 +0,0 @@
# cmd_zip_skill_1772466071278
## 描述
zip skill
## 工具
- example_tool(text)

View File

@@ -1,13 +0,0 @@
"""cmd_zip_skill_1772466071278 skill"""
from src.ai.skills.base import Skill
class CmdZipSkill1772466071278Skill(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"cmd_zip_skill_1772466071278 收到: {text}"
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "cmd_zip_skill_1772466071278",
"version": "1.0.0",
"description": "zip skill",
"author": "test",
"dependencies": [],
"enabled": false
}

View File

@@ -1,12 +0,0 @@
# skills_creator
Tools exposed to AI:
- create_skill
- write_skill_main
- update_skill_metadata
- delete_skill
- load_skill
- reload_skill
- list_skills
This skill enables creating and managing other skills from chat-driven AI tool calls.

View File

@@ -1,147 +0,0 @@
"""skills_creator skill"""
import json
from pathlib import Path
import shutil
from typing import Any, Dict
from src.ai.skills.base import Skill, SkillsManager, create_skill_template
class SkillsCreatorSkill(Skill):
"""Provide tools to create and manage local skills from chat."""
def __init__(self):
super().__init__()
self.skills_root = Path("skills")
self.protected_skills = {"skills_creator"}
async def initialize(self):
self.register_tool("create_skill", self.create_skill)
self.register_tool("write_skill_main", self.write_skill_main)
self.register_tool("update_skill_metadata", self.update_skill_metadata)
self.register_tool("delete_skill", self.delete_skill)
self.register_tool("load_skill", self.load_skill)
self.register_tool("reload_skill", self.reload_skill)
self.register_tool("list_skills", self.list_skills)
def _skill_key(self, skill_name: str) -> str:
return SkillsManager.normalize_skill_key(skill_name)
def _skill_path(self, skill_name: str) -> Path:
return self.skills_root / self._skill_key(skill_name)
async def create_skill(
self,
skill_name: str,
description: str = "custom skill",
author: str = "chat_user",
auto_load: bool = True,
overwrite: bool = False,
) -> str:
skill_key = self._skill_key(skill_name)
skill_path = self._skill_path(skill_key)
if skill_path.exists() and not overwrite:
return f"技能已存在: {skill_key}"
if skill_path.exists() and overwrite:
shutil.rmtree(skill_path)
create_skill_template(skill_key, self.skills_root, description=description, author=author)
if auto_load and self.manager:
await self.manager.load_skill(skill_key)
return f"已创建技能: {skill_key}"
async def write_skill_main(self, skill_name: str, code: str, auto_reload: bool = True) -> str:
skill_key = self._skill_key(skill_name)
skill_path = self._skill_path(skill_key)
if not skill_path.exists():
return f"技能不存在: {skill_key}"
main_file = skill_path / "main.py"
main_file.write_text(code, encoding="utf-8")
if auto_reload and self.manager:
await self.manager.reload_skill(skill_key)
return f"已更新技能代码: {skill_key}/main.py"
async def update_skill_metadata(self, skill_name: str, fields: Dict[str, Any]) -> str:
skill_key = self._skill_key(skill_name)
skill_path = self._skill_path(skill_key)
if not skill_path.exists():
return f"技能不存在: {skill_key}"
metadata_file = skill_path / "skill.json"
if metadata_file.exists():
metadata = json.loads(metadata_file.read_text(encoding="utf-8"))
else:
metadata = {
"name": skill_key,
"version": "1.0.0",
"description": "",
"author": "chat_user",
"dependencies": [],
"enabled": True,
}
for key, value in fields.items():
metadata[key] = value
metadata.setdefault("name", skill_key)
metadata.setdefault("version", "1.0.0")
metadata.setdefault("description", "")
metadata.setdefault("author", "chat_user")
metadata.setdefault("dependencies", [])
metadata.setdefault("enabled", True)
metadata_file.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
return f"已更新技能元数据: {skill_key}/skill.json"
async def delete_skill(self, skill_name: str, delete_files: bool = True) -> str:
skill_key = self._skill_key(skill_name)
if skill_key in self.protected_skills:
return f"拒绝删除受保护技能: {skill_key}"
if self.manager:
removed = await self.manager.uninstall_skill(skill_key, delete_files=delete_files)
return f"删除结果({skill_key}): {removed}"
skill_path = self._skill_path(skill_key)
if delete_files and skill_path.exists():
shutil.rmtree(skill_path)
return f"已删除技能目录: {skill_key}"
return f"技能不存在或未删除: {skill_key}"
async def load_skill(self, skill_name: str) -> str:
skill_key = self._skill_key(skill_name)
if not self.manager:
return "SkillsManager 不可用"
success = await self.manager.load_skill(skill_key)
return f"加载技能 {skill_key}: {success}"
async def reload_skill(self, skill_name: str) -> str:
skill_key = self._skill_key(skill_name)
if not self.manager:
return "SkillsManager 不可用"
success = await self.manager.reload_skill(skill_key)
return f"重载技能 {skill_key}: {success}"
async def list_skills(self) -> str:
if not self.manager:
return "SkillsManager 不可用"
payload = {
"loaded": self.manager.list_skills(),
"available": self.manager.list_available_skills(),
}
return json.dumps(payload, ensure_ascii=False)
async def cleanup(self):
pass

View File

@@ -1,8 +0,0 @@
{
"name": "skills_creator",
"version": "1.0.0",
"description": "Create, edit, load, and remove skills from chat",
"author": "QQBot",
"dependencies": [],
"enabled": true
}

View File

@@ -1,14 +1,7 @@
""" """AI package exports."""
AI模块 - 提供AI模型接入、人格系统、记忆系统和长任务处理能力
"""
from .client import AIClient
from .personality import PersonalitySystem
from .memory import MemorySystem
from .task_manager import LongTaskManager
__all__ = [ from .client import AIClient
'AIClient', from .memory import MemorySystem
'PersonalitySystem', from .personality import PersonalitySystem
'MemorySystem',
'LongTaskManager' __all__ = ["AIClient", "MemorySystem", "PersonalitySystem"]
]

View File

@@ -1,98 +1,97 @@
"""
AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘
""" """
import inspect Unified AI client for chat, memory and persona.
"""
from __future__ import annotations
import asyncio
import json import json
import re import re
from typing import List, Optional, Dict, Any, AsyncIterator, Tuple
from pathlib import Path from pathlib import Path
from .base import ModelConfig, ModelProvider, Message, ToolRegistry from typing import Any, Dict, List, Optional
from .models import OpenAIModel, AnthropicModel
from .personality import PersonalitySystem import httpx
from .base import Message, ModelConfig, ModelProvider
from .memory import MemorySystem from .memory import MemorySystem
from .task_manager import LongTaskManager from .models import AnthropicModel, OpenAIModel
from .personality import PersonalitySystem
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
logger = setup_logger('AIClient') logger = setup_logger("AIClient")
class AIClient: class AIClient:
"""AI瀹㈡埛绔?- 缁熶竴鎺ュ彛""" """High-level application service for chat and memory/persona orchestration."""
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
embed_config: Optional[ModelConfig] = None, embed_config: Optional[ModelConfig] = None,
data_dir: Path = Path("data/ai"), data_dir: Path = Path("data/ai"),
use_vector_db: bool = True use_vector_db: bool = True,
use_query_embedding: bool = False,
chat_retries: int = 1,
chat_retry_backoff: float = 0.8,
): ):
self.config = model_config self.config = model_config
self.data_dir = data_dir self.data_dir = data_dir
self.data_dir.mkdir(parents=True, exist_ok=True) self.data_dir.mkdir(parents=True, exist_ok=True)
# 初始化主模型 self.chat_retries = max(0, int(chat_retries))
self.chat_retry_backoff = max(0.0, float(chat_retry_backoff))
self.model = self._create_model(model_config) self.model = self._create_model(model_config)
self.embed_model = self._create_model(embed_config) if embed_config else None
# 初始化嵌入模型(如果提供)
self.embed_model = None
if embed_config:
self.embed_model = self._create_model(embed_config)
logger.info(
f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}"
)
# 初始化工具注册表
self.tools = ToolRegistry()
self._tool_sources: Dict[str, str] = {}
# 初始化人格系统
self.personality = PersonalitySystem( self.personality = PersonalitySystem(
config_path=data_dir / "personalities.json" config_path=data_dir / "personalities.json",
state_path=data_dir / "personality_state.json",
) )
# 初始化记忆系统
self.memory = MemorySystem( self.memory = MemorySystem(
storage_path=data_dir / "long_term_memory.json", storage_path=data_dir / "long_term_memory.json",
embed_func=self._embed_wrapper, embed_func=self._embed_wrapper,
importance_evaluator=self._evaluate_memory_importance, importance_evaluator=self._evaluate_memory_importance,
use_vector_db=use_vector_db use_vector_db=use_vector_db,
) use_query_embedding=use_query_embedding,
# 初始化长任务管理器
self.task_manager = LongTaskManager(
storage_path=data_dir / "tasks.json"
) )
logger.info( logger.info(
f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}" "AI client initialized",
extra={
"provider": model_config.provider.value,
"model": model_config.model_name,
"use_vector_db": use_vector_db,
"use_query_embedding": use_query_embedding,
"chat_retries": self.chat_retries,
},
) )
def _create_model(self, config: ModelConfig): def _create_model(self, config: Optional[ModelConfig]):
"""创建模型实例。""" if config is None:
if config.provider == ModelProvider.OPENAI: return None
if config.provider in {
ModelProvider.OPENAI,
ModelProvider.DEEPSEEK,
ModelProvider.QWEN,
}:
return OpenAIModel(config) return OpenAIModel(config)
elif config.provider == ModelProvider.ANTHROPIC: if config.provider == ModelProvider.ANTHROPIC:
return AnthropicModel(config) return AnthropicModel(config)
elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]: raise ValueError(f"Unsupported model provider: {config.provider}")
# DeepSeek 和 Qwen 使用 OpenAI 兼容接口
return OpenAIModel(config)
else:
raise ValueError(f"不支持的模型提供商: {config.provider}")
async def _embed_wrapper(self, text: str) -> List[float]: async def _embed_wrapper(self, text: str) -> List[float]:
"""嵌入向量包装器。"""
try: try:
# 如果有独立的嵌入模型,优先使用
if self.embed_model: if self.embed_model:
return await self.embed_model.embed(text) return await self.embed_model.embed(text)
# 否则尝试使用主模型
return await self.model.embed(text) return await self.model.embed(text)
except NotImplementedError: except NotImplementedError:
# 如果都不支持嵌入,返回 None记忆系统会降级 logger.warning("Current model does not support embeddings; fallback to local embedding.")
logger.warning("Current model does not support embeddings; vector retrieval disabled")
return None return None
except Exception as e: except Exception as exc:
logger.error(f"生成嵌入向量失败: {e}") logger.warning(f"Embedding generation failed: {exc}")
return None return None
@staticmethod @staticmethod
@@ -120,17 +119,12 @@ class AIClient:
async def _evaluate_memory_importance( async def _evaluate_memory_importance(
self, content: str, metadata: Optional[Dict] = None self, content: str, metadata: Optional[Dict] = None
) -> float: ) -> float:
"""
调用主模型评估记忆重要性,返回 [0, 1] 分值。
"""
system_prompt = ( system_prompt = (
"你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。" "You evaluate if content should be kept as long-term memory. "
"输出一个 0 到 1 的数字,数字越大表示越重要。" "Return only a float between 0 and 1."
"只输出数字,不要输出任何解释、单位或多余文本。"
) )
payload = json.dumps( payload = json.dumps(
{"content": content, "metadata": metadata or {}}, {"content": content, "metadata": metadata or {}}, ensure_ascii=False
ensure_ascii=False,
) )
messages = [ messages = [
Message(role="system", content=system_prompt), Message(role="system", content=system_prompt),
@@ -146,636 +140,181 @@ class AIClient:
) )
score = self._parse_importance_score(response.content) score = self._parse_importance_score(response.content)
return max(0.0, min(1.0, score)) return max(0.0, min(1.0, score))
except Exception as e: except Exception as exc:
logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}") logger.warning(f"Memory importance evaluation failed, fallback to 0.5: {exc}")
return 0.5 return 0.5
@staticmethod
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
try:
text = json.dumps(payload, ensure_ascii=False, default=str)
except Exception:
text = str(payload)
if len(text) > max_len:
return text[:max_len] + "..."
return text
@staticmethod
def _is_retriable_error(exc: Exception) -> bool:
if isinstance(exc, (httpx.ReadTimeout, httpx.ConnectError, asyncio.TimeoutError)):
return True
lower = str(exc).lower()
retriable_signals = [
"timed out",
"timeout",
"temporarily unavailable",
"connection reset",
"service unavailable",
"rate limit",
"429",
"502",
"503",
"504",
]
return any(signal in lower for signal in retriable_signals)
async def _chat_with_retry(self, messages: List[Message], **kwargs):
attempts = 1 + self.chat_retries
last_error: Optional[Exception] = None
for attempt in range(1, attempts + 1):
try:
return await self.model.chat(messages, None, **kwargs)
except Exception as exc:
last_error = exc
if attempt >= attempts or not self._is_retriable_error(exc):
break
sleep_seconds = self.chat_retry_backoff * (2 ** (attempt - 1))
logger.warning(
"Chat request failed, retrying",
extra={
"attempt": attempt,
"max_attempts": attempts,
"sleep_seconds": sleep_seconds,
"error": str(exc),
},
)
await asyncio.sleep(sleep_seconds)
if last_error:
raise last_error
raise RuntimeError("chat retry loop ended unexpectedly")
async def chat( async def chat(
self, self,
user_id: str, user_id: str,
user_message: str, user_message: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
use_memory: bool = True, use_memory: bool = True,
use_tools: bool = True,
stream: bool = False, stream: bool = False,
**kwargs group_id: Optional[str] = None,
session_id: Optional[str] = None,
memory_key: Optional[str] = None,
**kwargs,
) -> str: ) -> str:
"""对话接口。""" if stream:
try: raise NotImplementedError("stream mode is not supported by AIClient.chat")
# 构建消息列表
messages = [] memory_user_key = memory_key or user_id
messages: List[Message] = []
# 系统提示词
if system_prompt is None: if system_prompt is None:
system_prompt = self.personality.get_system_prompt() system_prompt = self.personality.get_system_prompt(
user_id=user_id,
group_id=group_id,
session_id=session_id,
)
# 注入记忆上下文
if use_memory: if use_memory:
short_term, long_term = await self.memory.get_context( short_term, long_term = await self.memory.get_context(
user_id=user_id, user_id=memory_user_key,
query=user_message query=user_message,
) )
if short_term or long_term: if short_term or long_term:
memory_context = self.memory.format_context(short_term, long_term) memory_context = self.memory.format_context(short_term, long_term)
system_prompt += f"\n\n{memory_context}" system_prompt = f"{system_prompt}\n\n{memory_context}".strip()
messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="system", content=system_prompt))
# 添加用户消息
messages.append(Message(role="user", content=user_message)) messages.append(Message(role="user", content=user_message))
# 准备工具
tools = None
tool_names: List[str] = []
if use_tools and self.tools.list():
tools = self.tools.to_openai_format()
tool_names = [tool.name for tool in self.tools.list()]
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
if forced_tool_name:
kwargs = dict(kwargs)
kwargs["forced_tool_name"] = forced_tool_name
if tools:
before_count = len(tools)
tools = [
tool
for tool in tools
if ((tool.get("function") or {}).get("name") == forced_tool_name)
]
if len(tools) == 1:
tool_names = [forced_tool_name]
logger.info( logger.info(
"显式工具调用已收敛工具列表: " "LLM request",
f"{before_count} -> {len(tools)}" extra={
) "user_id": user_id,
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") "group_id": group_id,
"session_id": session_id,
logger.info( "memory_key": memory_user_key,
"LLM请求: " "use_memory": use_memory,
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " "message_preview": self._preview_log_payload(user_message),
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " },
f"tool_names={self._preview_log_payload(tool_names)}, "
f"forced_tool={forced_tool_name or '-'}"
)
logger.info(
"LLM输入: "
f"user_message={self._preview_log_payload(user_message)}"
) )
# 调用模型 response = await self._chat_with_retry(messages, **kwargs)
if stream:
return self._chat_stream(messages, tools, **kwargs)
else:
response = await self.model.chat(messages, tools, **kwargs)
response_tool_count = len(response.tool_calls or [])
response_tool_names = []
for tool_call in response.tool_calls or []:
if isinstance(tool_call, dict):
function_info = tool_call.get("function") or {}
response_tool_names.append(function_info.get("name"))
else:
function_info = getattr(tool_call, "function", None)
response_tool_names.append(
getattr(function_info, "name", None) if function_info else None
)
logger.info( logger.info(
"LLM首轮输出: " "LLM response",
f"tool_calls={response_tool_count}, " extra={"content_preview": self._preview_log_payload(response.content)},
f"tool_names={self._preview_log_payload(response_tool_names)}, "
f"content={self._preview_log_payload(response.content)}"
) )
# 处理工具调用
if response.tool_calls:
response = await self._handle_tool_calls(
messages, response, tools, **kwargs
)
elif forced_tool_name:
forced_response = await self._run_forced_tool_fallback(
forced_tool_name=forced_tool_name,
user_message=user_message,
)
if forced_response is not None:
response = forced_response
# 写入记忆
if use_memory: if use_memory:
stored_memory = await self.memory.add_qa_pair( stored_memory = await self.memory.add_qa_pair(
user_id=user_id, user_id=memory_user_key,
question=user_message, question=user_message,
answer=response.content, answer=response.content,
metadata={"source": "chat"}, metadata={
"source": "chat",
"user_id": user_id,
"group_id": group_id,
"session_id": session_id,
},
) )
if stored_memory: if stored_memory:
logger.info( logger.info(
"已写入长期记忆问答对:\n" "Long-term memory stored",
f"{stored_memory.content}\n" extra={
f"memory_id={stored_memory.id}, " "memory_id": stored_memory.id,
f"importance={stored_memory.importance:.2f}" "importance": stored_memory.importance,
},
) )
return response.content return response.content
except Exception as e: def set_personality(
logger.error(f"对话失败: {type(e).__name__}: {e!r}") self, personality_name: str, scope: str = "global", scope_id: Optional[str] = None
raise ) -> bool:
return self.personality.set_personality(
async def _run_forced_tool_fallback( key=personality_name,
self, forced_tool_name: str, user_message: str scope=scope,
) -> Optional[Message]: scope_id=scope_id,
"""Execute forced tool locally when model did not emit tool_calls."""
tool_def = self.tools.get(forced_tool_name)
tool_source = self._tool_sources.get(forced_tool_name, "custom")
if not tool_def:
logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}")
return None
logger.warning(
"模型未返回 tool_calls启用本地强制工具执行: "
f"source={tool_source}, name={forced_tool_name}"
) )
try:
result = tool_def.function()
if inspect.isawaitable(result):
result = await result
except TypeError as exc:
logger.warning(
"本地强制工具执行失败(参数不匹配): "
f"name={forced_tool_name}, error={exc}"
)
return None
except Exception as exc:
logger.warning(
"本地强制工具执行失败: "
f"name={forced_tool_name}, error={exc}"
)
return None
result_text = str(result)
pipelined_text = await self._run_skill_doc_pipeline(
forced_tool_name=forced_tool_name,
skill_doc=result_text,
user_message=user_message,
)
if pipelined_text is not None:
result_text = pipelined_text
prefix_limit = self._extract_prefix_limit(user_message)
if prefix_limit:
result_text = result_text[:prefix_limit]
logger.info(
"本地强制工具执行成功: "
f"source={tool_source}, name={forced_tool_name}, "
f"result={self._preview_log_payload(result_text)}"
)
return Message(role="assistant", content=result_text)
async def _run_skill_doc_pipeline(
self, forced_tool_name: str, skill_doc: str, user_message: str
) -> Optional[str]:
"""Run an extra model step: execute instructions from skill doc on user text."""
if not forced_tool_name.endswith(".read_skill_doc"):
return None
target_text = self._extract_processing_payload(user_message)
if not target_text:
return None
logger.info(
"强制工具后续处理开始: "
f"name={forced_tool_name}, target_len={len(target_text)}"
)
messages = [
Message(
role="system",
content=(
"你是技能执行器。请严格按下面技能文档处理用户文本。"
"不要复述技能文档,不要解释工具调用过程,只输出最终处理结果。\n\n"
"[技能文档开始]\n"
f"{skill_doc}\n"
"[技能文档结束]"
),
),
Message(
role="user",
content=(
"请根据技能文档处理以下文本,保持原意并提升自然度:\n"
f"{target_text}"
),
),
]
try:
response = await self.model.chat(messages=messages, tools=None)
content = (response.content or "").strip()
if not content:
return None
logger.info(
"强制工具后续处理完成: "
f"name={forced_tool_name}, output_len={len(content)}"
)
return content
except Exception as exc:
logger.warning(
"强制工具后续处理失败,回退为工具原始输出: "
f"name={forced_tool_name}, error={exc}"
)
return None
async def _chat_stream(
self,
messages: List[Message],
tools: Optional[List[Dict]],
**kwargs
) -> AsyncIterator[str]:
"""流式对话。"""
async for chunk in self.model.chat_stream(messages, tools, **kwargs):
yield chunk
async def _handle_tool_calls(
self,
messages: List[Message],
response: Message,
tools: Optional[List[Dict]],
**kwargs
) -> Message:
"""处理工具调用。"""
messages.append(response)
total_calls = len(response.tool_calls or [])
if total_calls:
logger.info(f"检测到工具调用请求: {total_calls}")
# 执行工具调用
for tool_call in response.tool_calls or []:
try:
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
except Exception as e:
logger.warning(f"解析工具调用失败: {e}")
fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None)
if fallback_id:
messages.append(Message(
role="tool",
content=f"工具参数解析失败: {str(e)}",
tool_call_id=fallback_id,
name="tool"
))
continue
if not tool_name:
logger.warning(f"跳过无效工具调用: {tool_call}")
continue
tool_def = self.tools.get(tool_name)
tool_source = self._tool_sources.get(tool_name, "custom")
if not tool_def:
error_msg = f"未找到工具: {tool_name}"
logger.warning(error_msg)
messages.append(Message(
role="tool",
name=tool_name,
content=error_msg,
tool_call_id=tool_call_id
))
continue
try:
logger.info(
"工具调用开始: "
f"source={tool_source}, name={tool_name}, "
f"args={self._preview_log_payload(tool_args)}"
)
result = tool_def.function(**tool_args)
if inspect.isawaitable(result):
result = await result
logger.info(
"工具调用成功: "
f"source={tool_source}, name={tool_name}, "
f"result={self._preview_log_payload(result)}"
)
messages.append(Message(
role="tool",
name=tool_name,
content=str(result),
tool_call_id=tool_call_id
))
except Exception as e:
logger.warning(
"工具调用失败: "
f"source={tool_source}, name={tool_name}, error={e}"
)
messages.append(Message(
role="tool",
name=tool_name,
content=f"工具执行失败: {str(e)}",
tool_call_id=tool_call_id
))
# 再次调用模型获取最终响应
final_kwargs = dict(kwargs)
# Force only the first model turn, avoid recursive force after tool result.
final_kwargs.pop("forced_tool_name", None)
final_response = await self.model.chat(messages, tools, **final_kwargs)
logger.info(
"LLM最终输出: "
f"content={self._preview_log_payload(final_response.content)}"
)
return final_response
def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]:
"""兼容不同 SDK 返回的工具调用结构。"""
if isinstance(tool_call, dict):
tool_call_id = tool_call.get('id')
function = tool_call.get('function') or {}
tool_name = function.get('name')
raw_args = function.get('arguments')
else:
tool_call_id = getattr(tool_call, 'id', None)
function = getattr(tool_call, 'function', None)
tool_name = getattr(function, 'name', None) if function else None
raw_args = getattr(function, 'arguments', None) if function else None
tool_args = self._normalize_tool_args(raw_args)
return tool_name, tool_args, tool_call_id
def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]:
"""将工具参数统一转换为字典。"""
if raw_args is None:
return {}
if isinstance(raw_args, dict):
return raw_args
if isinstance(raw_args, str):
raw_args = raw_args.strip()
if not raw_args:
return {}
parsed = json.loads(raw_args)
if not isinstance(parsed, dict):
raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}")
return parsed
if hasattr(raw_args, 'model_dump'):
parsed = raw_args.model_dump()
if isinstance(parsed, dict):
return parsed
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
@staticmethod
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
"""日志中展示参数/结果时使用的简短预览。"""
try:
text = json.dumps(payload, ensure_ascii=False, default=str)
except Exception:
text = str(payload)
if len(text) > max_len:
return text[:max_len] + "..."
return text
@staticmethod
def _extract_prefix_limit(user_message: str) -> Optional[int]:
"""Extract requested output prefix length like '前100字'."""
if not user_message:
return None
match = re.search(r"\s*(\d{1,4})\s*字", user_message)
if not match:
return None
try:
limit = int(match.group(1))
except (TypeError, ValueError):
return None
if limit <= 0:
return None
return min(limit, 5000)
@staticmethod
def _extract_processing_payload(user_message: str) -> Optional[str]:
"""Extract text payload like '处理以下文本:...' from user message."""
if not user_message:
return None
text = user_message.strip()
markers = [
"以下文本:",
"以下文本:",
"文本:",
"文本:",
]
for marker in markers:
idx = text.find(marker)
if idx < 0:
continue
payload = text[idx + len(marker) :].strip()
if payload:
return payload
pattern = re.compile(
r"(?:处理|润色|改写|人性化处理|优化)[\s\S]{0,32}(?:如下|以下)[:]\s*([\s\S]+)$"
)
match = pattern.search(text)
if match:
payload = (match.group(1) or "").strip()
if payload:
return payload
return None
@staticmethod
def _compact_identifier(text: str) -> str:
"""Compact identifier for fuzzy matching (e.g. humanizer_zh -> humanizerzh)."""
return re.sub(r"[^a-z0-9]+", "", (text or "").lower())
@staticmethod
def _extract_forced_tool_name(
user_message: str, available_tool_names: List[str]
) -> Optional[str]:
if not user_message or not available_tool_names:
return None
triggers = [
"调用工具",
"使用工具",
"只调用",
"务必调用",
"必须调用",
"调用",
"使用",
"tool",
]
if not any(trigger in user_message for trigger in triggers):
return None
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
explicit_matches = [
name for name in pattern.findall(user_message) if name in available_tool_names
]
if len(explicit_matches) == 1:
return explicit_matches[0]
if len(explicit_matches) > 1:
return None
contained = [name for name in available_tool_names if name in user_message]
if len(contained) == 1:
return contained[0]
# 允许只写 skill/tool 前缀(如 humanizer_zh前提是前缀下只有一个工具。
prefixes = sorted(
{name.split(".", 1)[0] for name in available_tool_names},
key=len,
reverse=True,
)
matched_prefixes = [
prefix
for prefix in prefixes
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
]
if len(matched_prefixes) == 1:
prefix_tools = [
name
for name in available_tool_names
if name.startswith(f"{matched_prefixes[0]}.")
]
if len(prefix_tools) == 1:
return prefix_tools[0]
# 模糊匹配:支持省略下划线/点号的写法(如 humanizerzh
compact_message = AIClient._compact_identifier(user_message)
if compact_message:
compact_full_matches = []
for tool_name in available_tool_names:
compact_tool_name = AIClient._compact_identifier(tool_name)
if compact_tool_name and compact_tool_name in compact_message:
compact_full_matches.append(tool_name)
if len(compact_full_matches) == 1:
return compact_full_matches[0]
if len(compact_full_matches) > 1:
return None
compact_prefix_map: Dict[str, List[str]] = {}
for tool_name in available_tool_names:
prefix = tool_name.split(".", 1)[0]
compact_prefix = AIClient._compact_identifier(prefix)
if not compact_prefix:
continue
compact_prefix_map.setdefault(compact_prefix, []).append(tool_name)
compact_prefix_matches = [
compact_prefix
for compact_prefix in compact_prefix_map
if compact_prefix in compact_message
]
if len(compact_prefix_matches) == 1:
matched_tools = compact_prefix_map[compact_prefix_matches[0]]
if len(matched_tools) == 1:
return matched_tools[0]
return None
def set_personality(self, personality_name: str) -> bool:
"""设置人格。"""
return self.personality.set_personality(personality_name)
def list_personalities(self) -> List[str]: def list_personalities(self) -> List[str]:
"""列出所有人格。"""
return self.personality.list_personalities() return self.personality.list_personalities()
def switch_model(self, model_config: ModelConfig) -> bool: def switch_model(self, model_config: ModelConfig) -> bool:
"""Runtime switch for primary chat model.""" self.model = self._create_model(model_config)
new_model = self._create_model(model_config)
self.model = new_model
self.config = model_config self.config = model_config
logger.info( logger.info(
f"已切换主模型: {model_config.provider.value}/{model_config.model_name}" "Primary model switched",
extra={
"provider": model_config.provider.value,
"model": model_config.model_name,
},
) )
return True return True
async def create_long_task(
self,
user_id: str,
title: str,
description: str,
steps: List[Dict],
metadata: Optional[Dict] = None
) -> str:
"""创建长任务。"""
return self.task_manager.create_task(
user_id=user_id,
title=title,
description=description,
steps=steps,
metadata=metadata
)
async def start_task(
self,
task_id: str,
progress_callback: Optional[callable] = None
):
"""启动任务。"""
await self.task_manager.start_task(task_id, progress_callback)
def get_task_status(self, task_id: str) -> Optional[Dict]:
"""获取任务状态。"""
return self.task_manager.get_task_status(task_id)
def register_tool(
self,
name: str,
description: str,
parameters: Dict,
function: callable,
source: str = "custom",
):
"""注册工具。"""
from .base import ToolDefinition
tool = ToolDefinition(
name=name,
description=description,
parameters=parameters,
function=function
)
self.tools.register(tool)
self._tool_sources[name] = source
logger.info(f"已注册工具: {name} (source={source})")
def unregister_tool(self, name: str) -> bool:
"""卸载工具。"""
removed = self.tools.unregister(name)
if removed:
self._tool_sources.pop(name, None)
logger.info(f"已卸载工具: {name}")
return removed
def unregister_tools_by_prefix(self, prefix: str) -> int:
"""按前缀批量卸载工具。"""
removed_count = self.tools.unregister_by_prefix(prefix)
for tool_name in list(self._tool_sources.keys()):
if tool_name.startswith(prefix):
self._tool_sources.pop(tool_name, None)
if removed_count:
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
return removed_count
def clear_memory(self, user_id: str): def clear_memory(self, user_id: str):
"""清除用户短期记忆。"""
self.memory.clear_short_term(user_id) self.memory.clear_short_term(user_id)
logger.info(f"Cleared short-term memory for user {user_id}") logger.info(f"Cleared short-term memory for {user_id}")
async def clear_long_term_memory(self, user_id: str) -> bool: async def clear_long_term_memory(self, user_id: str) -> bool:
try: try:
await self.memory.clear_long_term(user_id) await self.memory.clear_long_term(user_id)
logger.info(f"Cleared long-term memory for user {user_id}") logger.info(f"Cleared long-term memory for {user_id}")
return True return True
except Exception as e: except Exception as exc:
logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}") logger.warning(f"Failed to clear long-term memory for {user_id}: {exc}")
return False return False
async def list_long_term_memories(self, user_id: str, limit: int = 20): async def list_long_term_memories(self, user_id: str, limit: int = 20):
@@ -823,10 +362,18 @@ class AIClient:
return await self.memory.delete_long_term(user_id, memory_id) return await self.memory.delete_long_term(user_id, memory_id)
async def clear_all_memory(self, user_id: str) -> bool: async def clear_all_memory(self, user_id: str) -> bool:
"""清除用户全部记忆(短期 + 长期)。"""
self.clear_memory(user_id) self.clear_memory(user_id)
try:
return await self.clear_long_term_memory(user_id) return await self.clear_long_term_memory(user_id)
except Exception:
return False
async def close(self):
await self.memory.close()
for model in [self.model, self.embed_model]:
if model is None:
continue
close = getattr(model, "close", None)
if close is None:
continue
maybe_awaitable = close()
if asyncio.iscoroutine(maybe_awaitable):
await maybe_awaitable

View File

@@ -1,13 +0,0 @@
"""
MCP模块
"""
from .base import MCPServer, MCPClient, MCPManager, MCPResource, MCPTool, MCPPrompt
__all__ = [
'MCPServer',
'MCPClient',
'MCPManager',
'MCPResource',
'MCPTool',
'MCPPrompt'
]

View File

@@ -1,230 +0,0 @@
"""
MCP (Model Context Protocol) 支持
"""
import asyncio
import json
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, asdict
from pathlib import Path
from src.utils.logger import setup_logger
logger = setup_logger('MCPSystem')
@dataclass
class MCPResource:
"""MCP资源"""
uri: str
name: str
description: str
mime_type: str
@dataclass
class MCPTool:
"""MCP工具"""
name: str
description: str
input_schema: Dict[str, Any]
@dataclass
class MCPPrompt:
"""MCP提示词"""
name: str
description: str
arguments: List[Dict[str, Any]]
class MCPServer:
"""MCP服务器基类"""
def __init__(self, name: str, version: str):
self.name = name
self.version = version
self.resources: Dict[str, MCPResource] = {}
self.tools: Dict[str, Callable] = {}
self.tool_specs: Dict[str, MCPTool] = {}
self.prompts: Dict[str, MCPPrompt] = {}
async def initialize(self):
"""初始化服务器"""
pass
async def shutdown(self):
"""关闭服务器"""
pass
def register_resource(self, resource: MCPResource):
"""注册资源"""
self.resources[resource.uri] = resource
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
"""注册工具"""
tool = MCPTool(name=name, description=description, input_schema=input_schema)
self.tool_specs[name] = tool
self.tools[name] = handler
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
def register_prompt(self, prompt: MCPPrompt):
"""注册提示词"""
self.prompts[prompt.name] = prompt
async def list_resources(self) -> List[MCPResource]:
"""列出资源"""
return list(self.resources.values())
async def read_resource(self, uri: str) -> Optional[str]:
"""读取资源"""
raise NotImplementedError
async def list_tools(self) -> List[MCPTool]:
"""列出工具"""
return list(self.tool_specs.values())
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
if name not in self.tools:
raise ValueError(f"工具不存在: {name}")
handler = self.tools[name]
logger.info(
"MCP工具调用开始: "
f"server={self.name}, tool={name}, "
f"args={json.dumps(arguments, ensure_ascii=False, default=str)}"
)
try:
result = await handler(**arguments)
except Exception as exc:
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
raise
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
return result
async def list_prompts(self) -> List[MCPPrompt]:
"""列出提示词"""
return list(self.prompts.values())
async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]:
"""获取提示词"""
raise NotImplementedError
class MCPClient:
"""MCP客户端"""
def __init__(self):
self.servers: Dict[str, MCPServer] = {}
async def connect_server(self, server: MCPServer):
"""连接服务器"""
await server.initialize()
self.servers[server.name] = server
logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}")
async def disconnect_server(self, server_name: str):
"""断开服务器"""
if server_name in self.servers:
await self.servers[server_name].shutdown()
del self.servers[server_name]
logger.info(f"✅ 断开MCP服务器: {server_name}")
def get_server(self, name: str) -> Optional[MCPServer]:
"""获取服务器"""
return self.servers.get(name)
def list_servers(self) -> List[str]:
"""列出所有服务器"""
return list(self.servers.keys())
async def list_all_resources(self) -> Dict[str, List[MCPResource]]:
"""列出所有资源"""
result = {}
for name, server in self.servers.items():
result[name] = await server.list_resources()
return result
async def list_all_tools(self) -> Dict[str, List[MCPTool]]:
"""列出所有工具"""
result = {}
for name, server in self.servers.items():
result[name] = await server.list_tools()
return result
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
server = self.get_server(server_name)
if not server:
raise ValueError(f"服务器不存在: {server_name}")
return await server.call_tool(tool_name, arguments)
class MCPManager:
"""MCP管理器"""
def __init__(self, config_path: Path):
self.config_path = config_path
self.client = MCPClient()
self.server_configs: Dict[str, Dict] = {}
self._load_config()
def _load_config(self):
"""加载配置"""
if self.config_path.exists():
with open(self.config_path, 'r', encoding='utf-8') as f:
self.server_configs = json.load(f)
def _save_config(self):
"""保存配置"""
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, 'w', encoding='utf-8') as f:
json.dump(self.server_configs, f, ensure_ascii=False, indent=2)
async def register_server(self, server: MCPServer, config: Optional[Dict] = None):
"""注册服务器"""
await self.client.connect_server(server)
if config:
self.server_configs[server.name] = config
self._save_config()
async def unregister_server(self, server_name: str):
"""注销服务器"""
await self.client.disconnect_server(server_name)
if server_name in self.server_configs:
del self.server_configs[server_name]
self._save_config()
def get_client(self) -> MCPClient:
"""获取客户端"""
return self.client
async def get_all_tools_for_ai(self) -> List[Dict]:
"""获取所有工具AI格式"""
all_tools = []
tools_by_server = await self.client.list_all_tools()
for server_name, tools in tools_by_server.items():
for tool in tools:
all_tools.append({
"type": "function",
"function": {
"name": f"{server_name}.{tool.name}",
"description": tool.description,
"parameters": tool.input_schema
}
})
return all_tools
async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any:
"""执行工具"""
parts = full_tool_name.split('.', 1)
if len(parts) != 2:
raise ValueError(f"工具名格式错误: {full_tool_name}")
server_name, tool_name = parts
logger.info(f"MCP执行请求: {full_tool_name}")
return await self.client.call_tool(server_name, tool_name, arguments)

View File

@@ -1,6 +0,0 @@
"""
MCP服务器实现
"""
from .filesystem import FileSystemMCPServer
__all__ = ['FileSystemMCPServer']

View File

@@ -1,123 +0,0 @@
"""
MCP示例服务器 - 文件系统访问
"""
from pathlib import Path
from typing import Optional
from ..base import MCPServer, MCPResource
class FileSystemMCPServer(MCPServer):
"""文件系统MCP服务器"""
def __init__(self, root_path: Path):
super().__init__(name="filesystem", version="1.0.0")
self.root_path = root_path
async def initialize(self):
"""初始化"""
# 注册工具
self.register_tool(
name="read_file",
description="读取文件内容",
input_schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径"
}
},
"required": ["path"]
},
handler=self.read_file
)
self.register_tool(
name="write_file",
description="写入文件内容",
input_schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径"
},
"content": {
"type": "string",
"description": "文件内容"
}
},
"required": ["path", "content"]
},
handler=self.write_file
)
self.register_tool(
name="list_directory",
description="列出目录内容",
input_schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "目录路径"
}
},
"required": ["path"]
},
handler=self.list_directory
)
def _resolve_path(self, path: str) -> Path:
"""解析路径"""
full_path = (self.root_path / path).resolve()
# 安全检查确保路径在root_path内
if not str(full_path).startswith(str(self.root_path)):
raise ValueError("路径超出允许范围")
return full_path
async def read_file(self, path: str) -> str:
"""读取文件"""
file_path = self._resolve_path(path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {path}")
if not file_path.is_file():
raise ValueError(f"不是文件: {path}")
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
async def write_file(self, path: str, content: str) -> str:
"""写入文件"""
file_path = self._resolve_path(path)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return f"文件已写入: {path}"
async def list_directory(self, path: str) -> list:
"""列出目录"""
dir_path = self._resolve_path(path)
if not dir_path.exists():
raise FileNotFoundError(f"目录不存在: {path}")
if not dir_path.is_dir():
raise ValueError(f"不是目录: {path}")
items = []
for item in dir_path.iterdir():
items.append({
"name": item.name,
"type": "directory" if item.is_dir() else "file",
"size": item.stat().st_size if item.is_file() else None
})
return items

View File

@@ -1,25 +1,29 @@
"""
记忆系统:短期记忆、长期记忆与 RAG 检索(向量数据库)。
""" """
import asyncio Memory system: short-term window + long-term retrieval.
"""
from __future__ import annotations
import hashlib import hashlib
import shutil import shutil
import time import time
import uuid import uuid
from typing import List, Dict, Optional, Tuple, Callable, Awaitable from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from collections import deque from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from .vector_store import VectorStore, VectorMemory, ChromaVectorStore, JSONVectorStore
from .vector_store import ChromaVectorStore, JSONVectorStore, VectorMemory, VectorStore
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
logger = setup_logger('MemorySystem') logger = setup_logger("MemorySystem")
@dataclass @dataclass
class MemoryItem: class MemoryItem:
"""记忆项(用于短期记忆)。""" """In-memory short-term record."""
content: str content: str
timestamp: datetime timestamp: datetime
user_id: str user_id: str
@@ -27,154 +31,136 @@ class MemoryItem:
metadata: Dict = field(default_factory=dict) metadata: Dict = field(default_factory=dict)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""转换为字典。"""
return { return {
'content': self.content, "content": self.content,
'timestamp': self.timestamp.isoformat(), "timestamp": self.timestamp.isoformat(),
'user_id': self.user_id, "user_id": self.user_id,
'importance': self.importance, "importance": self.importance,
'metadata': self.metadata "metadata": self.metadata,
} }
class ShortTermMemory: class ShortTermMemory:
"""短期记忆(滑动窗口)。""" """Short-term memory window."""
def __init__(self, max_size: int = 20, max_age_minutes: int = 30): def __init__(self, max_size: int = 20, max_age_minutes: int = 30):
self.max_size = max_size self.max_size = max_size
self.max_age = timedelta(minutes=max_age_minutes) self.max_age = timedelta(minutes=max_age_minutes)
self.memories: Dict[str, deque] = {} # user_id -> deque of MemoryItem self.memories: Dict[str, deque] = {}
def add(self, user_id: str, content: str, metadata: Optional[Dict] = None): def add(self, user_id: str, content: str, metadata: Optional[Dict] = None):
"""添加短期记忆。"""
if user_id not in self.memories: if user_id not in self.memories:
self.memories[user_id] = deque(maxlen=self.max_size) self.memories[user_id] = deque(maxlen=self.max_size)
memory = MemoryItem( self.memories[user_id].append(
MemoryItem(
content=content, content=content,
timestamp=datetime.now(), timestamp=datetime.now(),
user_id=user_id, user_id=user_id,
metadata=metadata or {} metadata=metadata or {},
)
) )
self.memories[user_id].append(memory)
def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]: def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]:
"""获取短期记忆。""" items = list(self.memories.get(user_id, []))
if user_id not in self.memories: if not items:
return [] return []
# 过滤过期记忆
now = datetime.now() now = datetime.now()
valid_memories = [ valid = [m for m in items if now - m.timestamp <= self.max_age]
m for m in self.memories[user_id] self.memories[user_id] = deque(valid, maxlen=self.max_size)
if now - m.timestamp <= self.max_age
]
if limit: if limit and limit > 0:
valid_memories = valid_memories[-limit:] return valid[-limit:]
return valid
return valid_memories
def clear(self, user_id: str): def clear(self, user_id: str):
"""清除用户短期记忆。"""
if user_id in self.memories:
self.memories.pop(user_id, None) self.memories.pop(user_id, None)
class MemorySystem: class MemorySystem:
"""记忆系统:整合短期记忆与长期记忆。""" """Memory system: short-term + long-term storage."""
def __init__( def __init__(
self, self,
storage_path: Path, storage_path: Path,
embed_func: Optional[callable] = None, embed_func: Optional[Callable[[str], Awaitable[List[float]]]] = None,
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None, importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
importance_threshold: float = 0.6, importance_threshold: float = 0.6,
use_vector_db: bool = True, use_vector_db: bool = True,
use_query_embedding: bool = False, use_query_embedding: bool = False,
max_long_term_per_user: int = 500,
dedup_window_seconds: int = 300,
): ):
self.short_term = ShortTermMemory() self.short_term = ShortTermMemory()
self.embed_func = embed_func self.embed_func = embed_func
self.importance_evaluator = importance_evaluator self.importance_evaluator = importance_evaluator
self.importance_threshold = importance_threshold self.importance_threshold = importance_threshold
# Only embed retrieval queries when explicitly enabled.
self.use_query_embedding = use_query_embedding self.use_query_embedding = use_query_embedding
self.max_long_term_per_user = max(10, int(max_long_term_per_user))
self.dedup_window = timedelta(seconds=max(1, int(dedup_window_seconds)))
# 初始化向量存储
if use_vector_db: if use_vector_db:
chroma_path = storage_path.parent / "chroma_db" chroma_path = storage_path.parent / "chroma_db"
chroma_store = self._init_chroma_store(chroma_path) chroma_store = self._init_chroma_store(chroma_path)
if chroma_store is not None: if chroma_store is not None:
self.vector_store = chroma_store self.vector_store = chroma_store
logger.info("Using Chroma vector store")
else: else:
self.vector_store = JSONVectorStore(storage_path) self.vector_store = JSONVectorStore(storage_path)
else: else:
# 使用 JSON 存储(向后兼容)
self.vector_store = JSONVectorStore(storage_path) self.vector_store = JSONVectorStore(storage_path)
logger.info("使用 JSON 存储")
@staticmethod @staticmethod
def _is_chroma_table_conflict(error: Exception) -> bool: def _is_chroma_table_conflict(error: Exception) -> bool:
msg = str(error).lower() return "table embeddings already exists" in str(error).lower()
return "table embeddings already exists" in msg
@staticmethod @staticmethod
def _is_chroma_trigram_error(error: Exception) -> bool: def _is_chroma_trigram_error(error: Exception) -> bool:
msg = str(error).lower() return "no such tokenizer: trigram" in str(error).lower()
return "no such tokenizer: trigram" in msg
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]: def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
"""初始化 Chroma遇到已知 sqlite schema 冲突时尝试修复。"""
try: try:
return ChromaVectorStore(chroma_path) return ChromaVectorStore(chroma_path)
except Exception as error: except Exception as error:
if self._is_chroma_trigram_error(error): if self._is_chroma_trigram_error(error):
logger.warning( logger.warning(
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。" "Chroma unavailable (sqlite trigram unsupported), fallback to JSON store."
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
) )
return None return None
if not self._is_chroma_table_conflict(error): if not self._is_chroma_table_conflict(error):
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}") logger.warning(f"Chroma init failed, fallback to JSON: {error}")
return None return None
# 先做一次短暂重试,处理并发启动时的瞬时冲突。 logger.warning(f"Chroma schema conflict, retry once: {error}")
logger.warning(f"Chroma 初始化出现 schema 冲突,正在重试: {error}")
time.sleep(0.2) time.sleep(0.2)
try: try:
return ChromaVectorStore(chroma_path) return ChromaVectorStore(chroma_path)
except Exception as retry_error: except Exception as retry_error:
if not self._is_chroma_table_conflict(retry_error): if not self._is_chroma_table_conflict(retry_error):
logger.warning(f"Chroma 重试失败,降级为 JSON 存储: {retry_error}") logger.warning(f"Chroma retry failed, fallback to JSON: {retry_error}")
return None return None
backup_name = ( backup_path = chroma_path.parent / (
f"{chroma_path.name}_backup_conflict_" f"{chroma_path.name}_backup_conflict_"
f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
) )
backup_path = chroma_path.parent / backup_name
try: try:
if chroma_path.exists(): if chroma_path.exists():
shutil.move(str(chroma_path), str(backup_path)) shutil.move(str(chroma_path), str(backup_path))
chroma_path.mkdir(parents=True, exist_ok=True) chroma_path.mkdir(parents=True, exist_ok=True)
repaired = ChromaVectorStore(chroma_path) repaired = ChromaVectorStore(chroma_path)
logger.warning( logger.warning(
f"检测到 Chroma 元数据库冲突,已重建目录并保留备份: {backup_path}" f"Chroma metadata repaired by rebuilding directory. Backup: {backup_path}"
) )
return repaired return repaired
except Exception as repair_error: except Exception as repair_error:
logger.warning(f"Chroma 修复失败,降级为 JSON 存储: {repair_error}") logger.warning(f"Chroma repair failed, fallback to JSON: {repair_error}")
return None return None
@staticmethod @staticmethod
def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]: def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]:
if not values: if not values:
return [0.0] * dim return [0.0] * dim
normalized = [float(v) for v in values[:dim]] normalized = [float(v) for v in values[:dim]]
if len(normalized) < dim: if len(normalized) < dim:
normalized.extend([0.0] * (dim - len(normalized))) normalized.extend([0.0] * (dim - len(normalized)))
@@ -191,14 +177,11 @@ class MemorySystem:
return vec return vec
for idx, byte in enumerate(encoded): for idx, byte in enumerate(encoded):
bucket = idx % dim vec[idx % dim] += byte / 255.0
vec[bucket] += (byte / 255.0)
digest = hashlib.sha256(encoded).digest() digest = hashlib.sha256(encoded).digest()
for idx, byte in enumerate(digest): for idx, byte in enumerate(digest):
bucket = idx % dim vec[idx % dim] += ((byte / 255.0) - 0.5) * 0.1
vec[bucket] += ((byte / 255.0) - 0.5) * 0.1
return vec return vec
async def _build_embedding(self, text: str) -> List[float]: async def _build_embedding(self, text: str) -> List[float]:
@@ -207,9 +190,8 @@ class MemorySystem:
embedding = await self.embed_func(text) embedding = await self.embed_func(text)
if embedding: if embedding:
return [float(v) for v in list(embedding)] return [float(v) for v in list(embedding)]
except Exception as e: except Exception as exc:
logger.warning(f"embedding generation failed: {e}") logger.warning(f"Embedding generation failed, fallback to local: {exc}")
return self._local_embedding(text) return self._local_embedding(text)
async def _add_vector_memory( async def _add_vector_memory(
@@ -231,10 +213,8 @@ class MemorySystem:
): ):
return True return True
# Chroma collection may have a fixed historical embedding dimension. candidate_dims: List[int] = []
candidate_dims = [] for dim in [len(embedding or []), 1024, 1536, 768, 384, 3072]:
base_len = len(embedding or [])
for dim in [base_len, 1024, 1536, 768, 384, 3072]:
if dim and dim > 0 and dim not in candidate_dims: if dim and dim > 0 and dim not in candidate_dims:
candidate_dims.append(dim) candidate_dims.append(dim)
@@ -250,7 +230,6 @@ class MemorySystem:
) )
if ok: if ok:
return True return True
return False return False
@staticmethod @staticmethod
@@ -261,15 +240,53 @@ class MemorySystem:
value = 0.5 value = 0.5
return max(0.0, min(1.0, value)) return max(0.0, min(1.0, value))
async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float:
if not content or not content.strip():
return 0.0
if self.importance_evaluator:
try:
score = await self.importance_evaluator(content, metadata)
return self._normalize_importance(score)
except Exception as exc:
logger.warning(f"Importance evaluation failed, fallback to 0.5: {exc}")
return 0.5
async def _is_duplicate_long_term(self, user_id: str, content: str) -> bool:
normalized = " ".join((content or "").split())
if not normalized:
return True
all_memories = await self.vector_store.get_all(user_id)
recent_cutoff = datetime.now() - self.dedup_window
for memory in all_memories[-20:]:
if (
" ".join((memory.content or "").split()) == normalized
and memory.timestamp >= recent_cutoff
):
return True
return False
async def _trim_user_long_term(self, user_id: str):
all_memories = await self.vector_store.get_all(user_id)
if len(all_memories) <= self.max_long_term_per_user:
return
all_memories.sort(key=lambda m: (m.importance, m.timestamp))
to_delete = len(all_memories) - self.max_long_term_per_user
for memory in all_memories[:to_delete]:
await self.vector_store.delete(memory.id)
async def add_message( async def add_message(
self, self,
user_id: str, user_id: str,
role: str, role: str,
content: str, content: str,
metadata: Optional[Dict] = None metadata: Optional[Dict] = None,
): ):
"""向短期记忆添加单条消息(不做长期记忆评分)。""" payload = dict(metadata or {})
self.short_term.add(user_id, content, metadata) payload.setdefault("role", role)
self.short_term.add(user_id, content, payload)
async def add_qa_pair( async def add_qa_pair(
self, self,
@@ -278,9 +295,6 @@ class MemorySystem:
answer: str, answer: str,
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
) -> Optional[VectorMemory]: ) -> Optional[VectorMemory]:
"""
添加最新问答对,并仅对该问答对做模型重要性评估。
"""
user_meta = {"role": "user"} user_meta = {"role": "user"}
assistant_meta = {"role": "assistant"} assistant_meta = {"role": "assistant"}
if isinstance(metadata, dict): if isinstance(metadata, dict):
@@ -297,6 +311,8 @@ class MemorySystem:
importance = await self._evaluate_importance(qa_content, qa_metadata) importance = await self._evaluate_importance(qa_content, qa_metadata)
if importance < self.importance_threshold: if importance < self.importance_threshold:
return None return None
if await self._is_duplicate_long_term(user_id, qa_content):
return None
embedding = await self._build_embedding(qa_content) embedding = await self._build_embedding(qa_content)
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
@@ -311,97 +327,88 @@ class MemorySystem:
if not ok: if not ok:
return None return None
await self._trim_user_long_term(user_id)
return await self.get_long_term(user_id, memory_id) return await self.get_long_term(user_id, memory_id)
async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float: @staticmethod
"""评估记忆重要性。""" def _simple_text_score(query: str, text: str) -> float:
if not content or not content.strip(): query_tokens = [token for token in query.lower().split() if token]
if not query_tokens:
return 0.0 return 0.0
text_lower = text.lower()
if self.importance_evaluator: hit = sum(1 for token in query_tokens if token in text_lower)
try: return hit / len(query_tokens)
score = await self.importance_evaluator(content, metadata)
return self._normalize_importance(score)
except Exception as e:
logger.warning(f"importance evaluation failed, fallback to neutral score: {e}")
# 当模型评估不可用时,使用中性分数作为兜底。
return 0.5
async def get_context( async def get_context(
self, self,
user_id: str, user_id: str,
query: Optional[str] = None, query: Optional[str] = None,
max_short_term: int = 10, max_short_term: int = 10,
max_long_term: int = 5 max_long_term: int = 5,
) -> Tuple[List[MemoryItem], List[VectorMemory]]: ) -> Tuple[List[MemoryItem], List[VectorMemory]]:
"""获取上下文(短期 + 长期记忆)。"""
# 获取短期记忆
short_term_memories = self.short_term.get(user_id, limit=max_short_term) short_term_memories = self.short_term.get(user_id, limit=max_short_term)
long_term_memories: List[VectorMemory] = []
# 获取相关长期记忆
long_term_memories = []
if query and self.use_query_embedding: if query and self.use_query_embedding:
try: try:
# 使用向量检索
query_embedding = await self._build_embedding(query) query_embedding = await self._build_embedding(query)
if query_embedding:
long_term_memories = await self.vector_store.search( long_term_memories = await self.vector_store.search(
user_id=user_id, user_id=user_id,
query_embedding=query_embedding, query_embedding=query_embedding,
limit=max_long_term limit=max_long_term,
) )
except Exception as e: except Exception as exc:
logger.warning(f"向量检索失败,改用重要性检索: {e}") logger.warning(f"Vector search failed, fallback to lexical: {exc}")
if query and not long_term_memories: if query and not long_term_memories:
query_lower = query.lower()
try:
candidates = await self.vector_store.get_all(user_id) candidates = await self.vector_store.get_all(user_id)
matches = [m for m in candidates if query_lower in m.content.lower()] scored = []
matches.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) for memory in candidates:
long_term_memories = matches[:max_long_term] score = self._simple_text_score(query, memory.content)
except Exception: if score <= 0:
pass continue
combined = score * 0.7 + memory.importance * 0.3
scored.append((combined, memory))
scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True)
long_term_memories = [item[1] for item in scored[:max_long_term]]
# 濡傛灉鍚戦噺妫€绱㈠け璐ユ垨娌℃湁缁撴灉锛屼娇鐢ㄩ噸瑕佹€ф绱?
if not long_term_memories: if not long_term_memories:
long_term_memories = await self.vector_store.get_by_importance( long_term_memories = await self.vector_store.get_by_importance(
user_id=user_id, user_id=user_id,
limit=max_long_term limit=max_long_term,
) )
# 更新长期记忆访问记录
for memory in long_term_memories: for memory in long_term_memories:
await self.vector_store.update_access(memory.id) await self.vector_store.update_access(memory.id)
return short_term_memories, long_term_memories return short_term_memories, long_term_memories
def format_context( def format_context(
self, self, short_term: List[MemoryItem], long_term: List[VectorMemory]
short_term: List[MemoryItem],
long_term: List[VectorMemory]
) -> str: ) -> str:
"""格式化上下文为文本。""" lines: List[str] = []
context = ""
if long_term: if long_term:
context += "## 相关历史记忆\n" lines.append("## 相关历史记忆")
for i, memory in enumerate(long_term, 1): for idx, memory in enumerate(long_term, 1):
context += f"{i}. {memory.content}\n" lines.append(f"{idx}. {memory.content}")
context += "\n" lines.append("")
if short_term: if short_term:
context += "## 最近对话\n" lines.append("## 最近对话")
for memory in short_term: for item in short_term:
context += f"- {memory.content}\n" role = str((item.metadata or {}).get("role") or "").strip().lower()
if role == "user":
prefix = "用户"
elif role == "assistant":
prefix = "助手"
else:
prefix = "对话"
lines.append(f"- {prefix}: {item.content}")
return context return "\n".join(lines).strip()
async def list_long_term( async def list_long_term(self, user_id: str, limit: int = 20) -> List[VectorMemory]:
self, user_id: str, limit: int = 20
) -> List[VectorMemory]:
memories = await self.vector_store.get_all(user_id) memories = await self.vector_store.get_all(user_id)
memories.sort(key=lambda m: m.timestamp, reverse=True) memories.sort(key=lambda m: m.timestamp, reverse=True)
if limit > 0: if limit > 0:
@@ -422,21 +429,24 @@ class MemorySystem:
importance: float = 0.8, importance: float = 0.8,
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
) -> Optional[VectorMemory]: ) -> Optional[VectorMemory]:
memory_id = str(uuid.uuid4()) if await self._is_duplicate_long_term(user_id, content):
importance = self._normalize_importance(importance) return None
embedding = await self._build_embedding(content)
memory_id = str(uuid.uuid4())
normalized_importance = self._normalize_importance(importance)
embedding = await self._build_embedding(content)
ok = await self._add_vector_memory( ok = await self._add_vector_memory(
memory_id=memory_id, memory_id=memory_id,
user_id=user_id, user_id=user_id,
content=content, content=content,
embedding=embedding, embedding=embedding,
importance=importance, importance=normalized_importance,
metadata=metadata or {}, metadata=metadata or {},
) )
if not ok: if not ok:
return None return None
await self._trim_user_long_term(user_id)
return await self.get_long_term(user_id, memory_id) return await self.get_long_term(user_id, memory_id)
async def search_long_term( async def search_long_term(
@@ -457,10 +467,15 @@ class MemorySystem:
return results return results
all_memories = await self.vector_store.get_all(user_id) all_memories = await self.vector_store.get_all(user_id)
query_lower = query.lower() scored = []
matched = [m for m in all_memories if query_lower in m.content.lower()] for memory in all_memories:
matched.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) score = self._simple_text_score(query, memory.content)
return matched[:limit] if score <= 0:
continue
combined = score * 0.7 + memory.importance * 0.3
scored.append((combined, memory))
scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True)
return [item[1] for item in scored[:limit]]
async def update_long_term( async def update_long_term(
self, self,
@@ -503,7 +518,6 @@ class MemorySystem:
) )
if not added: if not added:
return None return None
return await self.get_long_term(user_id, memory_id) return await self.get_long_term(user_id, memory_id)
async def delete_long_term(self, user_id: str, memory_id: str) -> bool: async def delete_long_term(self, user_id: str, memory_id: str) -> bool:
@@ -513,13 +527,10 @@ class MemorySystem:
return await self.vector_store.delete(memory_id) return await self.vector_store.delete(memory_id)
def clear_short_term(self, user_id: str): def clear_short_term(self, user_id: str):
"""清除短期记忆。"""
self.short_term.clear(user_id) self.short_term.clear(user_id)
async def clear_long_term(self, user_id: str): async def clear_long_term(self, user_id: str):
"""清除长期记忆。"""
await self.vector_store.clear_user(user_id) await self.vector_store.clear_user(user_id)
async def close(self): async def close(self):
"""关闭记忆系统。"""
await self.vector_store.close() await self.vector_store.close()

View File

@@ -1,30 +1,33 @@
""" """
Anthropic Claude模型实现 Anthropic Claude model implementation.
""" """
from typing import List, Optional, AsyncIterator
from __future__ import annotations
from typing import AsyncIterator, List, Optional
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic
from ..base import BaseAIModel, Message, ModelConfig from ..base import BaseAIModel, Message, ModelConfig
class AnthropicModel(BaseAIModel): class AnthropicModel(BaseAIModel):
"""Anthropic Claude模型实现""" """Anthropic Claude model implementation."""
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.client = AsyncAnthropic( self.client = AsyncAnthropic(
api_key=config.api_key, api_key=config.api_key,
base_url=config.api_base, base_url=config.api_base,
timeout=config.timeout timeout=config.timeout,
) )
async def chat( async def chat(
self, self,
messages: List[Message], messages: List[Message],
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
**kwargs **kwargs,
) -> Message: ) -> Message:
"""同步对话"""
# 分离system消息
system_message = None system_message = None
formatted_messages = [] formatted_messages = []
@@ -32,10 +35,7 @@ class AnthropicModel(BaseAIModel):
if msg.role == "system": if msg.role == "system":
system_message = msg.content system_message = msg.content
else: else:
formatted_messages.append({ formatted_messages.append({"role": msg.role, "content": msg.content})
"role": msg.role,
"content": msg.content
})
params = { params = {
"model": self.config.model_name, "model": self.config.model_name,
@@ -43,46 +43,43 @@ class AnthropicModel(BaseAIModel):
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"temperature": self.config.temperature, "temperature": self.config.temperature,
} }
if system_message: if system_message:
params["system"] = system_message params["system"] = system_message
if tools: if tools:
params["tools"] = tools params["tools"] = tools
params.update(kwargs) params.update(kwargs)
response = await self.client.messages.create(**params) response = await self.client.messages.create(**params)
content = "" content = ""
tool_calls = [] tool_calls = []
for block in response.content: for block in response.content:
if block.type == "text": if block.type == "text":
content += block.text content += block.text
elif block.type == "tool_use": elif block.type == "tool_use":
tool_calls.append({ tool_calls.append(
{
"id": block.id, "id": block.id,
"type": "function", "type": "function",
"function": { "function": {
"name": block.name, "name": block.name,
"arguments": block.input "arguments": block.input,
},
} }
}) )
return Message( return Message(
role="assistant", role="assistant",
content=content, content=content,
tool_calls=tool_calls if tool_calls else None tool_calls=tool_calls if tool_calls else None,
) )
async def chat_stream( async def chat_stream(
self, self,
messages: List[Message], messages: List[Message],
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
**kwargs **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""流式对话"""
system_message = None system_message = None
formatted_messages = [] formatted_messages = []
@@ -90,10 +87,7 @@ class AnthropicModel(BaseAIModel):
if msg.role == "system": if msg.role == "system":
system_message = msg.content system_message = msg.content
else: else:
formatted_messages.append({ formatted_messages.append({"role": msg.role, "content": msg.content})
"role": msg.role,
"content": msg.content
})
params = { params = {
"model": self.config.model_name, "model": self.config.model_name,
@@ -102,13 +96,10 @@ class AnthropicModel(BaseAIModel):
"temperature": self.config.temperature, "temperature": self.config.temperature,
"stream": True, "stream": True,
} }
if system_message: if system_message:
params["system"] = system_message params["system"] = system_message
if tools: if tools:
params["tools"] = tools params["tools"] = tools
params.update(kwargs) params.update(kwargs)
async with self.client.messages.stream(**params) as stream: async with self.client.messages.stream(**params) as stream:
@@ -116,5 +107,13 @@ class AnthropicModel(BaseAIModel):
yield text yield text
async def embed(self, text: str) -> List[float]: async def embed(self, text: str) -> List[float]:
"""文本嵌入Anthropic不直接提供需要使用其他服务""" raise NotImplementedError(
raise NotImplementedError("Anthropic不提供嵌入API请使用OpenAI或其他服务") "Anthropic does not provide embedding API; use OpenAI-compatible embedding model."
)
async def close(self):
close_fn = getattr(self.client, "close", None)
if callable(close_fn):
maybe_awaitable = close_fn()
if hasattr(maybe_awaitable, "__await__"):
await maybe_awaitable

View File

@@ -24,7 +24,7 @@ class OpenAIModel(BaseAIModel):
self.logger = logger self.logger = logger
self._embedding_token_limit: Optional[int] = None self._embedding_token_limit: Optional[int] = None
http_client = httpx.AsyncClient( self._http_client = httpx.AsyncClient(
timeout=config.timeout, timeout=config.timeout,
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
) )
@@ -33,7 +33,7 @@ class OpenAIModel(BaseAIModel):
api_key=config.api_key, api_key=config.api_key,
base_url=config.api_base, base_url=config.api_base,
timeout=config.timeout, timeout=config.timeout,
http_client=http_client, http_client=self._http_client,
) )
self._supports_tools = False self._supports_tools = False
@@ -590,3 +590,8 @@ class OpenAIModel(BaseAIModel):
self.logger.error(f"text preview: {repr(candidate_text[:100])}") self.logger.error(f"text preview: {repr(candidate_text[:100])}")
self.logger.error(f"full traceback:\n{traceback.format_exc()}") self.logger.error(f"full traceback:\n{traceback.format_exc()}")
raise raise
async def close(self):
"""Release network resources."""
if self._http_client:
await self._http_client.aclose()

View File

@@ -1,4 +1,6 @@
"""Personality system for role-play profiles.""" """Personality system for role-play profiles with scope priority."""
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@@ -32,8 +34,6 @@ class PersonalityProfile:
custom_instructions: str = "" custom_instructions: str = ""
def to_system_prompt(self) -> str: def to_system_prompt(self) -> str:
"""Build plain-text system prompt."""
traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly" traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly"
lines = [ lines = [
"Role Setting", "Role Setting",
@@ -55,13 +55,29 @@ class PersonalityProfile:
class PersonalitySystem: class PersonalitySystem:
"""Personality management and persistence.""" """Personality management, persistence and scope resolution."""
def __init__(self, config_path: Optional[Path] = None): def __init__(
self,
config_path: Optional[Path] = None,
state_path: Optional[Path] = None,
):
self.config_path = config_path or Path("config/personalities.json") self.config_path = config_path or Path("config/personalities.json")
self.state_path = state_path or self.config_path.with_name("personality_state.json")
self.personalities: Dict[str, PersonalityProfile] = {} self.personalities: Dict[str, PersonalityProfile] = {}
self.current_personality: Optional[PersonalityProfile] = None self._active_global_key: str = "default"
self._active_user_keys: Dict[str, str] = {}
self._active_group_keys: Dict[str, str] = {}
self._active_session_keys: Dict[str, str] = {}
self._load_personalities() self._load_personalities()
self._load_state()
self._ensure_valid_active_keys()
@property
def current_personality(self) -> Optional[PersonalityProfile]:
return self.personalities.get(self._active_global_key)
def _dict_to_profile(self, config: Dict) -> PersonalityProfile: def _dict_to_profile(self, config: Dict) -> PersonalityProfile:
trait_names = config.get("traits", []) trait_names = config.get("traits", [])
@@ -83,27 +99,30 @@ class PersonalitySystem:
custom_instructions=str(config.get("custom_instructions", "")), custom_instructions=str(config.get("custom_instructions", "")),
) )
def _load_personalities(self): def _profile_to_dict(self, profile: PersonalityProfile) -> Dict:
"""Load personality config from disk or create defaults.""" return {
"name": profile.name,
"description": profile.description,
"traits": [trait.name for trait in profile.traits],
"speaking_style": profile.speaking_style,
"example_responses": profile.example_responses,
"custom_instructions": profile.custom_instructions,
}
def _load_personalities(self):
if self.config_path.exists(): if self.config_path.exists():
with open(self.config_path, "r", encoding="utf-8") as f: with open(self.config_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for key, config in data.items(): for key, config in data.items():
if isinstance(config, dict):
self.personalities[key] = self._dict_to_profile(config) self.personalities[key] = self._dict_to_profile(config)
if "default" in self.personalities: if self.personalities:
self.current_personality = self.personalities["default"]
elif self.personalities:
first_key = next(iter(self.personalities.keys()))
self.current_personality = self.personalities[first_key]
return return
self._create_default_personalities() self._create_default_personalities()
def _create_default_personalities(self): def _create_default_personalities(self):
"""Create and persist built-in default profiles."""
default = PersonalityProfile( default = PersonalityProfile(
name="Assistant", name="Assistant",
description="A friendly and practical AI assistant.", description="A friendly and practical AI assistant.",
@@ -146,87 +165,200 @@ class PersonalitySystem:
"tech_expert": tech_expert, "tech_expert": tech_expert,
"creative": creative, "creative": creative,
} }
self.current_personality = default self._active_global_key = "default"
self._save_personalities() self._save_personalities()
def _save_personalities(self): def _save_personalities(self):
"""Persist personalities to disk."""
self.config_path.parent.mkdir(parents=True, exist_ok=True) self.config_path.parent.mkdir(parents=True, exist_ok=True)
data = {} data = {
key: self._profile_to_dict(profile)
for key, profile in self.personalities.items(): for key, profile in self.personalities.items()
data[key] = {
"name": profile.name,
"description": profile.description,
"traits": [trait.name for trait in profile.traits],
"speaking_style": profile.speaking_style,
"example_responses": profile.example_responses,
"custom_instructions": profile.custom_instructions,
} }
with open(self.config_path, "w", encoding="utf-8") as f: with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
def set_personality(self, key: str) -> bool: def _load_state(self):
"""Switch active personality by key.""" if not self.state_path.exists():
return
try:
with open(self.state_path, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
return
self._active_global_key = str(data.get("global", self._active_global_key))
self._active_user_keys = {
str(k): str(v)
for k, v in (data.get("user") or {}).items()
if k is not None and v is not None
}
self._active_group_keys = {
str(k): str(v)
for k, v in (data.get("group") or {}).items()
if k is not None and v is not None
}
self._active_session_keys = {
str(k): str(v)
for k, v in (data.get("session") or {}).items()
if k is not None and v is not None
}
def _save_state(self):
self.state_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"global": self._active_global_key,
"user": self._active_user_keys,
"group": self._active_group_keys,
"session": self._active_session_keys,
}
with open(self.state_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def _ensure_valid_active_keys(self):
if self._active_global_key not in self.personalities:
if "default" in self.personalities:
self._active_global_key = "default"
elif self.personalities:
self._active_global_key = next(iter(self.personalities.keys()))
else:
self._active_global_key = ""
self._active_user_keys = {
scope_id: key
for scope_id, key in self._active_user_keys.items()
if key in self.personalities
}
self._active_group_keys = {
scope_id: key
for scope_id, key in self._active_group_keys.items()
if key in self.personalities
}
self._active_session_keys = {
scope_id: key
for scope_id, key in self._active_session_keys.items()
if key in self.personalities
}
self._save_state()
def set_personality(
self, key: str, scope: str = "global", scope_id: Optional[str] = None
) -> bool:
if key not in self.personalities: if key not in self.personalities:
return False return False
self.current_personality = self.personalities[key] scope_normalized = (scope or "global").strip().lower()
if scope_normalized == "global":
self._active_global_key = key
self._save_state()
return True return True
def get_system_prompt(self) -> str: if not scope_id:
"""Get current personality prompt.""" return False
if self.current_personality: if scope_normalized == "user":
return self.current_personality.to_system_prompt() self._active_user_keys[scope_id] = key
return "" elif scope_normalized == "group":
self._active_group_keys[scope_id] = key
elif scope_normalized == "session":
self._active_session_keys[scope_id] = key
else:
return False
self._save_state()
return True
def get_active_personality(
self,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> Optional[PersonalityProfile]:
# Priority: session > group > user > global > default
if session_id:
key = self._active_session_keys.get(session_id)
if key in self.personalities:
return self.personalities[key]
if group_id:
key = self._active_group_keys.get(group_id)
if key in self.personalities:
return self.personalities[key]
if user_id:
key = self._active_user_keys.get(user_id)
if key in self.personalities:
return self.personalities[key]
if self._active_global_key in self.personalities:
return self.personalities[self._active_global_key]
return self.personalities.get("default")
def get_system_prompt(
self,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
profile = self.get_active_personality(
user_id=user_id,
group_id=group_id,
session_id=session_id,
)
return profile.to_system_prompt() if profile else ""
def add_personality(self, key: str, profile: PersonalityProfile) -> bool: def add_personality(self, key: str, profile: PersonalityProfile) -> bool:
"""Add a new personality profile."""
key = key.strip() key = key.strip()
if not key: if not key:
return False return False
self.personalities[key] = profile self.personalities[key] = profile
if not self.current_personality: if not self._active_global_key:
self.current_personality = profile self._active_global_key = key
self._save_state()
self._save_personalities() self._save_personalities()
return True return True
def remove_personality(self, key: str) -> bool: def remove_personality(self, key: str) -> bool:
"""Remove a personality profile."""
if key == "default": if key == "default":
return False return False
if key not in self.personalities: if key not in self.personalities:
return False return False
removed_profile = self.personalities[key]
del self.personalities[key] del self.personalities[key]
if not self.personalities:
self._create_default_personalities()
if self.current_personality == removed_profile: if self._active_global_key == key:
if "default" in self.personalities: self._active_global_key = (
self.current_personality = self.personalities["default"] "default"
elif self.personalities: if "default" in self.personalities
first_key = next(iter(self.personalities.keys())) else next(iter(self.personalities.keys()))
self.current_personality = self.personalities[first_key] )
else:
self.current_personality = None self._active_user_keys = {
scope_id: active_key
for scope_id, active_key in self._active_user_keys.items()
if active_key != key
}
self._active_group_keys = {
scope_id: active_key
for scope_id, active_key in self._active_group_keys.items()
if active_key != key
}
self._active_session_keys = {
scope_id: active_key
for scope_id, active_key in self._active_session_keys.items()
if active_key != key
}
self._save_personalities() self._save_personalities()
self._save_state()
return True return True
def list_personalities(self) -> List[str]: def list_personalities(self) -> List[str]:
"""List all personality keys."""
return sorted(self.personalities.keys()) return sorted(self.personalities.keys())
def get_personality(self, key: str) -> Optional[PersonalityProfile]: def get_personality(self, key: str) -> Optional[PersonalityProfile]:
"""Get personality by key."""
return self.personalities.get(key) return self.personalities.get(key)

View File

@@ -1,6 +0,0 @@
"""
Skills系统初始化
"""
from .base import Skill, SkillsManager, SkillMetadata, create_skill_template
__all__ = ['Skill', 'SkillsManager', 'SkillMetadata', 'create_skill_template']

View File

@@ -1,750 +0,0 @@
"""
Skills 系统 - 可扩展技能插件框架。
"""
from dataclasses import dataclass
import importlib
import inspect
import json
from pathlib import Path
import re
import shutil
import sys
import tempfile
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
import urllib.parse
import urllib.request
import zipfile
import os
import stat
from src.utils.logger import setup_logger
logger = setup_logger("SkillsSystem")
@dataclass
class SkillMetadata:
"""技能元数据。"""
name: str
version: str
description: str
author: str
dependencies: List[str]
enabled: bool = True
class Skill:
"""技能基类。"""
def __init__(self):
self.metadata: Optional[SkillMetadata] = None
self.tools: Dict[str, Callable] = {}
self.manager = None
async def initialize(self):
"""初始化技能。"""
async def cleanup(self):
"""清理技能。"""
def get_tools(self) -> Dict[str, Callable]:
"""获取技能提供的工具。"""
return self.tools
def register_tool(self, name: str, func: Callable):
"""注册工具。"""
self.tools[name] = func
class SkillsManager:
"""技能管理器。"""
_SKILL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]")
_GITHUB_SHORTCUT_PATTERN = re.compile(
r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+(?:#[A-Za-z0-9_.-]+)?$"
)
def __init__(self, skills_dir: Path):
self.skills_dir = skills_dir
self.skills: Dict[str, Skill] = {}
self.skills_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"✅ Skills 目录: {skills_dir}")
@classmethod
def normalize_skill_key(cls, raw_name: str) -> str:
"""将任意输入规范化为可导入的 Python 包名。"""
key = raw_name.strip().lower().replace("-", "_").replace(" ", "_")
key = cls._SKILL_KEY_PATTERN.sub("_", key)
key = re.sub(r"_+", "_", key).strip("_")
if not key:
raise ValueError("技能名不能为空")
if key[0].isdigit():
key = f"skill_{key}"
return key
def _get_skill_path(self, skill_name: str) -> Path:
return self.skills_dir / self.normalize_skill_key(skill_name)
@staticmethod
def _on_rmtree_error(func, path, exc_info):
"""Handle Windows readonly/locked file deletion errors."""
try:
os.chmod(path, stat.S_IWRITE)
func(path)
except Exception:
# Keep original failure path for upper retry logic.
pass
def _read_metadata(self, skill_path: Path, fallback_name: str) -> Dict[str, Any]:
metadata_file = skill_path / "skill.json"
if metadata_file.exists():
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
else:
metadata = {}
metadata.setdefault("name", fallback_name)
metadata.setdefault("version", "1.0.0")
metadata.setdefault("description", f"{fallback_name} skill")
metadata.setdefault("author", "unknown")
metadata.setdefault("dependencies", [])
metadata.setdefault("enabled", True)
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
return metadata
def _ensure_skill_package_layout(self, skill_path: Path, skill_key: str):
"""确保技能目录满足运行最小结构。"""
skill_path.mkdir(parents=True, exist_ok=True)
init_file = skill_path / "__init__.py"
if not init_file.exists():
init_file.write_text("", encoding="utf-8")
main_file = skill_path / "main.py"
if not main_file.exists():
template = f'''"""{skill_key} skill"""
from src.ai.skills.base import Skill
class {"".join(p.capitalize() for p in skill_key.split("_"))}Skill(Skill):
async def initialize(self):
self.register_tool("ping", self.ping)
async def ping(self, text: str = "ok") -> str:
return text
async def cleanup(self):
pass
'''
main_file.write_text(template, encoding="utf-8")
self._read_metadata(skill_path, skill_key)
async def load_skill(self, skill_name: str) -> bool:
"""加载技能。"""
try:
skill_name = self.normalize_skill_key(skill_name)
if skill_name in self.skills:
logger.info(f"✅ 技能已加载: {skill_name}")
return True
skill_path = self._get_skill_path(skill_name)
if not skill_path.exists():
logger.error(f"❌ 技能不存在: {skill_name}")
return False
metadata_file = skill_path / "skill.json"
if not metadata_file.exists():
logger.error(f"❌ 技能元数据不存在: {skill_name}")
return False
with open(metadata_file, "r", encoding="utf-8") as f:
metadata_dict = json.load(f)
metadata = SkillMetadata(**metadata_dict)
if not metadata.enabled:
logger.info(f"⏸️ 技能已禁用: {skill_name}")
return False
module_path = f"skills.{skill_name}.main"
importlib.invalidate_caches()
try:
old_dont_write = sys.dont_write_bytecode
sys.dont_write_bytecode = True
try:
if module_path in sys.modules:
module = importlib.reload(sys.modules[module_path])
else:
module = importlib.import_module(module_path)
finally:
sys.dont_write_bytecode = old_dont_write
except Exception as exc:
logger.error(f"❌ 无法导入技能模块 {module_path}: {exc}")
return False
skill_class = None
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, Skill) and obj != Skill:
skill_class = obj
break
if not skill_class:
logger.error(f"❌ 技能中未找到 Skill 子类: {skill_name}")
return False
skill = skill_class()
skill.metadata = metadata
skill.manager = self
await skill.initialize()
self.skills[skill_name] = skill
logger.info(f"✅ 加载技能: {skill_name} v{metadata.version}")
return True
except Exception as exc:
logger.error(f"❌ 加载技能失败 {skill_name}: {exc}")
return False
async def load_all_skills(self):
"""加载所有可用技能。"""
for skill_name in self.list_available_skills():
await self.load_skill(skill_name)
async def unload_skill(self, skill_name: str) -> bool:
"""仅卸载内存中的技能。"""
skill_name = self.normalize_skill_key(skill_name)
if skill_name not in self.skills:
return False
skill = self.skills[skill_name]
await skill.cleanup()
del self.skills[skill_name]
sys.modules.pop(f"skills.{skill_name}.main", None)
sys.modules.pop(f"skills.{skill_name}", None)
importlib.invalidate_caches()
logger.info(f"✅ 卸载技能: {skill_name}")
return True
async def uninstall_skill(self, skill_name: str, delete_files: bool = True) -> bool:
"""卸载技能并可选删除文件。"""
skill_name = self.normalize_skill_key(skill_name)
if skill_name in self.skills:
await self.unload_skill(skill_name)
if not delete_files:
return True
skill_path = self._get_skill_path(skill_name)
if not skill_path.exists():
return False
removed = False
for _ in range(3):
try:
shutil.rmtree(skill_path, ignore_errors=False, onerror=self._on_rmtree_error)
except PermissionError:
pass
if not skill_path.exists():
removed = True
break
time.sleep(0.2)
if not removed:
try:
metadata_file = skill_path / "skill.json"
metadata = {}
if metadata_file.exists():
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
metadata["enabled"] = False
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.warning(f"⚠️ 删除目录失败,已软卸载技能: {skill_name}")
return True
except Exception:
return False
importlib.invalidate_caches()
logger.info(f"✅ 删除技能目录: {skill_name}")
return True
def get_skill(self, skill_name: str) -> Optional[Skill]:
"""获取已加载技能实例。"""
skill_name = self.normalize_skill_key(skill_name)
return self.skills.get(skill_name)
def list_skills(self) -> List[str]:
"""列出已加载技能。"""
return sorted(self.skills.keys())
def list_available_skills(self) -> List[str]:
"""列出可加载技能目录。"""
if not self.skills_dir.exists():
return []
available: List[str] = []
for skill_dir in self.skills_dir.iterdir():
if not skill_dir.is_dir() or skill_dir.name.startswith("_"):
continue
if (skill_dir / "skill.json").exists() and (skill_dir / "main.py").exists():
try:
with open(skill_dir / "skill.json", "r", encoding="utf-8") as f:
metadata = json.load(f)
if not metadata.get("enabled", True):
continue
available.append(self.normalize_skill_key(skill_dir.name))
except ValueError:
continue
except Exception:
continue
return sorted(set(available))
def get_all_tools(self) -> Dict[str, Callable]:
"""获取全部技能工具。"""
all_tools: Dict[str, Callable] = {}
for skill_name, skill in self.skills.items():
for tool_name, tool_func in skill.get_tools().items():
all_tools[f"{skill_name}.{tool_name}"] = tool_func
return all_tools
async def reload_skill(self, skill_name: str) -> bool:
"""重载技能。"""
skill_name = self.normalize_skill_key(skill_name)
if skill_name in self.skills:
await self.unload_skill(skill_name)
return await self.load_skill(skill_name)
def _parse_github_repo_source(
self,
source: str,
) -> Optional[Tuple[str, str, str, Optional[str]]]:
"""解析 GitHub 仓库 URL返回 owner/repo/branch/subpath。"""
try:
parsed = urllib.parse.urlparse(source.strip())
except Exception:
return None
if parsed.scheme not in {"http", "https"}:
return None
if parsed.netloc.lower() != "github.com":
return None
parts = [part for part in parsed.path.split("/") if part]
if len(parts) < 2:
return None
owner = parts[0]
repo = parts[1]
if repo.endswith(".git"):
repo = repo[:-4]
if not owner or not repo:
return None
branch = "main"
subpath: Optional[str] = None
if len(parts) >= 4 and parts[2] == "tree":
branch = parts[3]
if len(parts) > 4:
subpath = "/".join(parts[4:])
elif len(parts) > 2:
# 不支持 /issues、/blob 等深链接
return None
return owner, repo, branch, subpath
def _resolve_network_source(
self,
source: str,
) -> Tuple[str, Optional[str], Optional[str]]:
"""将网络来源解析为下载 URL并返回安装提示信息。"""
source = source.strip()
github_repo = self._parse_github_repo_source(source)
if github_repo:
owner, repo, branch, subpath = github_repo
codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}"
return codeload_url, self.normalize_skill_key(repo), subpath
if source.startswith(("http://", "https://")):
return source, None, None
if self._GITHUB_SHORTCUT_PATTERN.match(source):
repo_ref, _, branch = source.partition("#")
owner, repo = repo_ref.split("/", 1)
if repo.endswith(".git"):
repo = repo[:-4]
branch = branch or "main"
codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}"
return codeload_url, self.normalize_skill_key(repo), None
raise ValueError("source 必须是 URL 或 owner/repo[#branch]")
def _download_zip(self, url: str, output_zip: Path):
"""下载 zip 包到本地。"""
req = urllib.request.Request(url, headers={"User-Agent": "QQBot-Skills/1.0"})
with urllib.request.urlopen(req, timeout=30) as resp:
data = resp.read()
output_zip.write_bytes(data)
def _find_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]:
"""在目录中扫描技能候选项。"""
candidates: List[Tuple[str, Path]] = []
for metadata_file in root_dir.rglob("skill.json"):
candidate_dir = metadata_file.parent
if not (candidate_dir / "main.py").exists():
continue
try:
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
raw_name = str(metadata.get("name") or candidate_dir.name)
except Exception:
raw_name = candidate_dir.name
try:
skill_key = self.normalize_skill_key(raw_name)
except ValueError:
continue
candidates.append((skill_key, candidate_dir))
uniq: Dict[str, Path] = {}
for key, path in candidates:
uniq[key] = path
return sorted(uniq.items(), key=lambda x: x[0])
def _find_codex_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]:
"""在目录中扫描仅包含 SKILL.md 的候选项。"""
candidates: List[Tuple[str, Path]] = []
for skill_doc in root_dir.rglob("SKILL.md"):
candidate_dir = skill_doc.parent
if (candidate_dir / "skill.json").exists() and (candidate_dir / "main.py").exists():
continue
try:
skill_key = self.normalize_skill_key(candidate_dir.name)
except ValueError:
continue
candidates.append((skill_key, candidate_dir))
uniq: Dict[str, Path] = {}
for key, path in candidates:
uniq[key] = path
return sorted(uniq.items(), key=lambda x: x[0])
def _scope_extract_root(self, extract_root: Path, subpath: Optional[str]) -> Path:
"""将解压目录缩小到 repo 子路径(若指定)。"""
if not subpath:
return extract_root
clean_subpath = Path(subpath.strip("/\\"))
if str(clean_subpath) == ".":
return extract_root
direct = extract_root / clean_subpath
if direct.exists():
return direct
for child in extract_root.iterdir():
candidate = child / clean_subpath
if candidate.exists():
return candidate
return extract_root
@staticmethod
def _extract_markdown_title(markdown_content: str) -> str:
for line in markdown_content.splitlines():
stripped = line.strip()
if stripped.startswith("#"):
return stripped.lstrip("#").strip()
return ""
def _install_codex_skill_adapter(self, source_dir: Path, target_path: Path, skill_key: str):
"""将 Codex SKILL.md 转换为可加载的项目技能。"""
skill_doc = source_dir / "SKILL.md"
if not skill_doc.exists():
raise FileNotFoundError(f"SKILL.md not found in {source_dir}")
content = skill_doc.read_text(encoding="utf-8")
title = self._extract_markdown_title(content)
description = (
f"Imported from SKILL.md: {title}" if title else f"Imported from SKILL.md ({skill_key})"
)
target_path.mkdir(parents=True, exist_ok=True)
(target_path / "SKILL.md").write_text(content, encoding="utf-8")
(target_path / "__init__.py").write_text("", encoding="utf-8")
metadata = {
"name": skill_key,
"version": "1.0.0",
"description": description,
"author": "imported",
"dependencies": [],
"enabled": True,
}
with open(target_path / "skill.json", "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill"
main_code = f'''"""{skill_key} adapter skill (generated from SKILL.md)"""
from pathlib import Path
from src.ai.skills.base import Skill
class {class_name}(Skill):
async def initialize(self):
self.register_tool("read_skill_doc", self.read_skill_doc)
async def read_skill_doc(self) -> str:
skill_doc = Path(__file__).with_name("SKILL.md")
if not skill_doc.exists():
return "SKILL.md not found."
return skill_doc.read_text(encoding="utf-8")
async def cleanup(self):
pass
'''
(target_path / "main.py").write_text(main_code, encoding="utf-8")
def install_skill_from_source(
self,
source: str,
skill_name: Optional[str] = None,
overwrite: bool = False,
) -> Tuple[bool, str]:
"""从网络或本地源安装技能目录(仅落盘,不自动加载)。"""
desired_key = self.normalize_skill_key(skill_name) if skill_name else None
with tempfile.TemporaryDirectory(prefix="qqbot_skill_") as tmp:
tmp_dir = Path(tmp)
extract_root: Optional[Path] = None
source_hint_key: Optional[str] = None
source_subpath: Optional[str] = None
source_path = Path(source)
if source_path.exists():
if source_path.is_dir():
extract_root = source_path
elif source_path.is_file() and source_path.suffix.lower() == ".zip":
with zipfile.ZipFile(source_path, "r") as zf:
zf.extractall(tmp_dir / "extract")
extract_root = tmp_dir / "extract"
else:
return False, "本地 source 仅支持目录或 zip 文件"
else:
try:
url, source_hint_key, source_subpath = self._resolve_network_source(source)
except ValueError as exc:
return False, str(exc)
download_zip = tmp_dir / "download.zip"
try:
self._download_zip(url, download_zip)
except Exception as exc:
# GitHub 简写默认 main 失败时尝试 master
if "codeload.github.com" in url and url.endswith("/main"):
fallback = url[:-4] + "master"
try:
self._download_zip(fallback, download_zip)
except Exception:
return False, f"下载技能失败: {exc}"
else:
return False, f"下载技能失败: {exc}"
try:
with zipfile.ZipFile(download_zip, "r") as zf:
zf.extractall(tmp_dir / "extract")
except Exception as exc:
return False, f"解压技能失败: {exc}"
extract_root = tmp_dir / "extract"
extract_root = self._scope_extract_root(extract_root, source_subpath)
candidates = self._find_skill_candidates(extract_root)
use_codex_adapter = False
selected_key: Optional[str] = None
selected_path: Optional[Path] = None
if candidates:
if desired_key:
for key, path in candidates:
if key == desired_key:
selected_key, selected_path = key, path
break
if not selected_path:
names = ", ".join([k for k, _ in candidates])
return False, f"源中未找到技能 {desired_key},可选: {names}"
else:
if len(candidates) > 1:
names = ", ".join([k for k, _ in candidates])
return False, f"检测到多个技能,请指定 skill_name。可选: {names}"
selected_key, selected_path = candidates[0]
else:
codex_candidates = self._find_codex_skill_candidates(extract_root)
if not codex_candidates:
return False, "未找到可安装技能(需包含 skill.json 与 main.py或 SKILL.md"
use_codex_adapter = True
if desired_key:
for key, path in codex_candidates:
if key == desired_key:
selected_key, selected_path = key, path
break
if not selected_path:
if len(codex_candidates) == 1:
selected_key = desired_key
selected_path = codex_candidates[0][1]
else:
names = ", ".join([k for k, _ in codex_candidates])
return False, f"源中未找到技能 {desired_key},可选: {names}"
else:
if source_hint_key:
for key, path in codex_candidates:
if key == source_hint_key:
selected_key, selected_path = key, path
break
if not selected_path:
if len(codex_candidates) > 1:
names = ", ".join([k for k, _ in codex_candidates])
return False, f"检测到多个 SKILL.md 技能,请指定 skill_name。可选: {names}"
selected_key, selected_path = codex_candidates[0]
if source_hint_key:
selected_key = source_hint_key
assert selected_key is not None and selected_path is not None
target_path = self._get_skill_path(selected_key)
if target_path.exists():
if not overwrite:
return False, f"技能已存在: {selected_key}"
shutil.rmtree(target_path)
if use_codex_adapter:
self._install_codex_skill_adapter(selected_path, target_path, selected_key)
else:
shutil.copytree(
selected_path,
target_path,
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".git", ".github"),
)
self._ensure_skill_package_layout(target_path, selected_key)
importlib.invalidate_caches()
logger.info(f"✅ 安装技能成功: {selected_key} <- {source}")
return True, selected_key
def create_skill_template(
skill_name: str,
output_dir: Path,
description: str = "技能描述",
author: str = "QQBot",
):
"""创建技能模板。"""
skill_key = SkillsManager.normalize_skill_key(skill_name)
skill_dir = output_dir / skill_key
skill_dir.mkdir(parents=True, exist_ok=True)
metadata = {
"name": skill_key,
"version": "1.0.0",
"description": description,
"author": author,
"dependencies": [],
"enabled": True,
}
with open(skill_dir / "skill.json", "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill"
main_code = f'''"""{skill_key} skill"""
from src.ai.skills.base import Skill
class {class_name}(Skill):
async def initialize(self):
self.register_tool("example_tool", self.example_tool)
async def example_tool(self, text: str) -> str:
return f"{skill_key} 收到: {{text}}"
async def cleanup(self):
pass
'''
with open(skill_dir / "main.py", "w", encoding="utf-8") as f:
f.write(main_code)
with open(skill_dir / "__init__.py", "w", encoding="utf-8") as f:
f.write("")
readme = f"""# {skill_key}
## 描述
{description}
## 工具
- example_tool(text)
"""
with open(skill_dir / "README.md", "w", encoding="utf-8") as f:
f.write(readme)
logger.info(f"✅ 创建技能模板: {skill_dir}")

View File

@@ -1,327 +0,0 @@
"""
长任务管理器 - 处理需要多步骤的复杂任务
"""
import asyncio
from typing import List, Dict, Optional, Callable, Any
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import json
from pathlib import Path
import uuid
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskStep:
"""任务步骤"""
step_id: str
description: str
action: str
params: Dict[str, Any]
status: TaskStatus = TaskStatus.PENDING
result: Optional[Any] = None
error: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
def to_dict(self) -> Dict:
return {
'step_id': self.step_id,
'description': self.description,
'action': self.action,
'params': self.params,
'status': self.status.value,
'result': self.result,
'error': self.error,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class LongTask:
"""长任务"""
task_id: str
user_id: str
title: str
description: str
steps: List[TaskStep] = field(default_factory=list)
status: TaskStatus = TaskStatus.PENDING
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
progress: float = 0.0
metadata: Dict = field(default_factory=dict)
def to_dict(self) -> Dict:
return {
'task_id': self.task_id,
'user_id': self.user_id,
'title': self.title,
'description': self.description,
'steps': [step.to_dict() for step in self.steps],
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'progress': self.progress,
'metadata': self.metadata
}
@classmethod
def from_dict(cls, data: Dict) -> 'LongTask':
steps = [
TaskStep(
step_id=s['step_id'],
description=s['description'],
action=s['action'],
params=s['params'],
status=TaskStatus(s['status']),
result=s.get('result'),
error=s.get('error'),
started_at=datetime.fromisoformat(s['started_at']) if s.get('started_at') else None,
completed_at=datetime.fromisoformat(s['completed_at']) if s.get('completed_at') else None
)
for s in data['steps']
]
return cls(
task_id=data['task_id'],
user_id=data['user_id'],
title=data['title'],
description=data['description'],
steps=steps,
status=TaskStatus(data['status']),
created_at=datetime.fromisoformat(data['created_at']),
started_at=datetime.fromisoformat(data['started_at']) if data.get('started_at') else None,
completed_at=datetime.fromisoformat(data['completed_at']) if data.get('completed_at') else None,
progress=data.get('progress', 0.0),
metadata=data.get('metadata', {})
)
class LongTaskManager:
"""长任务管理器"""
def __init__(self, storage_path: Path):
self.storage_path = storage_path
self.tasks: Dict[str, LongTask] = {}
self.action_handlers: Dict[str, Callable] = {}
self.running_tasks: Dict[str, asyncio.Task] = {}
self._load()
def _load(self):
"""加载任务"""
if self.storage_path.exists():
with open(self.storage_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for task_data in data:
task = LongTask.from_dict(task_data)
self.tasks[task.task_id] = task
def _save(self):
"""保存任务"""
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
data = [task.to_dict() for task in self.tasks.values()]
with open(self.storage_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def register_action(self, action_name: str, handler: Callable):
"""注册动作处理器"""
self.action_handlers[action_name] = handler
def create_task(
self,
user_id: str,
title: str,
description: str,
steps: List[Dict],
metadata: Optional[Dict] = None
) -> str:
"""创建任务"""
task_id = str(uuid.uuid4())
task_steps = [
TaskStep(
step_id=str(uuid.uuid4()),
description=step['description'],
action=step['action'],
params=step.get('params', {})
)
for step in steps
]
task = LongTask(
task_id=task_id,
user_id=user_id,
title=title,
description=description,
steps=task_steps,
metadata=metadata or {}
)
self.tasks[task_id] = task
self._save()
return task_id
async def execute_task(
self,
task_id: str,
progress_callback: Optional[Callable[[str, float, str], None]] = None
) -> bool:
"""执行任务"""
if task_id not in self.tasks:
return False
task = self.tasks[task_id]
if task.status == TaskStatus.RUNNING:
return False
task.status = TaskStatus.RUNNING
task.started_at = datetime.now()
self._save()
try:
total_steps = len(task.steps)
for i, step in enumerate(task.steps):
# 检查是否被取消
if task.status == TaskStatus.CANCELLED:
break
step.status = TaskStatus.RUNNING
step.started_at = datetime.now()
# 执行步骤
try:
handler = self.action_handlers.get(step.action)
if not handler:
raise ValueError(f"未找到动作处理器: {step.action}")
result = await handler(**step.params)
step.result = result
step.status = TaskStatus.COMPLETED
except Exception as e:
step.error = str(e)
step.status = TaskStatus.FAILED
task.status = TaskStatus.FAILED
self._save()
return False
finally:
step.completed_at = datetime.now()
# 更新进度
task.progress = (i + 1) / total_steps
self._save()
if progress_callback:
await progress_callback(
task_id,
task.progress,
f"完成步骤 {i+1}/{total_steps}: {step.description}"
)
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.now()
task.progress = 1.0
self._save()
return True
except Exception as e:
task.status = TaskStatus.FAILED
self._save()
raise e
async def start_task(
self,
task_id: str,
progress_callback: Optional[Callable[[str, float, str], None]] = None
):
"""启动任务(异步)"""
if task_id in self.running_tasks:
return
async def run():
try:
await self.execute_task(task_id, progress_callback)
finally:
if task_id in self.running_tasks:
del self.running_tasks[task_id]
self.running_tasks[task_id] = asyncio.create_task(run())
def pause_task(self, task_id: str) -> bool:
"""暂停任务"""
if task_id not in self.tasks:
return False
task = self.tasks[task_id]
if task.status == TaskStatus.RUNNING:
task.status = TaskStatus.PAUSED
self._save()
return True
return False
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id not in self.tasks:
return False
task = self.tasks[task_id]
task.status = TaskStatus.CANCELLED
self._save()
# 取消正在运行的任务
if task_id in self.running_tasks:
self.running_tasks[task_id].cancel()
return True
def get_task(self, task_id: str) -> Optional[LongTask]:
"""获取任务"""
return self.tasks.get(task_id)
def get_user_tasks(self, user_id: str) -> List[LongTask]:
"""获取用户的所有任务"""
return [
task for task in self.tasks.values()
if task.user_id == user_id
]
def get_task_status(self, task_id: str) -> Optional[Dict]:
"""获取任务状态"""
task = self.get_task(task_id)
if not task:
return None
completed_steps = sum(1 for step in task.steps if step.status == TaskStatus.COMPLETED)
total_steps = len(task.steps)
return {
'task_id': task.task_id,
'title': task.title,
'status': task.status.value,
'progress': task.progress,
'completed_steps': completed_steps,
'total_steps': total_steps,
'current_step': next(
(step.description for step in task.steps if step.status == TaskStatus.RUNNING),
None
)
}

View File

@@ -1,63 +1,77 @@
""" """
JSON文件存储实现(向后兼容) JSON-backed vector store implementation.
""" """
from __future__ import annotations
import asyncio
import json import json
import uuid import uuid
from typing import List, Dict, Optional
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np import numpy as np
from .base import VectorStore, VectorMemory
from .base import VectorMemory, VectorStore
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
logger = setup_logger('JSONStore') logger = setup_logger("JSONStore")
class JSONVectorStore(VectorStore): class JSONVectorStore(VectorStore):
"""JSON文件存储实现(向后兼容旧版本)""" """JSON file storage implementation."""
def __init__(self, storage_path: Path): def __init__(self, storage_path: Path):
"""初始化JSON存储"""
self.storage_path = storage_path self.storage_path = storage_path
self.memories: Dict[str, List[VectorMemory]] = {} # user_id -> List[VectorMemory] self.memories: Dict[str, List[VectorMemory]] = {}
self._lock = asyncio.Lock()
self._load() self._load()
logger.info(f"JSON存储初始化: {storage_path}") logger.info(f"JSON storage initialized: {storage_path}")
def _load(self): def _load(self):
"""加载记忆""" if not self.storage_path.exists():
if self.storage_path.exists(): return
try: try:
with open(self.storage_path, 'r', encoding='utf-8') as f: with open(self.storage_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
except Exception as exc:
for user_id, items in data.items(): logger.error(f"Failed to load memory file: {exc}")
self.memories[user_id] = []
for item in items:
# 兼容旧格式
if 'id' not in item:
item['id'] = str(uuid.uuid4())
memory = VectorMemory.from_dict(item)
self.memories[user_id].append(memory)
logger.info(f"加载了 {sum(len(v) for v in self.memories.values())} 条记忆")
except Exception as e:
logger.error(f"加载记忆失败: {e}")
self.memories = {} self.memories = {}
return
def _save(self): loaded = 0
"""保存记忆""" memories: Dict[str, List[VectorMemory]] = {}
for user_id, items in (data or {}).items():
if not isinstance(items, list):
continue
normalized: List[VectorMemory] = []
for item in items:
if not isinstance(item, dict):
continue
if "id" not in item:
item["id"] = str(uuid.uuid4())
try: try:
normalized.append(VectorMemory.from_dict(item))
loaded += 1
except Exception:
continue
memories[str(user_id)] = normalized
self.memories = memories
logger.info(f"Loaded {loaded} memories from JSON store")
def _save_locked(self):
self.storage_path.parent.mkdir(parents=True, exist_ok=True) self.storage_path.parent.mkdir(parents=True, exist_ok=True)
data = { data = {
user_id: [memory.to_dict() for memory in memories] user_id: [memory.to_dict() for memory in user_memories]
for user_id, memories in self.memories.items() for user_id, user_memories in self.memories.items()
} }
with open(self.storage_path, "w", encoding="utf-8") as f:
with open(self.storage_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"保存记忆失败: {e}")
async def add( async def add(
self, self,
@@ -66,9 +80,8 @@ class JSONVectorStore(VectorStore):
content: str, content: str,
embedding: List[float], embedding: List[float],
importance: float, importance: float,
metadata: Optional[Dict] = None metadata: Optional[Dict] = None,
) -> bool: ) -> bool:
"""添加记忆"""
try: try:
memory = VectorMemory( memory = VectorMemory(
id=id, id=id,
@@ -79,19 +92,14 @@ class JSONVectorStore(VectorStore):
timestamp=datetime.now(), timestamp=datetime.now(),
metadata=metadata or {}, metadata=metadata or {},
access_count=0, access_count=0,
last_access=None last_access=None,
) )
async with self._lock:
if user_id not in self.memories: self.memories.setdefault(user_id, []).append(memory)
self.memories[user_id] = [] self._save_locked()
self.memories[user_id].append(memory)
self._save()
logger.debug(f"添加记忆: {id} (用户: {user_id})")
return True return True
except Exception as e: except Exception as exc:
logger.error(f"添加记忆失败: {e}") logger.error(f"Failed to add memory: {exc}")
return False return False
async def search( async def search(
@@ -99,100 +107,93 @@ class JSONVectorStore(VectorStore):
user_id: str, user_id: str,
query_embedding: List[float], query_embedding: List[float],
limit: int = 5, limit: int = 5,
min_importance: float = 0.3 min_importance: float = 0.3,
) -> List[VectorMemory]: ) -> List[VectorMemory]:
"""搜索相似记忆""" async with self._lock:
if user_id not in self.memories: source = list(self.memories.get(user_id, []))
candidates = [m for m in source if m.importance >= min_importance]
if not candidates:
return [] return []
memories = self.memories[user_id]
# 过滤重要性
memories = [m for m in memories if m.importance >= min_importance]
if not memories:
return []
# 使用向量相似度排序
scored_memories = [] scored_memories = []
for memory in memories: for memory in candidates:
if memory.embedding: if memory.embedding:
similarity = self._cosine_similarity(query_embedding, memory.embedding) similarity = self._cosine_similarity(query_embedding, memory.embedding)
if similarity is not None:
scored_memories.append((similarity, memory)) scored_memories.append((similarity, memory))
if not scored_memories: if not scored_memories:
# 如果没有嵌入向量,按重要性排序
return await self.get_by_importance(user_id, limit, min_importance) return await self.get_by_importance(user_id, limit, min_importance)
scored_memories.sort(reverse=True, key=lambda x: x[0]) scored_memories.sort(reverse=True, key=lambda x: x[0])
return [m for _, m in scored_memories[:limit]] return [memory for _, memory in scored_memories[:limit]]
async def get_by_importance( async def get_by_importance(
self, self,
user_id: str, user_id: str,
limit: int = 5, limit: int = 5,
min_importance: float = 0.3 min_importance: float = 0.3,
) -> List[VectorMemory]: ) -> List[VectorMemory]:
"""按重要性获取记忆""" async with self._lock:
if user_id not in self.memories: source = list(self.memories.get(user_id, []))
return []
memories = [m for m in self.memories[user_id] if m.importance >= min_importance] memories = [m for m in source if m.importance >= min_importance]
memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True) memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True)
return memories[:limit] return memories[:limit]
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: @staticmethod
"""计算余弦相似度""" def _cosine_similarity(vec1: List[float], vec2: List[float]) -> Optional[float]:
vec1 = np.array(vec1) arr1 = np.array(vec1, dtype=float)
vec2 = np.array(vec2) arr2 = np.array(vec2, dtype=float)
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) denom = np.linalg.norm(arr1) * np.linalg.norm(arr2)
if denom == 0:
return None
return float(np.dot(arr1, arr2) / denom)
async def update_access(self, memory_id: str) -> bool: async def update_access(self, memory_id: str) -> bool:
"""更新访问记录"""
try: try:
async with self._lock:
for memories in self.memories.values(): for memories in self.memories.values():
for memory in memories: for memory in memories:
if memory.id == memory_id: if memory.id == memory_id:
memory.access_count += 1 memory.access_count += 1
memory.last_access = datetime.now() memory.last_access = datetime.now()
self._save() self._save_locked()
return True return True
return False return False
except Exception as e: except Exception as exc:
logger.error(f"更新访问记录失败: {e}") logger.error(f"Failed to update access: {exc}")
return False return False
async def delete(self, memory_id: str) -> bool: async def delete(self, memory_id: str) -> bool:
"""删除记忆"""
try: try:
async with self._lock:
for user_id, memories in self.memories.items(): for user_id, memories in self.memories.items():
for i, memory in enumerate(memories): for idx, memory in enumerate(memories):
if memory.id == memory_id: if memory.id == memory_id:
del self.memories[user_id][i] del self.memories[user_id][idx]
self._save() self._save_locked()
return True return True
return False return False
except Exception as e: except Exception as exc:
logger.error(f"删除记忆失败: {e}") logger.error(f"Failed to delete memory: {exc}")
return False return False
async def get_all(self, user_id: str) -> List[VectorMemory]: async def get_all(self, user_id: str) -> List[VectorMemory]:
"""获取用户所有记忆""" async with self._lock:
return self.memories.get(user_id, []) return list(self.memories.get(user_id, []))
async def clear_user(self, user_id: str) -> bool: async def clear_user(self, user_id: str) -> bool:
"""清除用户所有记忆"""
try: try:
if user_id in self.memories: async with self._lock:
del self.memories[user_id] self.memories.pop(user_id, None)
self._save() self._save_locked()
logger.info(f"清除用户记忆: {user_id}")
return True return True
except Exception as e: except Exception as exc:
logger.error(f"清除用户记忆失败: {e}") logger.error(f"Failed to clear user memories: {exc}")
return False return False
async def close(self): async def close(self):
"""关闭连接""" async with self._lock:
self._save() self._save_locked()
logger.info("JSON存储已关闭")

View File

@@ -1,105 +1,71 @@
""" """
QQ机器人主程序 QQ bot application entry module.
基于官方SDK: https://github.com/tencent-connect/botpy
官方文档: https://bot.q.qq.com/wiki/develop/api-v2/
""" """
from __future__ import annotations
import botpy import botpy
from botpy.message import Message from botpy.message import Message
from src.core.config import Config
from src.utils.logger import setup_logger
from src.handlers.message_handler_ai import MessageHandler
logger = setup_logger('QQBot') from src.core.config import Config
from src.handlers.message_handler_ai import MessageHandler
from src.utils.logger import setup_logger
logger = setup_logger("QQBot")
def build_intents() -> botpy.Intents: def build_intents() -> botpy.Intents:
"""
构建最小可用的 intents。
- public_guild_messages: 频道公域 @机器人 消息
- public_messages: 群聊@ + C2C 私聊(好友单聊)消息
"""
intents = botpy.Intents.none() intents = botpy.Intents.none()
intents.public_guild_messages = True intents.public_guild_messages = True
# 新版 botpy 中QQ 群聊@ / C2C 私聊依赖 public_messagesGROUP_AND_C2C_EVENT
if hasattr(intents, "public_messages"): if hasattr(intents, "public_messages"):
intents.public_messages = True intents.public_messages = True
logger.info("✅ 已启用 public_messages(支持群聊@与 C2C 私聊)") logger.info("Enabled public_messages for group/C2C events")
else: else:
logger.warning("⚠️ 当前 botpy 不支持 public_messages可能无法接收 C2C 私聊事件") logger.warning("Current botpy version does not expose public_messages")
return intents return intents
class MyClient(botpy.Client): class MyClient(botpy.Client):
"""QQ机器人客户端""" """QQ bot client wrapper."""
def __init__(self, intents: botpy.Intents): def __init__(self, intents: botpy.Intents):
super().__init__(intents=intents) super().__init__(intents=intents)
self.message_handler = MessageHandler(self) self.message_handler = MessageHandler(self)
async def on_ready(self): async def on_ready(self):
"""机器人启动完成事件""" logger.info(f"Bot is ready: {self.robot.name} (ID: {self.robot.id})")
logger.info(f"🤖 机器人已启动: {self.robot.name} (ID: {self.robot.id})")
async def on_at_message_create(self, message: Message): async def on_at_message_create(self, message: Message):
"""处理@机器人的消息(频道公域消息)"""
await self.message_handler.handle_at_message(message) await self.message_handler.handle_at_message(message)
async def on_message_create(self, message: Message): async def on_message_create(self, message: Message):
"""处理普通消息(需要私域权限)"""
await self.message_handler.handle_at_message(message) await self.message_handler.handle_at_message(message)
async def on_direct_message_create(self, message: Message): async def on_direct_message_create(self, message: Message):
"""处理私信消息"""
await self.message_handler.handle_at_message(message) await self.message_handler.handle_at_message(message)
async def on_group_at_message_create(self, message: Message): async def on_group_at_message_create(self, message: Message):
"""处理群聊@消息"""
await self.message_handler.handle_at_message(message) await self.message_handler.handle_at_message(message)
async def on_c2c_message_create(self, message: Message): async def on_c2c_message_create(self, message: Message):
"""处理C2C消息单聊"""
await self.message_handler.handle_at_message(message) await self.message_handler.handle_at_message(message)
async def on_guild_create(self, guild):
"""机器人加入频道事件"""
logger.info(f" 加入频道: {guild.name} (ID: {guild.id})")
async def on_guild_delete(self, guild):
"""机器人离开频道事件"""
logger.info(f" 离开频道: {guild.name} (ID: {guild.id})")
async def on_error(self, error): async def on_error(self, error):
"""错误处理""" logger.error(f"Bot error: {error}", exc_info=True)
logger.error(f"❌ 发生错误: {error}")
async def on_close(self):
ai_client = getattr(self.message_handler, "ai_client", None)
if ai_client:
try:
await ai_client.close()
except Exception as exc:
logger.warning(f"Failed to close AI resources cleanly: {exc}")
def main(): def main():
"""主函数"""
try:
# 验证配置
Config.validate() Config.validate()
logger.info("✅ 配置验证通过")
# 创建机器人实例 - 使用最小可用 intents避免 disallowed intents(4014)
intents = build_intents() intents = build_intents()
logger.info("✅ Intents 配置完成(最小权限模式)")
client = MyClient(intents=intents) client = MyClient(intents=intents)
# 启动机器人
logger.info("🚀 正在启动机器人...")
client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET) client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET)
except ValueError as e:
logger.error(f"❌ 配置错误: {e}")
logger.error("请检查 .env 文件配置")
except Exception as e:
logger.error(f"❌ 启动失败: {e}")
raise
if __name__ == "__main__":
main()

View File

@@ -1,72 +1,123 @@
""" """
QQ机器人配置管理模块 Centralized project configuration.
""" """
from __future__ import annotations
import os import os
from typing import Optional from typing import Optional, Set
from dotenv import load_dotenv from dotenv import load_dotenv
# 加载环境变量
load_dotenv() load_dotenv()
def _read_env(name: str, default: Optional[str] = None) -> Optional[str]: def _read_env(name: str, default: Optional[str] = None) -> Optional[str]:
"""
读取并清洗环境变量。
- 去除首尾空白
- 空字符串视为未设置
- 以 # 开头的值视为注释占位,视为未设置
"""
value = os.getenv(name) value = os.getenv(name)
if value is None: if value is None:
return default return default
value = value.strip() value = value.strip()
if not value or value.startswith('#'): if not value or value.startswith("#"):
return default return default
return value return value
def _read_bool(name: str, default: bool) -> bool:
raw = _read_env(name, None)
if raw is None:
return default
lowered = raw.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return default
def _read_int(name: str, default: int) -> int:
raw = _read_env(name, None)
if raw is None:
return default
try:
return int(raw)
except ValueError:
return default
def _read_float(name: str, default: float) -> float:
raw = _read_env(name, None)
if raw is None:
return default
try:
return float(raw)
except ValueError:
return default
def _read_csv_set(name: str) -> Set[str]:
raw = _read_env(name, "") or ""
return {item.strip() for item in raw.split(",") if item.strip()}
class Config: class Config:
"""机器人配置类""" """Application runtime configuration."""
# 机器人基本信息 BOT_APPID = _read_env("BOT_APPID", "") or ""
BOT_APPID = _read_env('BOT_APPID', '') or '' BOT_SECRET = _read_env("BOT_SECRET", "") or ""
BOT_SECRET = _read_env('BOT_SECRET', '') or ''
# 日志配置 ENV = _read_env("APP_ENV", "dev") or "dev"
LOG_LEVEL = _read_env('LOG_LEVEL', 'INFO') or 'INFO' LOG_LEVEL = _read_env("LOG_LEVEL", "INFO") or "INFO"
LOG_FORMAT = (_read_env("LOG_FORMAT", "text") or "text").lower()
# 沙箱模式 SANDBOX_MODE = _read_bool("SANDBOX_MODE", False)
SANDBOX_MODE = os.getenv('SANDBOX_MODE', 'False').lower() == 'true'
# AI配置 AI_PROVIDER = _read_env("AI_PROVIDER", "openai") or "openai"
AI_PROVIDER = _read_env('AI_PROVIDER', 'openai') or 'openai' AI_MODEL = _read_env("AI_MODEL", "gpt-4") or "gpt-4"
AI_MODEL = _read_env('AI_MODEL', 'gpt-4') or 'gpt-4' AI_API_KEY = _read_env("AI_API_KEY", "") or ""
AI_API_KEY = _read_env('AI_API_KEY', '') or '' AI_API_BASE = _read_env("AI_API_BASE", None)
AI_API_BASE = _read_env('AI_API_BASE', None)
# AI嵌入模型配置用于RAG AI_EMBED_PROVIDER = _read_env("AI_EMBED_PROVIDER", "openai") or "openai"
AI_EMBED_PROVIDER = _read_env('AI_EMBED_PROVIDER', 'openai') or 'openai' AI_EMBED_MODEL = (
AI_EMBED_MODEL = _read_env('AI_EMBED_MODEL', 'text-embedding-3-small') or 'text-embedding-3-small' _read_env("AI_EMBED_MODEL", "text-embedding-3-small")
AI_EMBED_API_KEY = _read_env('AI_EMBED_API_KEY', None) # 留空则使用 AI_API_KEY or "text-embedding-3-small"
AI_EMBED_API_BASE = _read_env('AI_EMBED_API_BASE', None) # 留空则使用 AI_API_BASE )
AI_EMBED_API_KEY = _read_env("AI_EMBED_API_KEY", None)
AI_EMBED_API_BASE = _read_env("AI_EMBED_API_BASE", None)
# 向量数据库配置 AI_USE_VECTOR_DB = _read_bool("AI_USE_VECTOR_DB", True)
AI_USE_VECTOR_DB = os.getenv('AI_USE_VECTOR_DB', 'true').lower() == 'true' AI_USE_QUERY_EMBEDDING = _read_bool("AI_USE_QUERY_EMBEDDING", False)
# `user`: one memory bucket per user
# `session`: memory bucket scoped to chat session (user + channel/group)
AI_MEMORY_SCOPE = (_read_env("AI_MEMORY_SCOPE", "session") or "session").lower()
AI_CHAT_RETRIES = max(0, _read_int("AI_CHAT_RETRIES", 1))
AI_CHAT_RETRY_BACKOFF_SECONDS = max(
0.0, _read_float("AI_CHAT_RETRY_BACKOFF_SECONDS", 0.8)
)
MESSAGE_DEDUP_SECONDS = max(1, _read_int("MESSAGE_DEDUP_SECONDS", 30))
MESSAGE_DEDUP_MAX_SIZE = max(128, _read_int("MESSAGE_DEDUP_MAX_SIZE", 4096))
BOT_ADMIN_IDS = _read_csv_set("BOT_ADMIN_IDS")
@classmethod @classmethod
def validate(cls): def is_admin(cls, user_id: Optional[str]) -> bool:
"""验证配置是否完整""" if not cls.BOT_ADMIN_IDS:
return True
if not user_id:
return False
return str(user_id) in cls.BOT_ADMIN_IDS
@classmethod
def validate(cls) -> bool:
if not cls.BOT_APPID: if not cls.BOT_APPID:
raise ValueError("BOT_APPID 未配置") raise ValueError("BOT_APPID 未配置")
if not cls.BOT_SECRET: if not cls.BOT_SECRET:
raise ValueError("BOT_SECRET 未配置") raise ValueError("BOT_SECRET 未配置")
# AI配置验证可选 if cls.AI_MEMORY_SCOPE not in {"user", "session"}:
if cls.AI_API_KEY: raise ValueError("AI_MEMORY_SCOPE 仅支持 user 或 session")
print(f"✅ AI配置: {cls.AI_PROVIDER}/{cls.AI_MODEL}")
else:
print("⚠️ AI_API_KEY 未设置AI功能将不可用")
return True return True

View File

@@ -1,6 +1,5 @@
""" """Message handlers."""
消息处理模块
"""
from .message_handler import MessageHandler
__all__ = ['MessageHandler'] from .message_handler_ai import MessageHandler
__all__ = ["MessageHandler"]

File diff suppressed because it is too large Load Diff

View File

@@ -1,63 +1,101 @@
""" """
日志配置模块 Logging helpers.
""" """
from __future__ import annotations
import json
import logging import logging
import os import os
from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Dict
def setup_logger(name='QQBot', level=None): class _JsonFormatter(logging.Formatter):
""" """Simple JSON formatter for structured logs."""
设置日志记录器
Args: _BASE_FIELDS = {
name: 日志记录器名称 "name",
level: 日志级别,默认从环境变量读取 "msg",
"args",
"levelname",
"levelno",
"pathname",
"filename",
"module",
"exc_info",
"exc_text",
"stack_info",
"lineno",
"funcName",
"created",
"msecs",
"relativeCreated",
"thread",
"threadName",
"processName",
"process",
}
Returns: def format(self, record: logging.LogRecord) -> str:
logging.Logger: 配置好的日志记录器 payload: Dict[str, Any] = {
""" "ts": datetime.now(timezone.utc).isoformat(),
# 创建logs目录 "level": record.levelname,
log_dir = Path(__file__).parent.parent.parent / 'logs' "logger": record.name,
"message": record.getMessage(),
"file": record.filename,
"line": record.lineno,
}
for key, value in record.__dict__.items():
if key in self._BASE_FIELDS or key.startswith("_"):
continue
payload[key] = value
if record.exc_info:
payload["exc"] = self.formatException(record.exc_info)
return json.dumps(payload, ensure_ascii=False)
def setup_logger(name: str = "QQBot", level: str | None = None) -> logging.Logger:
log_dir = Path(__file__).parent.parent.parent / "logs"
log_dir.mkdir(exist_ok=True) log_dir.mkdir(exist_ok=True)
# 设置日志级别
if level is None: if level is None:
level = os.getenv('LOG_LEVEL', 'INFO') level = os.getenv("LOG_LEVEL", "INFO")
log_format = (os.getenv("LOG_FORMAT", "text") or "text").lower()
# 创建日志记录器
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level.upper())) logger.setLevel(getattr(logging, str(level).upper(), logging.INFO))
# 避免向 root logger 传播导致重复输出
logger.propagate = False logger.propagate = False
# 避免重复添加处理器
if logger.handlers: if logger.handlers:
return logger return logger
# 控制台处理器
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
file_handler = logging.FileHandler(log_dir / "bot.log", encoding="utf-8")
if log_format == "json":
formatter = _JsonFormatter()
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
else:
console_fmt = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_fmt = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler.setFormatter(console_fmt)
file_handler.setFormatter(file_fmt)
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_format = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_format)
# 文件处理器
file_handler = logging.FileHandler(
log_dir / 'bot.log',
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_format = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_format)
# 添加处理器
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger.addHandler(file_handler) logger.addHandler(file_handler)
return logger return logger

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
from pathlib import Path
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))

View File

@@ -1,628 +0,0 @@
"""AI integration tests."""
import asyncio
import json
import os
from pathlib import Path
import shutil
import stat
import sys
import tempfile
import time
import zipfile
from dotenv import load_dotenv
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.ai import AIClient
from src.ai.base import ModelConfig, ModelProvider
from src.ai.memory import MemorySystem
from src.ai.skills import SkillsManager, create_skill_template
from src.handlers.message_handler_ai import MessageHandler
load_dotenv(project_root / ".env")
TEST_DATA_DIR = Path("data/ai_test")
def _safe_rmtree(path: Path):
if not path.exists():
return
def _onerror(func, target, exc_info):
try:
os.chmod(target, stat.S_IWRITE)
func(target)
except Exception:
pass
for _ in range(3):
try:
shutil.rmtree(path, onerror=_onerror)
return
except PermissionError:
time.sleep(0.2)
def _safe_unlink(path: Path):
if not path.exists():
return
for _ in range(3):
try:
path.unlink()
return
except PermissionError:
time.sleep(0.2)
def _read_env(name: str, default=None):
value = os.getenv(name)
if value is None:
return default
value = value.strip()
if not value or value.startswith("#"):
return default
return value
def get_ai_config() -> ModelConfig:
provider_map = {
"openai": ModelProvider.OPENAI,
"anthropic": ModelProvider.ANTHROPIC,
"deepseek": ModelProvider.DEEPSEEK,
"qwen": ModelProvider.QWEN,
"siliconflow": ModelProvider.OPENAI,
}
provider_str = (_read_env("AI_PROVIDER", "openai") or "openai").lower()
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
return ModelConfig(
provider=provider,
model_name=_read_env("AI_MODEL", "gpt-3.5-turbo") or "gpt-3.5-turbo",
api_key=_read_env("AI_API_KEY", "") or "",
api_base=_read_env("AI_API_BASE"),
temperature=0.7,
)
def get_embed_config() -> ModelConfig:
provider_map = {
"openai": ModelProvider.OPENAI,
"anthropic": ModelProvider.ANTHROPIC,
"deepseek": ModelProvider.DEEPSEEK,
"qwen": ModelProvider.QWEN,
"siliconflow": ModelProvider.OPENAI,
}
provider_str = (_read_env("AI_EMBED_PROVIDER", "openai") or "openai").lower()
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
api_key = _read_env("AI_EMBED_API_KEY") or _read_env("AI_API_KEY", "") or ""
api_base = _read_env("AI_EMBED_API_BASE") or _read_env("AI_API_BASE")
return ModelConfig(
provider=provider,
model_name=_read_env("AI_EMBED_MODEL", "text-embedding-3-small")
or "text-embedding-3-small",
api_key=api_key,
api_base=api_base,
temperature=0.0,
)
class FakeMessage:
def __init__(self, content: str):
from types import SimpleNamespace
self.content = content
self.author = SimpleNamespace(id="test_user")
self.replies = []
async def reply(self, content: str):
self.replies.append(content)
def make_handler() -> MessageHandler:
from types import SimpleNamespace
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
handler = MessageHandler(fake_bot)
handler.ai_client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
handler.skills_manager = SkillsManager(Path("skills"))
handler.model_profiles_path = TEST_DATA_DIR / "models_test.json"
TEST_DATA_DIR.mkdir(parents=True, exist_ok=True)
_safe_unlink(handler.model_profiles_path)
handler._ai_initialized = True
return handler
async def _test_basic_chat():
print("=== test_basic_chat ===")
config = get_ai_config()
if not config.api_key:
print("skip: AI_API_KEY not configured")
return
embed_config = get_embed_config()
client = AIClient(config, embed_config=embed_config, data_dir=TEST_DATA_DIR)
response = await client.chat(
user_id="test_user",
user_message="你好,请介绍一下你自己",
use_memory=False,
use_tools=False,
)
assert response
print("ok: chat response length", len(response))
async def _test_memory():
print("=== test_memory ===")
config = get_ai_config()
if not config.api_key:
print("skip: AI_API_KEY not configured")
return
client = AIClient(config, embed_config=get_embed_config(), data_dir=TEST_DATA_DIR)
await client.chat(user_id="test_user", user_message="鎴戝彨寮犱笁", use_memory=True)
await client.chat(user_id="test_user", user_message="what is my name", use_memory=True)
short_term, long_term = await client.memory.get_context("test_user")
assert len(short_term) >= 2
# 重要性改为模型评估后,是否入长期记忆取决于模型打分,不再固定断言数量。
assert isinstance(long_term, list)
print("ok: memory short/long", len(short_term), len(long_term))
async def _test_personality():
print("=== test_personality ===")
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
names = client.list_personalities()
assert names
assert client.set_personality(names[0])
key = "roleplay_test"
added = client.personality.add_personality(
key,
client.personality.get_personality("default"),
)
assert added
assert key in client.list_personalities()
assert client.personality.remove_personality(key)
assert key not in client.list_personalities()
print("ok: personality add/remove")
async def _test_skills():
print("=== test_skills ===")
manager = SkillsManager(Path("skills"))
assert await manager.load_skill("weather")
tools = manager.get_all_tools()
assert "weather.get_weather" in tools
weather = await tools["weather.get_weather"](city="鍖椾含")
assert weather
assert await manager.load_skill("skills_creator")
tools = manager.get_all_tools()
assert "skills_creator.create_skill" in tools
await manager.unload_skill("weather")
await manager.unload_skill("skills_creator")
print("ok: skills load/unload")
async def _test_skill_commands():
print("=== test_skill_commands ===")
handler = make_handler()
skill_key = f"cmd_zip_skill_{int(time.time() * 1000)}"
# Prepare a zip package source for install testing
tmp_root = TEST_DATA_DIR / "tmp_skill_pkg"
if tmp_root.exists():
_safe_rmtree(tmp_root)
tmp_root.mkdir(parents=True, exist_ok=True)
create_skill_template(skill_key, tmp_root, description="zip skill", author="test")
zip_path = TEST_DATA_DIR / f"{skill_key}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for file in (tmp_root / skill_key).rglob("*"):
if file.is_file():
zf.write(file, file.relative_to(tmp_root))
install_msg = FakeMessage(f"/skills install {zip_path}")
await handler._handle_command(install_msg, install_msg.content)
assert install_msg.replies, "install command no reply"
list_msg = FakeMessage("/skills")
await handler._handle_command(list_msg, list_msg.content)
assert list_msg.replies, "list command no reply"
reload_msg = FakeMessage(f"/skills reload {skill_key}")
await handler._handle_command(reload_msg, reload_msg.content)
assert reload_msg.replies, "reload command no reply"
uninstall_msg = FakeMessage(f"/skills uninstall {skill_key}")
await handler._handle_command(uninstall_msg, uninstall_msg.content)
assert uninstall_msg.replies, "uninstall command no reply"
if tmp_root.exists():
_safe_rmtree(tmp_root)
_safe_unlink(zip_path)
print("ok: skills install/reload/uninstall command")
async def _test_personality_commands():
print("=== test_personality_commands ===")
handler = make_handler()
intro = "You are a hot-blooded anime hero. Speak directly and stay in-character."
add_cmd = (
"/personality add roleplay_hero "
f"{intro}"
)
add_msg = FakeMessage(add_cmd)
await handler._handle_command(add_msg, add_msg.content)
assert add_msg.replies
set_msg = FakeMessage("/personality set roleplay_hero")
await handler._handle_command(set_msg, set_msg.content)
assert set_msg.replies
assert intro in handler.ai_client.personality.get_system_prompt()
remove_msg = FakeMessage("/personality remove roleplay_hero")
await handler._handle_command(remove_msg, remove_msg.content)
assert remove_msg.replies
assert "roleplay_hero" not in handler.ai_client.list_personalities()
print("ok: personality add/set/remove command")
async def _test_model_commands():
print("=== test_model_commands ===")
handler = make_handler()
list_msg = FakeMessage("/models")
await handler._handle_command(list_msg, list_msg.content)
assert list_msg.replies
assert "default" in list_msg.replies[-1].lower()
add_msg = FakeMessage("/models add roleplay_llm openai gpt-4o-mini")
await handler._handle_command(add_msg, add_msg.content)
assert add_msg.replies
assert handler.active_model_key == "roleplay_llm"
assert "roleplay_llm" in handler.model_profiles
switch_msg = FakeMessage("/models switch default")
await handler._handle_command(switch_msg, switch_msg.content)
assert switch_msg.replies
assert handler.active_model_key == "default"
current_msg = FakeMessage("/models current")
await handler._handle_command(current_msg, current_msg.content)
assert current_msg.replies
old_config = handler.ai_client.config
shortcut_model = "Qwen/Qwen2.5-7B-Instruct"
shortcut_key = handler._normalize_model_key(shortcut_model)
shortcut_add_msg = FakeMessage(f"/models add {shortcut_model}")
await handler._handle_command(shortcut_add_msg, shortcut_add_msg.content)
assert shortcut_add_msg.replies
assert handler.active_model_key == shortcut_key
assert handler.ai_client.config.model_name == shortcut_model
assert handler.ai_client.config.provider == old_config.provider
assert handler.ai_client.config.api_base == old_config.api_base
assert handler.ai_client.config.api_key == old_config.api_key
shortcut_remove_msg = FakeMessage(f"/models remove {shortcut_key}")
await handler._handle_command(shortcut_remove_msg, shortcut_remove_msg.content)
assert shortcut_remove_msg.replies
assert shortcut_key not in handler.model_profiles
remove_msg = FakeMessage("/models remove roleplay_llm")
await handler._handle_command(remove_msg, remove_msg.content)
assert remove_msg.replies
assert "roleplay_llm" not in handler.model_profiles
_safe_unlink(handler.model_profiles_path)
print("ok: model add/switch/remove command")
async def _test_memory_commands():
print("=== test_memory_commands ===")
handler = make_handler()
user_id = "test_user"
await handler.ai_client.clear_all_memory(user_id)
add_msg = FakeMessage("/memory add this is a long-term memory test")
await handler._handle_command(add_msg, add_msg.content)
assert add_msg.replies
assert "已新增长期记忆" in add_msg.replies[-1]
memory_id = add_msg.replies[-1].split(": ", 1)[1].split(" ", 1)[0]
assert memory_id
list_msg = FakeMessage("/memory list 5")
await handler._handle_command(list_msg, list_msg.content)
assert list_msg.replies
assert memory_id in list_msg.replies[-1]
get_msg = FakeMessage(f"/memory get {memory_id}")
await handler._handle_command(get_msg, get_msg.content)
assert get_msg.replies
assert memory_id in get_msg.replies[-1]
search_msg = FakeMessage("/memory search 长期记忆")
await handler._handle_command(search_msg, search_msg.content)
assert search_msg.replies
assert memory_id in search_msg.replies[-1]
update_msg = FakeMessage(f"/memory update {memory_id} 这是更新后的长期记忆")
await handler._handle_command(update_msg, update_msg.content)
assert update_msg.replies
assert "已更新长期记忆" in update_msg.replies[-1]
# Build short-term memory then clear only short-term.
await handler.ai_client.memory.add_message(
user_id=user_id,
role="user",
content="short memory for clear short test",
)
assert handler.ai_client.memory.short_term.get(user_id)
clear_short_msg = FakeMessage("/clear short")
await handler._handle_command(clear_short_msg, clear_short_msg.content)
assert clear_short_msg.replies
assert not handler.ai_client.memory.short_term.get(user_id)
# Long-term memory should still exist after clearing short-term only.
still_exists = await handler.ai_client.get_long_term_memory(user_id, memory_id)
assert still_exists is not None
delete_msg = FakeMessage(f"/memory delete {memory_id}")
await handler._handle_command(delete_msg, delete_msg.content)
assert delete_msg.replies
assert "已删除长期记忆" in delete_msg.replies[-1]
removed = await handler.ai_client.get_long_term_memory(user_id, memory_id)
assert removed is None
print("ok: memory command CRUD + clear short")
async def _test_plain_text_output():
print("=== test_plain_text_output ===")
handler = make_handler()
md_text = "# 标题\n**加粗** 和 `代码`\n- 列表\n[链接](https://example.com)"
plain = handler._plain_text(md_text)
assert "#" not in plain
assert "**" not in plain
assert "`" not in plain
assert "[" not in plain
assert "](" not in plain
print("ok: markdown stripped")
async def _test_skills_creator_autoload():
print("=== test_skills_creator_autoload ===")
from types import SimpleNamespace
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
handler = MessageHandler(fake_bot)
handler.model_profiles_path = TEST_DATA_DIR / "models_autoload_test.json"
_safe_unlink(handler.model_profiles_path)
await handler._init_ai()
assert handler.skills_manager is not None
assert "skills_creator" in handler.skills_manager.list_skills()
tool_names = [tool.name for tool in handler.ai_client.tools.list()]
assert "skills_creator.create_skill" in tool_names
print("ok: skills_creator autoloaded")
async def _test_mcp():
print("=== test_mcp ===")
from src.ai.mcp import MCPManager
from src.ai.mcp.servers import FileSystemMCPServer
manager = MCPManager(Path("config/mcp.json"))
fs_server = FileSystemMCPServer(root_path=Path("data"))
await manager.register_server(fs_server)
tools = await manager.get_all_tools_for_ai()
assert len(tools) >= 1
print("ok: mcp tools", len(tools))
async def _test_long_task():
print("=== test_long_task ===")
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
async def step1():
await asyncio.sleep(0.1)
return "step1"
async def step2():
await asyncio.sleep(0.1)
return "step2"
client.task_manager.register_action("step1", step1)
client.task_manager.register_action("step2", step2)
task_id = await client.create_long_task(
user_id="test_user",
title="test",
description="test task",
steps=[
{"description": "s1", "action": "step1", "params": {}},
{"description": "s2", "action": "step2", "params": {}},
],
)
await client.start_task(task_id)
await asyncio.sleep(0.5)
status = client.get_task_status(task_id)
assert status is not None
assert status["status"] in {"completed", "running"}
print("ok: long task", status["status"])
async def _test_memory_importance_evaluator():
print("=== test_memory_importance_evaluator ===")
called = {"value": False}
async def fake_importance_eval(content, metadata):
called["value"] = True
assert "用户:" in content
assert "助手:" in content
return 0.91
store_path = TEST_DATA_DIR / "importance_test.json"
_safe_unlink(store_path)
memory = MemorySystem(
storage_path=store_path,
importance_evaluator=fake_importance_eval,
use_vector_db=False,
)
stored = await memory.add_qa_pair(
user_id="u1",
question="请记住我的昵称是小明",
answer="好的,我记住了你的昵称是小明",
metadata={"source": "test"},
)
assert called["value"]
assert stored is not None
assert "用户:" in stored.content
assert "助手:" in stored.content
assert "小明" in stored.content
long_term = await memory.list_long_term("u1")
assert len(long_term) == 1
# add_message 仅写入短期记忆,不触发长期记忆评分写入。
await memory.add_message(user_id="u1", role="user", content="单条短期消息")
long_term_after_single = await memory.list_long_term("u1")
assert len(long_term_after_single) == 1
memory_without_eval = MemorySystem(
storage_path=TEST_DATA_DIR / "importance_fallback_test.json",
use_vector_db=False,
)
fallback_score = await memory_without_eval._evaluate_importance("任意内容", None)
assert fallback_score == 0.5
await memory.close()
await memory_without_eval.close()
_safe_unlink(store_path)
_safe_unlink(TEST_DATA_DIR / "importance_fallback_test.json")
print("ok: memory importance evaluator")
def test_basic_chat():
asyncio.run(_test_basic_chat())
def test_memory():
asyncio.run(_test_memory())
def test_personality():
asyncio.run(_test_personality())
def test_skills():
asyncio.run(_test_skills())
def test_skill_commands():
asyncio.run(_test_skill_commands())
def test_personality_commands():
asyncio.run(_test_personality_commands())
def test_model_commands():
asyncio.run(_test_model_commands())
def test_memory_commands():
asyncio.run(_test_memory_commands())
def test_plain_text_output():
asyncio.run(_test_plain_text_output())
def test_skills_creator_autoload():
asyncio.run(_test_skills_creator_autoload())
def test_mcp():
asyncio.run(_test_mcp())
def test_long_task():
asyncio.run(_test_long_task())
def test_memory_importance_evaluator():
asyncio.run(_test_memory_importance_evaluator())
async def main():
print("寮€濮?AI 鍔熻兘娴嬭瘯")
await _test_personality()
await _test_skills()
await _test_skill_commands()
await _test_personality_commands()
await _test_model_commands()
await _test_memory_commands()
await _test_plain_text_output()
await _test_skills_creator_autoload()
await _test_mcp()
await _test_long_task()
await _test_memory_importance_evaluator()
config = get_ai_config()
if config.api_key:
await _test_basic_chat()
await _test_memory()
else:
print("跳过需要 API Key 的对话/记忆测试")
print("娴嬭瘯瀹屾垚")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,73 +0,0 @@
"""Tests for AIClient forced tool name extraction."""
from src.ai.client import AIClient
def test_extract_forced_tool_name_full_name():
tools = [
"humanizer_zh.read_skill_doc",
"skills_creator.create_skill",
]
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced == "humanizer_zh.read_skill_doc"
def test_extract_forced_tool_name_unique_prefix():
tools = [
"humanizer_zh.read_skill_doc",
"skills_creator.create_skill",
]
message = "please call tool humanizer_zh only"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced == "humanizer_zh.read_skill_doc"
def test_extract_forced_tool_name_compact_prefix_without_underscore():
tools = [
"humanizer_zh.read_skill_doc",
"skills_creator.create_skill",
]
message = "调用humanizerzh人性化处理以下文本"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced == "humanizer_zh.read_skill_doc"
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
tools = [
"skills_creator.create_skill",
"skills_creator.reload_skill",
]
message = "please call tool skills_creator"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced is None
def test_extract_prefix_limit_from_user_message():
assert AIClient._extract_prefix_limit("直接返回前100字") == 100
assert AIClient._extract_prefix_limit("前 256 字") == 256
assert AIClient._extract_prefix_limit("返回全文") is None
def test_extract_processing_payload_with_marker():
message = "调用humanizer_zh.read_skill_doc人性化处理以下文本\n第一段。\n第二段。"
payload = AIClient._extract_processing_payload(message)
assert payload == "第一段。\n第二段。"
def test_extract_processing_payload_with_generic_pattern():
message = "请按技能规则优化如下:\n这是待处理文本。"
payload = AIClient._extract_processing_payload(message)
assert payload == "这是待处理文本。"
def test_extract_processing_payload_returns_none_when_absent():
assert AIClient._extract_processing_payload("请调用工具 humanizer_zh.read_skill_doc") is None

View File

@@ -1,44 +0,0 @@
import asyncio
from pathlib import Path
from src.ai.mcp.base import MCPManager, MCPServer
class _DummyMCPServer(MCPServer):
def __init__(self):
super().__init__(name="dummy", version="1.0.0")
async def initialize(self):
self.register_tool(
name="echo",
description="Echo text",
input_schema={
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"],
},
handler=self.echo,
)
async def echo(self, text: str) -> str:
return text
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
tools = asyncio.run(manager.get_all_tools_for_ai())
assert len(tools) == 1
function_info = tools[0]["function"]
assert function_info["name"] == "dummy.echo"
assert function_info["description"] == "Echo text"
assert function_info["parameters"]["required"] == ["text"]
def test_mcp_manager_execute_tool(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
assert result == "hello"

View File

@@ -0,0 +1,22 @@
from types import SimpleNamespace
from src.handlers.message_handler_ai import MessageHandler
def _build_handler() -> MessageHandler:
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="bot_1", name="TestBot"))
return MessageHandler(fake_bot)
def test_message_dedup_by_message_id():
handler = _build_handler()
msg = SimpleNamespace(id="m1", content="hello", author=SimpleNamespace(id="u1"))
assert handler._is_duplicate_message(msg) is False
assert handler._is_duplicate_message(msg) is True
def test_message_dedup_fallback_without_message_id():
handler = _build_handler()
msg = SimpleNamespace(content="hello", author=SimpleNamespace(id="u1"), group_id="g1")
assert handler._is_duplicate_message(msg) is False
assert handler._is_duplicate_message(msg) is True

View File

@@ -10,16 +10,14 @@ from src.handlers.message_handler_ai import MessageHandler
def _handler() -> MessageHandler: def _handler() -> MessageHandler:
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot")) fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot", name="TestBot"))
return MessageHandler(fake_bot) return MessageHandler(fake_bot)
def test_plain_text_removes_markdown_link_url(): def test_plain_text_removes_markdown_link_url():
handler = _handler() handler = _handler()
text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。" text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。"
result = handler._plain_text(text) result = handler._plain_text(text)
assert "Wikipedia" in result assert "Wikipedia" in result
assert "http" not in result.lower() assert "http" not in result.lower()
@@ -27,9 +25,7 @@ def test_plain_text_removes_markdown_link_url():
def test_plain_text_removes_bare_url(): def test_plain_text_removes_bare_url():
handler = _handler() handler = _handler()
text = "访问 https://example.com/path?a=1 或 www.example.org 查看。" text = "访问 https://example.com/path?a=1 或 www.example.org 查看。"
result = handler._plain_text(text) result = handler._plain_text(text)
assert "http" not in result.lower() assert "http" not in result.lower()
assert "www." not in result.lower() assert "www." not in result.lower()
assert "[链接已省略]" in result assert "[链接已省略]" in result

View File

@@ -0,0 +1,44 @@
from pathlib import Path
from src.ai.personality import PersonalityProfile, PersonalitySystem, PersonalityTrait
def _profile(name: str) -> PersonalityProfile:
return PersonalityProfile(
name=name,
description=f"{name} profile",
traits=[PersonalityTrait.FRIENDLY],
speaking_style="plain",
)
def test_scope_priority_session_over_group_over_user_over_global(tmp_path: Path):
cfg = tmp_path / "personalities.json"
state = tmp_path / "personality_state.json"
system = PersonalitySystem(config_path=cfg, state_path=state)
system.add_personality("p_global", _profile("global"))
system.add_personality("p_user", _profile("user"))
system.add_personality("p_group", _profile("group"))
system.add_personality("p_session", _profile("session"))
assert system.set_personality("p_global", scope="global")
assert system.set_personality("p_user", scope="user", scope_id="u1")
assert system.set_personality("p_group", scope="group", scope_id="g1")
assert system.set_personality("p_session", scope="session", scope_id="g1:u1")
profile = system.get_active_personality(user_id="u1", group_id="g1", session_id="g1:u1")
assert profile is not None
assert profile.name == "session"
profile_no_session = system.get_active_personality(user_id="u1", group_id="g1", session_id="other")
assert profile_no_session is not None
assert profile_no_session.name == "group"
profile_user_only = system.get_active_personality(user_id="u1", group_id=None, session_id=None)
assert profile_user_only is not None
assert profile_user_only.name == "user"
profile_global = system.get_active_personality(user_id="u2", group_id=None, session_id=None)
assert profile_global is not None
assert profile_global.name == "global"

View File

@@ -1,74 +0,0 @@
import io
from pathlib import Path
import zipfile
from src.ai.skills.base import SkillsManager
def _build_codex_skill_zip_bytes(markdown_text: str, root_name: str = "Humanizer-zh-main") -> bytes:
buffer = io.BytesIO()
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(f"{root_name}/SKILL.md", markdown_text)
return buffer.getvalue()
def test_resolve_network_source_supports_github_git_url(tmp_path: Path):
manager = SkillsManager(tmp_path / "skills")
url, hint_key, subpath = manager._resolve_network_source(
"https://github.com/op7418/Humanizer-zh.git"
)
assert url == "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
assert hint_key == "humanizer_zh"
assert subpath is None
def test_install_skill_from_local_skill_markdown_source(tmp_path: Path):
source_dir = tmp_path / "Humanizer-zh-main"
source_dir.mkdir(parents=True, exist_ok=True)
(source_dir / "SKILL.md").write_text(
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n",
encoding="utf-8",
)
manager = SkillsManager(tmp_path / "skills")
ok, installed_key = manager.install_skill_from_source(str(source_dir), skill_name="humanizer_zh")
assert ok
assert installed_key == "humanizer_zh"
installed_dir = tmp_path / "skills" / "humanizer_zh"
assert (installed_dir / "skill.json").exists()
assert (installed_dir / "main.py").exists()
assert (installed_dir / "SKILL.md").exists()
main_code = (installed_dir / "main.py").read_text(encoding="utf-8")
assert "read_skill_doc" in main_code
skill_text = (installed_dir / "SKILL.md").read_text(encoding="utf-8")
assert "Humanizer-zh" in skill_text
def test_install_skill_from_github_git_url_uses_repo_zip_and_markdown_adapter(
tmp_path: Path, monkeypatch
):
manager = SkillsManager(tmp_path / "skills")
zip_bytes = _build_codex_skill_zip_bytes(
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n"
)
captured_urls = []
def fake_download(url: str, output_zip: Path):
captured_urls.append(url)
output_zip.write_bytes(zip_bytes)
monkeypatch.setattr(manager, "_download_zip", fake_download)
ok, installed_key = manager.install_skill_from_source(
"https://github.com/op7418/Humanizer-zh.git"
)
assert ok
assert installed_key == "humanizer_zh"
assert captured_urls == [
"https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
]
assert (tmp_path / "skills" / "humanizer_zh" / "SKILL.md").exists()