Refactor configuration and enhance AI capabilities
Updated .env.example to improve clarity and added new configuration options for memory and reliability settings. Refactored main.py to streamline the bot's entry point and improved error handling. Enhanced README to reflect new features and command structure. Removed deprecated cmd_zip_skill and skills_creator modules to clean up the codebase. Updated AIClient and MemorySystem for better performance and flexibility in handling user interactions.
This commit is contained in:
36
.env.example
36
.env.example
@@ -1,36 +1,36 @@
|
|||||||
# QQ 机器人配置
|
# QQ Bot credentials
|
||||||
# 可从 https://bot.q.qq.com/open 获取
|
|
||||||
|
|
||||||
# 机器人 AppID(必填)
|
|
||||||
BOT_APPID=your_app_id_here
|
BOT_APPID=your_app_id_here
|
||||||
|
|
||||||
# 机器人 AppSecret(必填)
|
|
||||||
BOT_SECRET=your_app_secret_here
|
BOT_SECRET=your_app_secret_here
|
||||||
|
|
||||||
# 日志级别: DEBUG / INFO / WARNING / ERROR
|
# Runtime
|
||||||
|
APP_ENV=dev
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
|
LOG_FORMAT=text
|
||||||
# 是否启用沙箱环境
|
|
||||||
SANDBOX_MODE=False
|
SANDBOX_MODE=False
|
||||||
|
|
||||||
# ==================== AI 配置 ====================
|
# Optional admin allow-list (comma separated user IDs).
|
||||||
|
# Empty means all users are treated as admin.
|
||||||
|
BOT_ADMIN_IDS=
|
||||||
|
|
||||||
# 主模型配置
|
# AI chat model
|
||||||
# 可选 provider: openai / anthropic / deepseek / qwen
|
|
||||||
AI_PROVIDER=openai
|
AI_PROVIDER=openai
|
||||||
AI_MODEL=gpt-4
|
AI_MODEL=gpt-4
|
||||||
AI_API_KEY=your_api_key_here
|
AI_API_KEY=your_api_key_here
|
||||||
# 可选,自定义 API 地址
|
|
||||||
AI_API_BASE=https://api.openai.com/v1
|
AI_API_BASE=https://api.openai.com/v1
|
||||||
|
|
||||||
# 嵌入模型配置(用于长期记忆检索)
|
# Embedding model (optional)
|
||||||
# 不配置时将回退为主模型的 embedding 能力(如果可用)
|
|
||||||
AI_EMBED_PROVIDER=openai
|
AI_EMBED_PROVIDER=openai
|
||||||
AI_EMBED_MODEL=text-embedding-3-small
|
AI_EMBED_MODEL=text-embedding-3-small
|
||||||
AI_EMBED_API_KEY=
|
AI_EMBED_API_KEY=
|
||||||
AI_EMBED_API_BASE=
|
AI_EMBED_API_BASE=
|
||||||
|
|
||||||
# 向量数据库配置
|
# Memory storage and retrieval
|
||||||
# true: 使用 Chroma(推荐)
|
|
||||||
# false: 使用 JSON 存储
|
|
||||||
AI_USE_VECTOR_DB=true
|
AI_USE_VECTOR_DB=true
|
||||||
|
AI_USE_QUERY_EMBEDDING=false
|
||||||
|
AI_MEMORY_SCOPE=session
|
||||||
|
|
||||||
|
# Reliability
|
||||||
|
AI_CHAT_RETRIES=1
|
||||||
|
AI_CHAT_RETRY_BACKOFF_SECONDS=0.8
|
||||||
|
MESSAGE_DEDUP_SECONDS=30
|
||||||
|
MESSAGE_DEDUP_MAX_SIZE=4096
|
||||||
|
|||||||
131
README.md
131
README.md
@@ -1,14 +1,16 @@
|
|||||||
# QQbot(AI 聊天机器人)
|
# QQbot (Memory + Persona Core)
|
||||||
|
|
||||||
一个基于 `botpy` 的 QQ 机器人项目,支持多模型切换、长期/短期记忆、人设管理、Skills 插件与 MCP 能力。
|
QQ 机器人项目,保留并强化两大核心能力:
|
||||||
|
- `Memory`:短期/长期记忆、检索、清理、会话作用域
|
||||||
|
- `Persona`:角色配置、作用域优先级(session > group > user > global)
|
||||||
|
|
||||||
## 功能概览
|
## 主要能力
|
||||||
|
|
||||||
- 多模型配置与运行时切换(`/models`)
|
- 多模型配置与运行时切换:`/models`
|
||||||
- 人设增删改切换(`/personality`)
|
- 人设管理:`/personality`
|
||||||
- 短期/长期记忆管理(`/clear`、`/memory`)
|
- 记忆管理:`/memory`、`/clear`
|
||||||
- Skills 本地与网络安装/卸载/重载(`/skills`)
|
- QQ 消息安全输出:自动清理 Markdown/URL
|
||||||
- 自动去除 Markdown 格式后再回复(适配 QQ 聊天)
|
- 工程增强:消息去重、失败重试、权限边界、结构化日志
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
@@ -20,7 +22,9 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
2. 配置环境变量
|
2. 配置环境变量
|
||||||
|
|
||||||
复制 `.env.example` 为 `.env`,填写 QQ 机器人和 AI 配置。
|
```bash
|
||||||
|
copy .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
3. 启动
|
3. 启动
|
||||||
|
|
||||||
@@ -28,89 +32,42 @@ pip install -r requirements.txt
|
|||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## 命令说明
|
## 命令
|
||||||
|
|
||||||
### 通用
|
- 基础
|
||||||
|
- `/help`
|
||||||
|
- `/clear` `/clear short` `/clear long` `/clear all`
|
||||||
|
- 人设
|
||||||
|
- `/personality`
|
||||||
|
- `/personality list`
|
||||||
|
- `/personality set <key> [global|user|group|session]`
|
||||||
|
- `/personality add <name> <Introduction>`
|
||||||
|
- `/personality remove <key>`
|
||||||
|
- 模型
|
||||||
|
- `/models`
|
||||||
|
- `/models current`
|
||||||
|
- `/models add <model_name>`
|
||||||
|
- `/models add <key> <provider> <model_name> [api_base]`
|
||||||
|
- `/models switch <key|index>`
|
||||||
|
- `/models remove <key|index>`
|
||||||
|
- 记忆
|
||||||
|
- `/memory`
|
||||||
|
- `/memory get <id>`
|
||||||
|
- `/memory add <content|json>`
|
||||||
|
- `/memory update <id> <content|json>`
|
||||||
|
- `/memory delete <id>`
|
||||||
|
- `/memory search <query> [limit]`
|
||||||
|
|
||||||
- `/help`
|
## 关键配置
|
||||||
- `/clear`(默认等价 `/clear short`)
|
|
||||||
- `/clear short`
|
|
||||||
- `/clear long`
|
|
||||||
- `/clear all`
|
|
||||||
|
|
||||||
### 人设命令
|
- `AI_MEMORY_SCOPE=user|session`:记忆作用域
|
||||||
|
- `BOT_ADMIN_IDS`:管理员白名单(逗号分隔)
|
||||||
- `/personality`
|
- `AI_CHAT_RETRIES` / `AI_CHAT_RETRY_BACKOFF_SECONDS`:聊天失败重试
|
||||||
- `/personality list`
|
- `MESSAGE_DEDUP_SECONDS` / `MESSAGE_DEDUP_MAX_SIZE`:消息去重窗口
|
||||||
- `/personality set <key>`
|
- `LOG_FORMAT=text|json`:日志输出格式
|
||||||
- `/personality add <name> <Introduction>`
|
|
||||||
- `/personality remove <key>`
|
|
||||||
|
|
||||||
说明:
|
|
||||||
- `add` 会新增并切换到该人设
|
|
||||||
- `Introduction` 会作为人设简介与自定义指令
|
|
||||||
|
|
||||||
### Skills 命令
|
|
||||||
|
|
||||||
- `/skills`
|
|
||||||
- `/skills install <source> [skill_name]`
|
|
||||||
- `/skills uninstall <skill_name>`
|
|
||||||
- `/skills reload <skill_name>`
|
|
||||||
|
|
||||||
`source` 支持:
|
|
||||||
- 本地技能名(如 `weather`)
|
|
||||||
- URL(zip 包)
|
|
||||||
- GitHub 简写(`owner/repo` 或 `owner/repo#branch`)
|
|
||||||
- GitHub 仓库 URL(如 `https://github.com/op7418/Humanizer-zh.git`)
|
|
||||||
|
|
||||||
兼容说明:
|
|
||||||
- 若源中包含标准技能结构(`skill.json` + `main.py`),按原方式安装
|
|
||||||
- 若仅包含 `SKILL.md`,会自动生成适配技能并提供 `read_skill_doc` 工具读取文档内容
|
|
||||||
|
|
||||||
### 模型命令
|
|
||||||
|
|
||||||
- `/models`
|
|
||||||
- `/models current`
|
|
||||||
- `/models add <model_name>`
|
|
||||||
- `/models add <key> <provider> <model_name> [api_base]`
|
|
||||||
- `/models switch <key|index>`
|
|
||||||
- `/models remove <key|index>`
|
|
||||||
|
|
||||||
说明:
|
|
||||||
- `/models add <model_name>` 只替换模型名,沿用当前 API Base 和 API Key
|
|
||||||
|
|
||||||
### 长期记忆命令
|
|
||||||
|
|
||||||
- `/memory`
|
|
||||||
- `/memory get <id>`
|
|
||||||
- `/memory add <content|json>`
|
|
||||||
- `/memory update <id> <content|json>`
|
|
||||||
- `/memory delete <id>`
|
|
||||||
- `/memory search <query> [limit]`
|
|
||||||
|
|
||||||
## 目录结构
|
|
||||||
|
|
||||||
```text
|
|
||||||
QQbot/
|
|
||||||
├─ src/
|
|
||||||
│ ├─ ai/
|
|
||||||
│ ├─ handlers/
|
|
||||||
│ ├─ core/
|
|
||||||
│ └─ utils/
|
|
||||||
├─ skills/
|
|
||||||
├─ config/
|
|
||||||
├─ docs/
|
|
||||||
└─ tests/
|
|
||||||
```
|
|
||||||
|
|
||||||
## 测试
|
## 测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m pytest -q
|
pytest -q
|
||||||
```
|
|
||||||
|
|
||||||
如果你使用 conda 环境,请先执行:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
conda activate qqbot
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"filesystem": {
|
|
||||||
"enabled": true,
|
|
||||||
"root_path": "data"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
41
main.py
41
main.py
@@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
QQ机器人主入口
|
Project entrypoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
|
||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
@@ -23,10 +25,6 @@ def _sqlite_supports_trigram(sqlite_module) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_sqlite_for_chroma():
|
def _ensure_sqlite_for_chroma():
|
||||||
"""
|
|
||||||
Ensure sqlite runtime supports FTS5 trigram tokenizer for Chroma.
|
|
||||||
On some cloud images, system sqlite lacks trigram support.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
import sqlite3
|
import sqlite3
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -55,35 +53,8 @@ def _ensure_sqlite_for_chroma():
|
|||||||
|
|
||||||
_ensure_sqlite_for_chroma()
|
_ensure_sqlite_for_chroma()
|
||||||
|
|
||||||
from src.core.bot import MyClient, build_intents
|
from src.core.bot import main as run_bot_main
|
||||||
from src.core.config import Config
|
|
||||||
from src.utils.logger import setup_logger
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
# 设置日志
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证配置
|
|
||||||
Config.validate()
|
|
||||||
logger.info("配置验证通过")
|
|
||||||
|
|
||||||
# 创建并启动机器人(最小权限,避免 4014 disallowed intents)
|
|
||||||
logger.info("正在启动QQ机器人...")
|
|
||||||
intents = build_intents()
|
|
||||||
client = MyClient(intents=intents)
|
|
||||||
client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"配置错误: {e}")
|
|
||||||
logger.error("请检查 .env 文件配置")
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"启动失败: {e}", exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
run_bot_main()
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
# ============================================
|
# Core runtime
|
||||||
# 核心依赖(必须安装)
|
|
||||||
# ============================================
|
|
||||||
qq-botpy
|
qq-botpy
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
# ============================================
|
# AI providers
|
||||||
# AI 功能依赖(可选)
|
|
||||||
# 如果不需要 AI 对话功能,可以注释掉下面的依赖
|
|
||||||
# ============================================
|
|
||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
anthropic>=0.18.0
|
anthropic>=0.18.0
|
||||||
|
|
||||||
|
# Memory storage
|
||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
chromadb>=0.4.0 # 向量数据库,用于记忆存储
|
chromadb>=0.4.0
|
||||||
pysqlite3-binary>=0.5.3; platform_system != "Windows" # 云端可用于补齐 sqlite trigram 支持
|
pysqlite3-binary>=0.5.3; platform_system != "Windows"
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkillSkill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465404375
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465404375 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465404375Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465404375 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465404375",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465434774
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465434774 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465434774Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465434774 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465434774",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465467809
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465467809 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465467809Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465467809 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465467809",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465652075
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465652075 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465652075Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465652075 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465652075",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465685352
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465685352 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465685352Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465685352 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465685352",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465936294
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465936294 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465936294Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465936294 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465936294",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772465966322
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772465966322 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772465966322Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772465966322 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772465966322",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# cmd_zip_skill_1772466071278
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
zip skill
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""cmd_zip_skill_1772466071278 skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class CmdZipSkill1772466071278Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"cmd_zip_skill_1772466071278 收到: {text}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "cmd_zip_skill_1772466071278",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "zip skill",
|
|
||||||
"author": "test",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
# skills_creator
|
|
||||||
|
|
||||||
Tools exposed to AI:
|
|
||||||
- create_skill
|
|
||||||
- write_skill_main
|
|
||||||
- update_skill_metadata
|
|
||||||
- delete_skill
|
|
||||||
- load_skill
|
|
||||||
- reload_skill
|
|
||||||
- list_skills
|
|
||||||
|
|
||||||
This skill enables creating and managing other skills from chat-driven AI tool calls.
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
"""skills_creator skill"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from src.ai.skills.base import Skill, SkillsManager, create_skill_template
|
|
||||||
|
|
||||||
|
|
||||||
class SkillsCreatorSkill(Skill):
|
|
||||||
"""Provide tools to create and manage local skills from chat."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.skills_root = Path("skills")
|
|
||||||
self.protected_skills = {"skills_creator"}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("create_skill", self.create_skill)
|
|
||||||
self.register_tool("write_skill_main", self.write_skill_main)
|
|
||||||
self.register_tool("update_skill_metadata", self.update_skill_metadata)
|
|
||||||
self.register_tool("delete_skill", self.delete_skill)
|
|
||||||
self.register_tool("load_skill", self.load_skill)
|
|
||||||
self.register_tool("reload_skill", self.reload_skill)
|
|
||||||
self.register_tool("list_skills", self.list_skills)
|
|
||||||
|
|
||||||
def _skill_key(self, skill_name: str) -> str:
|
|
||||||
return SkillsManager.normalize_skill_key(skill_name)
|
|
||||||
|
|
||||||
def _skill_path(self, skill_name: str) -> Path:
|
|
||||||
return self.skills_root / self._skill_key(skill_name)
|
|
||||||
|
|
||||||
async def create_skill(
|
|
||||||
self,
|
|
||||||
skill_name: str,
|
|
||||||
description: str = "custom skill",
|
|
||||||
author: str = "chat_user",
|
|
||||||
auto_load: bool = True,
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
skill_path = self._skill_path(skill_key)
|
|
||||||
|
|
||||||
if skill_path.exists() and not overwrite:
|
|
||||||
return f"技能已存在: {skill_key}"
|
|
||||||
|
|
||||||
if skill_path.exists() and overwrite:
|
|
||||||
shutil.rmtree(skill_path)
|
|
||||||
|
|
||||||
create_skill_template(skill_key, self.skills_root, description=description, author=author)
|
|
||||||
|
|
||||||
if auto_load and self.manager:
|
|
||||||
await self.manager.load_skill(skill_key)
|
|
||||||
|
|
||||||
return f"已创建技能: {skill_key}"
|
|
||||||
|
|
||||||
async def write_skill_main(self, skill_name: str, code: str, auto_reload: bool = True) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
skill_path = self._skill_path(skill_key)
|
|
||||||
if not skill_path.exists():
|
|
||||||
return f"技能不存在: {skill_key}"
|
|
||||||
|
|
||||||
main_file = skill_path / "main.py"
|
|
||||||
main_file.write_text(code, encoding="utf-8")
|
|
||||||
|
|
||||||
if auto_reload and self.manager:
|
|
||||||
await self.manager.reload_skill(skill_key)
|
|
||||||
|
|
||||||
return f"已更新技能代码: {skill_key}/main.py"
|
|
||||||
|
|
||||||
async def update_skill_metadata(self, skill_name: str, fields: Dict[str, Any]) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
skill_path = self._skill_path(skill_key)
|
|
||||||
if not skill_path.exists():
|
|
||||||
return f"技能不存在: {skill_key}"
|
|
||||||
|
|
||||||
metadata_file = skill_path / "skill.json"
|
|
||||||
if metadata_file.exists():
|
|
||||||
metadata = json.loads(metadata_file.read_text(encoding="utf-8"))
|
|
||||||
else:
|
|
||||||
metadata = {
|
|
||||||
"name": skill_key,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "",
|
|
||||||
"author": "chat_user",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in fields.items():
|
|
||||||
metadata[key] = value
|
|
||||||
|
|
||||||
metadata.setdefault("name", skill_key)
|
|
||||||
metadata.setdefault("version", "1.0.0")
|
|
||||||
metadata.setdefault("description", "")
|
|
||||||
metadata.setdefault("author", "chat_user")
|
|
||||||
metadata.setdefault("dependencies", [])
|
|
||||||
metadata.setdefault("enabled", True)
|
|
||||||
|
|
||||||
metadata_file.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
||||||
return f"已更新技能元数据: {skill_key}/skill.json"
|
|
||||||
|
|
||||||
async def delete_skill(self, skill_name: str, delete_files: bool = True) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
if skill_key in self.protected_skills:
|
|
||||||
return f"拒绝删除受保护技能: {skill_key}"
|
|
||||||
|
|
||||||
if self.manager:
|
|
||||||
removed = await self.manager.uninstall_skill(skill_key, delete_files=delete_files)
|
|
||||||
return f"删除结果({skill_key}): {removed}"
|
|
||||||
|
|
||||||
skill_path = self._skill_path(skill_key)
|
|
||||||
if delete_files and skill_path.exists():
|
|
||||||
shutil.rmtree(skill_path)
|
|
||||||
return f"已删除技能目录: {skill_key}"
|
|
||||||
|
|
||||||
return f"技能不存在或未删除: {skill_key}"
|
|
||||||
|
|
||||||
async def load_skill(self, skill_name: str) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
if not self.manager:
|
|
||||||
return "SkillsManager 不可用"
|
|
||||||
|
|
||||||
success = await self.manager.load_skill(skill_key)
|
|
||||||
return f"加载技能 {skill_key}: {success}"
|
|
||||||
|
|
||||||
async def reload_skill(self, skill_name: str) -> str:
|
|
||||||
skill_key = self._skill_key(skill_name)
|
|
||||||
if not self.manager:
|
|
||||||
return "SkillsManager 不可用"
|
|
||||||
|
|
||||||
success = await self.manager.reload_skill(skill_key)
|
|
||||||
return f"重载技能 {skill_key}: {success}"
|
|
||||||
|
|
||||||
async def list_skills(self) -> str:
|
|
||||||
if not self.manager:
|
|
||||||
return "SkillsManager 不可用"
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"loaded": self.manager.list_skills(),
|
|
||||||
"available": self.manager.list_available_skills(),
|
|
||||||
}
|
|
||||||
return json.dumps(payload, ensure_ascii=False)
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "skills_creator",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "Create, edit, load, and remove skills from chat",
|
|
||||||
"author": "QQBot",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": true
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,7 @@
|
|||||||
"""
|
"""AI package exports."""
|
||||||
AI模块 - 提供AI模型接入、人格系统、记忆系统和长任务处理能力
|
|
||||||
"""
|
|
||||||
from .client import AIClient
|
|
||||||
from .personality import PersonalitySystem
|
|
||||||
from .memory import MemorySystem
|
|
||||||
from .task_manager import LongTaskManager
|
|
||||||
|
|
||||||
__all__ = [
|
from .client import AIClient
|
||||||
'AIClient',
|
from .memory import MemorySystem
|
||||||
'PersonalitySystem',
|
from .personality import PersonalitySystem
|
||||||
'MemorySystem',
|
|
||||||
'LongTaskManager'
|
__all__ = ["AIClient", "MemorySystem", "PersonalitySystem"]
|
||||||
]
|
|
||||||
|
|||||||
843
src/ai/client.py
843
src/ai/client.py
@@ -1,98 +1,97 @@
|
|||||||
"""
|
|
||||||
AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘
|
|
||||||
"""
|
"""
|
||||||
import inspect
|
Unified AI client for chat, memory and persona.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Dict, Any, AsyncIterator, Tuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .base import ModelConfig, ModelProvider, Message, ToolRegistry
|
from typing import Any, Dict, List, Optional
|
||||||
from .models import OpenAIModel, AnthropicModel
|
|
||||||
from .personality import PersonalitySystem
|
import httpx
|
||||||
|
|
||||||
|
from .base import Message, ModelConfig, ModelProvider
|
||||||
from .memory import MemorySystem
|
from .memory import MemorySystem
|
||||||
from .task_manager import LongTaskManager
|
from .models import AnthropicModel, OpenAIModel
|
||||||
|
from .personality import PersonalitySystem
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger('AIClient')
|
logger = setup_logger("AIClient")
|
||||||
|
|
||||||
|
|
||||||
class AIClient:
|
class AIClient:
|
||||||
"""AI瀹㈡埛绔?- 缁熶竴鎺ュ彛"""
|
"""High-level application service for chat and memory/persona orchestration."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
embed_config: Optional[ModelConfig] = None,
|
embed_config: Optional[ModelConfig] = None,
|
||||||
data_dir: Path = Path("data/ai"),
|
data_dir: Path = Path("data/ai"),
|
||||||
use_vector_db: bool = True
|
use_vector_db: bool = True,
|
||||||
|
use_query_embedding: bool = False,
|
||||||
|
chat_retries: int = 1,
|
||||||
|
chat_retry_backoff: float = 0.8,
|
||||||
):
|
):
|
||||||
self.config = model_config
|
self.config = model_config
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# 初始化主模型
|
self.chat_retries = max(0, int(chat_retries))
|
||||||
|
self.chat_retry_backoff = max(0.0, float(chat_retry_backoff))
|
||||||
|
|
||||||
self.model = self._create_model(model_config)
|
self.model = self._create_model(model_config)
|
||||||
|
self.embed_model = self._create_model(embed_config) if embed_config else None
|
||||||
|
|
||||||
# 初始化嵌入模型(如果提供)
|
|
||||||
self.embed_model = None
|
|
||||||
if embed_config:
|
|
||||||
self.embed_model = self._create_model(embed_config)
|
|
||||||
logger.info(
|
|
||||||
f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化工具注册表
|
|
||||||
self.tools = ToolRegistry()
|
|
||||||
self._tool_sources: Dict[str, str] = {}
|
|
||||||
|
|
||||||
# 初始化人格系统
|
|
||||||
self.personality = PersonalitySystem(
|
self.personality = PersonalitySystem(
|
||||||
config_path=data_dir / "personalities.json"
|
config_path=data_dir / "personalities.json",
|
||||||
|
state_path=data_dir / "personality_state.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化记忆系统
|
|
||||||
self.memory = MemorySystem(
|
self.memory = MemorySystem(
|
||||||
storage_path=data_dir / "long_term_memory.json",
|
storage_path=data_dir / "long_term_memory.json",
|
||||||
embed_func=self._embed_wrapper,
|
embed_func=self._embed_wrapper,
|
||||||
importance_evaluator=self._evaluate_memory_importance,
|
importance_evaluator=self._evaluate_memory_importance,
|
||||||
use_vector_db=use_vector_db
|
use_vector_db=use_vector_db,
|
||||||
)
|
use_query_embedding=use_query_embedding,
|
||||||
|
|
||||||
# 初始化长任务管理器
|
|
||||||
self.task_manager = LongTaskManager(
|
|
||||||
storage_path=data_dir / "tasks.json"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}"
|
"AI client initialized",
|
||||||
|
extra={
|
||||||
|
"provider": model_config.provider.value,
|
||||||
|
"model": model_config.model_name,
|
||||||
|
"use_vector_db": use_vector_db,
|
||||||
|
"use_query_embedding": use_query_embedding,
|
||||||
|
"chat_retries": self.chat_retries,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_model(self, config: ModelConfig):
|
def _create_model(self, config: Optional[ModelConfig]):
|
||||||
"""创建模型实例。"""
|
if config is None:
|
||||||
if config.provider == ModelProvider.OPENAI:
|
return None
|
||||||
|
|
||||||
|
if config.provider in {
|
||||||
|
ModelProvider.OPENAI,
|
||||||
|
ModelProvider.DEEPSEEK,
|
||||||
|
ModelProvider.QWEN,
|
||||||
|
}:
|
||||||
return OpenAIModel(config)
|
return OpenAIModel(config)
|
||||||
elif config.provider == ModelProvider.ANTHROPIC:
|
if config.provider == ModelProvider.ANTHROPIC:
|
||||||
return AnthropicModel(config)
|
return AnthropicModel(config)
|
||||||
elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]:
|
raise ValueError(f"Unsupported model provider: {config.provider}")
|
||||||
# DeepSeek 和 Qwen 使用 OpenAI 兼容接口
|
|
||||||
return OpenAIModel(config)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的模型提供商: {config.provider}")
|
|
||||||
|
|
||||||
async def _embed_wrapper(self, text: str) -> List[float]:
|
async def _embed_wrapper(self, text: str) -> List[float]:
|
||||||
"""嵌入向量包装器。"""
|
|
||||||
try:
|
try:
|
||||||
# 如果有独立的嵌入模型,优先使用
|
|
||||||
if self.embed_model:
|
if self.embed_model:
|
||||||
return await self.embed_model.embed(text)
|
return await self.embed_model.embed(text)
|
||||||
# 否则尝试使用主模型
|
|
||||||
return await self.model.embed(text)
|
return await self.model.embed(text)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
# 如果都不支持嵌入,返回 None(记忆系统会降级)
|
logger.warning("Current model does not support embeddings; fallback to local embedding.")
|
||||||
logger.warning("Current model does not support embeddings; vector retrieval disabled")
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"生成嵌入向量失败: {e}")
|
logger.warning(f"Embedding generation failed: {exc}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -120,17 +119,12 @@ class AIClient:
|
|||||||
async def _evaluate_memory_importance(
|
async def _evaluate_memory_importance(
|
||||||
self, content: str, metadata: Optional[Dict] = None
|
self, content: str, metadata: Optional[Dict] = None
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
|
||||||
调用主模型评估记忆重要性,返回 [0, 1] 分值。
|
|
||||||
"""
|
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。"
|
"You evaluate if content should be kept as long-term memory. "
|
||||||
"输出一个 0 到 1 的数字,数字越大表示越重要。"
|
"Return only a float between 0 and 1."
|
||||||
"只输出数字,不要输出任何解释、单位或多余文本。"
|
|
||||||
)
|
)
|
||||||
payload = json.dumps(
|
payload = json.dumps(
|
||||||
{"content": content, "metadata": metadata or {}},
|
{"content": content, "metadata": metadata or {}}, ensure_ascii=False
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
Message(role="system", content=system_prompt),
|
Message(role="system", content=system_prompt),
|
||||||
@@ -146,636 +140,181 @@ class AIClient:
|
|||||||
)
|
)
|
||||||
score = self._parse_importance_score(response.content)
|
score = self._parse_importance_score(response.content)
|
||||||
return max(0.0, min(1.0, score))
|
return max(0.0, min(1.0, score))
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}")
|
logger.warning(f"Memory importance evaluation failed, fallback to 0.5: {exc}")
|
||||||
return 0.5
|
return 0.5
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
|
||||||
|
try:
|
||||||
|
text = json.dumps(payload, ensure_ascii=False, default=str)
|
||||||
|
except Exception:
|
||||||
|
text = str(payload)
|
||||||
|
if len(text) > max_len:
|
||||||
|
return text[:max_len] + "..."
|
||||||
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_retriable_error(exc: Exception) -> bool:
|
||||||
|
if isinstance(exc, (httpx.ReadTimeout, httpx.ConnectError, asyncio.TimeoutError)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
lower = str(exc).lower()
|
||||||
|
retriable_signals = [
|
||||||
|
"timed out",
|
||||||
|
"timeout",
|
||||||
|
"temporarily unavailable",
|
||||||
|
"connection reset",
|
||||||
|
"service unavailable",
|
||||||
|
"rate limit",
|
||||||
|
"429",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
]
|
||||||
|
return any(signal in lower for signal in retriable_signals)
|
||||||
|
|
||||||
|
async def _chat_with_retry(self, messages: List[Message], **kwargs):
|
||||||
|
attempts = 1 + self.chat_retries
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt in range(1, attempts + 1):
|
||||||
|
try:
|
||||||
|
return await self.model.chat(messages, None, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
last_error = exc
|
||||||
|
if attempt >= attempts or not self._is_retriable_error(exc):
|
||||||
|
break
|
||||||
|
|
||||||
|
sleep_seconds = self.chat_retry_backoff * (2 ** (attempt - 1))
|
||||||
|
logger.warning(
|
||||||
|
"Chat request failed, retrying",
|
||||||
|
extra={
|
||||||
|
"attempt": attempt,
|
||||||
|
"max_attempts": attempts,
|
||||||
|
"sleep_seconds": sleep_seconds,
|
||||||
|
"error": str(exc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await asyncio.sleep(sleep_seconds)
|
||||||
|
|
||||||
|
if last_error:
|
||||||
|
raise last_error
|
||||||
|
raise RuntimeError("chat retry loop ended unexpectedly")
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
use_memory: bool = True,
|
use_memory: bool = True,
|
||||||
use_tools: bool = True,
|
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs
|
group_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
memory_key: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""对话接口。"""
|
if stream:
|
||||||
try:
|
raise NotImplementedError("stream mode is not supported by AIClient.chat")
|
||||||
# 构建消息列表
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# 系统提示词
|
memory_user_key = memory_key or user_id
|
||||||
if system_prompt is None:
|
messages: List[Message] = []
|
||||||
system_prompt = self.personality.get_system_prompt()
|
|
||||||
|
|
||||||
# 注入记忆上下文
|
if system_prompt is None:
|
||||||
if use_memory:
|
system_prompt = self.personality.get_system_prompt(
|
||||||
short_term, long_term = await self.memory.get_context(
|
user_id=user_id,
|
||||||
user_id=user_id,
|
group_id=group_id,
|
||||||
query=user_message
|
session_id=session_id,
|
||||||
)
|
|
||||||
|
|
||||||
if short_term or long_term:
|
|
||||||
memory_context = self.memory.format_context(short_term, long_term)
|
|
||||||
system_prompt += f"\n\n{memory_context}"
|
|
||||||
|
|
||||||
messages.append(Message(role="system", content=system_prompt))
|
|
||||||
|
|
||||||
# 添加用户消息
|
|
||||||
messages.append(Message(role="user", content=user_message))
|
|
||||||
|
|
||||||
# 准备工具
|
|
||||||
tools = None
|
|
||||||
tool_names: List[str] = []
|
|
||||||
if use_tools and self.tools.list():
|
|
||||||
tools = self.tools.to_openai_format()
|
|
||||||
tool_names = [tool.name for tool in self.tools.list()]
|
|
||||||
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
|
|
||||||
if forced_tool_name:
|
|
||||||
kwargs = dict(kwargs)
|
|
||||||
kwargs["forced_tool_name"] = forced_tool_name
|
|
||||||
if tools:
|
|
||||||
before_count = len(tools)
|
|
||||||
tools = [
|
|
||||||
tool
|
|
||||||
for tool in tools
|
|
||||||
if ((tool.get("function") or {}).get("name") == forced_tool_name)
|
|
||||||
]
|
|
||||||
if len(tools) == 1:
|
|
||||||
tool_names = [forced_tool_name]
|
|
||||||
logger.info(
|
|
||||||
"显式工具调用已收敛工具列表: "
|
|
||||||
f"{before_count} -> {len(tools)}"
|
|
||||||
)
|
|
||||||
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"LLM请求: "
|
|
||||||
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
|
|
||||||
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
|
|
||||||
f"tool_names={self._preview_log_payload(tool_names)}, "
|
|
||||||
f"forced_tool={forced_tool_name or '-'}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"LLM输入: "
|
|
||||||
f"user_message={self._preview_log_payload(user_message)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用模型
|
if use_memory:
|
||||||
if stream:
|
short_term, long_term = await self.memory.get_context(
|
||||||
return self._chat_stream(messages, tools, **kwargs)
|
user_id=memory_user_key,
|
||||||
else:
|
query=user_message,
|
||||||
response = await self.model.chat(messages, tools, **kwargs)
|
|
||||||
response_tool_count = len(response.tool_calls or [])
|
|
||||||
response_tool_names = []
|
|
||||||
for tool_call in response.tool_calls or []:
|
|
||||||
if isinstance(tool_call, dict):
|
|
||||||
function_info = tool_call.get("function") or {}
|
|
||||||
response_tool_names.append(function_info.get("name"))
|
|
||||||
else:
|
|
||||||
function_info = getattr(tool_call, "function", None)
|
|
||||||
response_tool_names.append(
|
|
||||||
getattr(function_info, "name", None) if function_info else None
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"LLM首轮输出: "
|
|
||||||
f"tool_calls={response_tool_count}, "
|
|
||||||
f"tool_names={self._preview_log_payload(response_tool_names)}, "
|
|
||||||
f"content={self._preview_log_payload(response.content)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理工具调用
|
|
||||||
if response.tool_calls:
|
|
||||||
response = await self._handle_tool_calls(
|
|
||||||
messages, response, tools, **kwargs
|
|
||||||
)
|
|
||||||
elif forced_tool_name:
|
|
||||||
forced_response = await self._run_forced_tool_fallback(
|
|
||||||
forced_tool_name=forced_tool_name,
|
|
||||||
user_message=user_message,
|
|
||||||
)
|
|
||||||
if forced_response is not None:
|
|
||||||
response = forced_response
|
|
||||||
|
|
||||||
# 写入记忆
|
|
||||||
if use_memory:
|
|
||||||
stored_memory = await self.memory.add_qa_pair(
|
|
||||||
user_id=user_id,
|
|
||||||
question=user_message,
|
|
||||||
answer=response.content,
|
|
||||||
metadata={"source": "chat"},
|
|
||||||
)
|
|
||||||
if stored_memory:
|
|
||||||
logger.info(
|
|
||||||
"已写入长期记忆问答对:\n"
|
|
||||||
f"{stored_memory.content}\n"
|
|
||||||
f"memory_id={stored_memory.id}, "
|
|
||||||
f"importance={stored_memory.importance:.2f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _run_forced_tool_fallback(
|
|
||||||
self, forced_tool_name: str, user_message: str
|
|
||||||
) -> Optional[Message]:
|
|
||||||
"""Execute forced tool locally when model did not emit tool_calls."""
|
|
||||||
tool_def = self.tools.get(forced_tool_name)
|
|
||||||
tool_source = self._tool_sources.get(forced_tool_name, "custom")
|
|
||||||
if not tool_def:
|
|
||||||
logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"模型未返回 tool_calls,启用本地强制工具执行: "
|
|
||||||
f"source={tool_source}, name={forced_tool_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = tool_def.function()
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
result = await result
|
|
||||||
except TypeError as exc:
|
|
||||||
logger.warning(
|
|
||||||
"本地强制工具执行失败(参数不匹配): "
|
|
||||||
f"name={forced_tool_name}, error={exc}"
|
|
||||||
)
|
)
|
||||||
return None
|
if short_term or long_term:
|
||||||
except Exception as exc:
|
memory_context = self.memory.format_context(short_term, long_term)
|
||||||
logger.warning(
|
system_prompt = f"{system_prompt}\n\n{memory_context}".strip()
|
||||||
"本地强制工具执行失败: "
|
|
||||||
f"name={forced_tool_name}, error={exc}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
result_text = str(result)
|
messages.append(Message(role="system", content=system_prompt))
|
||||||
pipelined_text = await self._run_skill_doc_pipeline(
|
messages.append(Message(role="user", content=user_message))
|
||||||
forced_tool_name=forced_tool_name,
|
|
||||||
skill_doc=result_text,
|
|
||||||
user_message=user_message,
|
|
||||||
)
|
|
||||||
if pipelined_text is not None:
|
|
||||||
result_text = pipelined_text
|
|
||||||
|
|
||||||
prefix_limit = self._extract_prefix_limit(user_message)
|
|
||||||
if prefix_limit:
|
|
||||||
result_text = result_text[:prefix_limit]
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"本地强制工具执行成功: "
|
"LLM request",
|
||||||
f"source={tool_source}, name={forced_tool_name}, "
|
extra={
|
||||||
f"result={self._preview_log_payload(result_text)}"
|
"user_id": user_id,
|
||||||
|
"group_id": group_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"memory_key": memory_user_key,
|
||||||
|
"use_memory": use_memory,
|
||||||
|
"message_preview": self._preview_log_payload(user_message),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return Message(role="assistant", content=result_text)
|
|
||||||
|
|
||||||
async def _run_skill_doc_pipeline(
|
|
||||||
self, forced_tool_name: str, skill_doc: str, user_message: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Run an extra model step: execute instructions from skill doc on user text."""
|
|
||||||
if not forced_tool_name.endswith(".read_skill_doc"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_text = self._extract_processing_payload(user_message)
|
|
||||||
if not target_text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
response = await self._chat_with_retry(messages, **kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"强制工具后续处理开始: "
|
"LLM response",
|
||||||
f"name={forced_tool_name}, target_len={len(target_text)}"
|
extra={"content_preview": self._preview_log_payload(response.content)},
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
if use_memory:
|
||||||
Message(
|
stored_memory = await self.memory.add_qa_pair(
|
||||||
role="system",
|
user_id=memory_user_key,
|
||||||
content=(
|
question=user_message,
|
||||||
"你是技能执行器。请严格按下面技能文档处理用户文本。"
|
answer=response.content,
|
||||||
"不要复述技能文档,不要解释工具调用过程,只输出最终处理结果。\n\n"
|
metadata={
|
||||||
"[技能文档开始]\n"
|
"source": "chat",
|
||||||
f"{skill_doc}\n"
|
"user_id": user_id,
|
||||||
"[技能文档结束]"
|
"group_id": group_id,
|
||||||
),
|
"session_id": session_id,
|
||||||
),
|
},
|
||||||
Message(
|
|
||||||
role="user",
|
|
||||||
content=(
|
|
||||||
"请根据技能文档处理以下文本,保持原意并提升自然度:\n"
|
|
||||||
f"{target_text}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.model.chat(messages=messages, tools=None)
|
|
||||||
content = (response.content or "").strip()
|
|
||||||
if not content:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"强制工具后续处理完成: "
|
|
||||||
f"name={forced_tool_name}, output_len={len(content)}"
|
|
||||||
)
|
)
|
||||||
return content
|
if stored_memory:
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"强制工具后续处理失败,回退为工具原始输出: "
|
|
||||||
f"name={forced_tool_name}, error={exc}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _chat_stream(
|
|
||||||
self,
|
|
||||||
messages: List[Message],
|
|
||||||
tools: Optional[List[Dict]],
|
|
||||||
**kwargs
|
|
||||||
) -> AsyncIterator[str]:
|
|
||||||
"""流式对话。"""
|
|
||||||
async for chunk in self.model.chat_stream(messages, tools, **kwargs):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def _handle_tool_calls(
|
|
||||||
self,
|
|
||||||
messages: List[Message],
|
|
||||||
response: Message,
|
|
||||||
tools: Optional[List[Dict]],
|
|
||||||
**kwargs
|
|
||||||
) -> Message:
|
|
||||||
"""处理工具调用。"""
|
|
||||||
messages.append(response)
|
|
||||||
total_calls = len(response.tool_calls or [])
|
|
||||||
if total_calls:
|
|
||||||
logger.info(f"检测到工具调用请求: {total_calls} 个")
|
|
||||||
|
|
||||||
# 执行工具调用
|
|
||||||
for tool_call in response.tool_calls or []:
|
|
||||||
try:
|
|
||||||
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"解析工具调用失败: {e}")
|
|
||||||
fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None)
|
|
||||||
if fallback_id:
|
|
||||||
messages.append(Message(
|
|
||||||
role="tool",
|
|
||||||
content=f"工具参数解析失败: {str(e)}",
|
|
||||||
tool_call_id=fallback_id,
|
|
||||||
name="tool"
|
|
||||||
))
|
|
||||||
continue
|
|
||||||
if not tool_name:
|
|
||||||
logger.warning(f"跳过无效工具调用: {tool_call}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_def = self.tools.get(tool_name)
|
|
||||||
tool_source = self._tool_sources.get(tool_name, "custom")
|
|
||||||
if not tool_def:
|
|
||||||
error_msg = f"未找到工具: {tool_name}"
|
|
||||||
logger.warning(error_msg)
|
|
||||||
messages.append(Message(
|
|
||||||
role="tool",
|
|
||||||
name=tool_name,
|
|
||||||
content=error_msg,
|
|
||||||
tool_call_id=tool_call_id
|
|
||||||
))
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"工具调用开始: "
|
"Long-term memory stored",
|
||||||
f"source={tool_source}, name={tool_name}, "
|
extra={
|
||||||
f"args={self._preview_log_payload(tool_args)}"
|
"memory_id": stored_memory.id,
|
||||||
|
"importance": stored_memory.importance,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
result = tool_def.function(**tool_args)
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
result = await result
|
|
||||||
logger.info(
|
|
||||||
"工具调用成功: "
|
|
||||||
f"source={tool_source}, name={tool_name}, "
|
|
||||||
f"result={self._preview_log_payload(result)}"
|
|
||||||
)
|
|
||||||
messages.append(Message(
|
|
||||||
role="tool",
|
|
||||||
name=tool_name,
|
|
||||||
content=str(result),
|
|
||||||
tool_call_id=tool_call_id
|
|
||||||
))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
"工具调用失败: "
|
|
||||||
f"source={tool_source}, name={tool_name}, error={e}"
|
|
||||||
)
|
|
||||||
messages.append(Message(
|
|
||||||
role="tool",
|
|
||||||
name=tool_name,
|
|
||||||
content=f"工具执行失败: {str(e)}",
|
|
||||||
tool_call_id=tool_call_id
|
|
||||||
))
|
|
||||||
|
|
||||||
# 再次调用模型获取最终响应
|
return response.content
|
||||||
final_kwargs = dict(kwargs)
|
|
||||||
# Force only the first model turn, avoid recursive force after tool result.
|
def set_personality(
|
||||||
final_kwargs.pop("forced_tool_name", None)
|
self, personality_name: str, scope: str = "global", scope_id: Optional[str] = None
|
||||||
final_response = await self.model.chat(messages, tools, **final_kwargs)
|
) -> bool:
|
||||||
logger.info(
|
return self.personality.set_personality(
|
||||||
"LLM最终输出: "
|
key=personality_name,
|
||||||
f"content={self._preview_log_payload(final_response.content)}"
|
scope=scope,
|
||||||
|
scope_id=scope_id,
|
||||||
)
|
)
|
||||||
return final_response
|
|
||||||
|
|
||||||
def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]:
|
|
||||||
"""兼容不同 SDK 返回的工具调用结构。"""
|
|
||||||
if isinstance(tool_call, dict):
|
|
||||||
tool_call_id = tool_call.get('id')
|
|
||||||
function = tool_call.get('function') or {}
|
|
||||||
tool_name = function.get('name')
|
|
||||||
raw_args = function.get('arguments')
|
|
||||||
else:
|
|
||||||
tool_call_id = getattr(tool_call, 'id', None)
|
|
||||||
function = getattr(tool_call, 'function', None)
|
|
||||||
tool_name = getattr(function, 'name', None) if function else None
|
|
||||||
raw_args = getattr(function, 'arguments', None) if function else None
|
|
||||||
|
|
||||||
tool_args = self._normalize_tool_args(raw_args)
|
|
||||||
return tool_name, tool_args, tool_call_id
|
|
||||||
|
|
||||||
def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]:
|
|
||||||
"""将工具参数统一转换为字典。"""
|
|
||||||
if raw_args is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if isinstance(raw_args, dict):
|
|
||||||
return raw_args
|
|
||||||
|
|
||||||
if isinstance(raw_args, str):
|
|
||||||
raw_args = raw_args.strip()
|
|
||||||
if not raw_args:
|
|
||||||
return {}
|
|
||||||
parsed = json.loads(raw_args)
|
|
||||||
if not isinstance(parsed, dict):
|
|
||||||
raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}")
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
if hasattr(raw_args, 'model_dump'):
|
|
||||||
parsed = raw_args.model_dump()
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
|
|
||||||
"""日志中展示参数/结果时使用的简短预览。"""
|
|
||||||
try:
|
|
||||||
text = json.dumps(payload, ensure_ascii=False, default=str)
|
|
||||||
except Exception:
|
|
||||||
text = str(payload)
|
|
||||||
|
|
||||||
if len(text) > max_len:
|
|
||||||
return text[:max_len] + "..."
|
|
||||||
return text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_prefix_limit(user_message: str) -> Optional[int]:
|
|
||||||
"""Extract requested output prefix length like '前100字'."""
|
|
||||||
if not user_message:
|
|
||||||
return None
|
|
||||||
|
|
||||||
match = re.search(r"前\s*(\d{1,4})\s*字", user_message)
|
|
||||||
if not match:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
limit = int(match.group(1))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if limit <= 0:
|
|
||||||
return None
|
|
||||||
return min(limit, 5000)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_processing_payload(user_message: str) -> Optional[str]:
|
|
||||||
"""Extract text payload like '处理以下文本:...' from user message."""
|
|
||||||
if not user_message:
|
|
||||||
return None
|
|
||||||
|
|
||||||
text = user_message.strip()
|
|
||||||
markers = [
|
|
||||||
"以下文本:",
|
|
||||||
"以下文本:",
|
|
||||||
"文本:",
|
|
||||||
"文本:",
|
|
||||||
]
|
|
||||||
for marker in markers:
|
|
||||||
idx = text.find(marker)
|
|
||||||
if idx < 0:
|
|
||||||
continue
|
|
||||||
payload = text[idx + len(marker) :].strip()
|
|
||||||
if payload:
|
|
||||||
return payload
|
|
||||||
|
|
||||||
pattern = re.compile(
|
|
||||||
r"(?:处理|润色|改写|人性化处理|优化)[\s\S]{0,32}(?:如下|以下)[::]\s*([\s\S]+)$"
|
|
||||||
)
|
|
||||||
match = pattern.search(text)
|
|
||||||
if match:
|
|
||||||
payload = (match.group(1) or "").strip()
|
|
||||||
if payload:
|
|
||||||
return payload
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compact_identifier(text: str) -> str:
|
|
||||||
"""Compact identifier for fuzzy matching (e.g. humanizer_zh -> humanizerzh)."""
|
|
||||||
return re.sub(r"[^a-z0-9]+", "", (text or "").lower())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_forced_tool_name(
|
|
||||||
user_message: str, available_tool_names: List[str]
|
|
||||||
) -> Optional[str]:
|
|
||||||
if not user_message or not available_tool_names:
|
|
||||||
return None
|
|
||||||
|
|
||||||
triggers = [
|
|
||||||
"调用工具",
|
|
||||||
"使用工具",
|
|
||||||
"只调用",
|
|
||||||
"务必调用",
|
|
||||||
"必须调用",
|
|
||||||
"调用",
|
|
||||||
"使用",
|
|
||||||
"tool",
|
|
||||||
]
|
|
||||||
if not any(trigger in user_message for trigger in triggers):
|
|
||||||
return None
|
|
||||||
|
|
||||||
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
|
|
||||||
explicit_matches = [
|
|
||||||
name for name in pattern.findall(user_message) if name in available_tool_names
|
|
||||||
]
|
|
||||||
if len(explicit_matches) == 1:
|
|
||||||
return explicit_matches[0]
|
|
||||||
if len(explicit_matches) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
contained = [name for name in available_tool_names if name in user_message]
|
|
||||||
if len(contained) == 1:
|
|
||||||
return contained[0]
|
|
||||||
|
|
||||||
# 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。
|
|
||||||
prefixes = sorted(
|
|
||||||
{name.split(".", 1)[0] for name in available_tool_names},
|
|
||||||
key=len,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
matched_prefixes = [
|
|
||||||
prefix
|
|
||||||
for prefix in prefixes
|
|
||||||
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
|
|
||||||
]
|
|
||||||
if len(matched_prefixes) == 1:
|
|
||||||
prefix_tools = [
|
|
||||||
name
|
|
||||||
for name in available_tool_names
|
|
||||||
if name.startswith(f"{matched_prefixes[0]}.")
|
|
||||||
]
|
|
||||||
if len(prefix_tools) == 1:
|
|
||||||
return prefix_tools[0]
|
|
||||||
|
|
||||||
# 模糊匹配:支持省略下划线/点号的写法(如 humanizerzh)。
|
|
||||||
compact_message = AIClient._compact_identifier(user_message)
|
|
||||||
if compact_message:
|
|
||||||
compact_full_matches = []
|
|
||||||
for tool_name in available_tool_names:
|
|
||||||
compact_tool_name = AIClient._compact_identifier(tool_name)
|
|
||||||
if compact_tool_name and compact_tool_name in compact_message:
|
|
||||||
compact_full_matches.append(tool_name)
|
|
||||||
|
|
||||||
if len(compact_full_matches) == 1:
|
|
||||||
return compact_full_matches[0]
|
|
||||||
if len(compact_full_matches) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
compact_prefix_map: Dict[str, List[str]] = {}
|
|
||||||
for tool_name in available_tool_names:
|
|
||||||
prefix = tool_name.split(".", 1)[0]
|
|
||||||
compact_prefix = AIClient._compact_identifier(prefix)
|
|
||||||
if not compact_prefix:
|
|
||||||
continue
|
|
||||||
compact_prefix_map.setdefault(compact_prefix, []).append(tool_name)
|
|
||||||
|
|
||||||
compact_prefix_matches = [
|
|
||||||
compact_prefix
|
|
||||||
for compact_prefix in compact_prefix_map
|
|
||||||
if compact_prefix in compact_message
|
|
||||||
]
|
|
||||||
if len(compact_prefix_matches) == 1:
|
|
||||||
matched_tools = compact_prefix_map[compact_prefix_matches[0]]
|
|
||||||
if len(matched_tools) == 1:
|
|
||||||
return matched_tools[0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set_personality(self, personality_name: str) -> bool:
|
|
||||||
"""设置人格。"""
|
|
||||||
return self.personality.set_personality(personality_name)
|
|
||||||
|
|
||||||
def list_personalities(self) -> List[str]:
|
def list_personalities(self) -> List[str]:
|
||||||
"""列出所有人格。"""
|
|
||||||
return self.personality.list_personalities()
|
return self.personality.list_personalities()
|
||||||
|
|
||||||
def switch_model(self, model_config: ModelConfig) -> bool:
|
def switch_model(self, model_config: ModelConfig) -> bool:
|
||||||
"""Runtime switch for primary chat model."""
|
self.model = self._create_model(model_config)
|
||||||
new_model = self._create_model(model_config)
|
|
||||||
self.model = new_model
|
|
||||||
self.config = model_config
|
self.config = model_config
|
||||||
logger.info(
|
logger.info(
|
||||||
f"已切换主模型: {model_config.provider.value}/{model_config.model_name}"
|
"Primary model switched",
|
||||||
|
extra={
|
||||||
|
"provider": model_config.provider.value,
|
||||||
|
"model": model_config.model_name,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def create_long_task(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
title: str,
|
|
||||||
description: str,
|
|
||||||
steps: List[Dict],
|
|
||||||
metadata: Optional[Dict] = None
|
|
||||||
) -> str:
|
|
||||||
"""创建长任务。"""
|
|
||||||
return self.task_manager.create_task(
|
|
||||||
user_id=user_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
steps=steps,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start_task(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
progress_callback: Optional[callable] = None
|
|
||||||
):
|
|
||||||
"""启动任务。"""
|
|
||||||
await self.task_manager.start_task(task_id, progress_callback)
|
|
||||||
|
|
||||||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
|
||||||
"""获取任务状态。"""
|
|
||||||
return self.task_manager.get_task_status(task_id)
|
|
||||||
|
|
||||||
def register_tool(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
parameters: Dict,
|
|
||||||
function: callable,
|
|
||||||
source: str = "custom",
|
|
||||||
):
|
|
||||||
"""注册工具。"""
|
|
||||||
from .base import ToolDefinition
|
|
||||||
tool = ToolDefinition(
|
|
||||||
name=name,
|
|
||||||
description=description,
|
|
||||||
parameters=parameters,
|
|
||||||
function=function
|
|
||||||
)
|
|
||||||
self.tools.register(tool)
|
|
||||||
self._tool_sources[name] = source
|
|
||||||
logger.info(f"已注册工具: {name} (source={source})")
|
|
||||||
|
|
||||||
def unregister_tool(self, name: str) -> bool:
|
|
||||||
"""卸载工具。"""
|
|
||||||
removed = self.tools.unregister(name)
|
|
||||||
if removed:
|
|
||||||
self._tool_sources.pop(name, None)
|
|
||||||
logger.info(f"已卸载工具: {name}")
|
|
||||||
return removed
|
|
||||||
|
|
||||||
def unregister_tools_by_prefix(self, prefix: str) -> int:
|
|
||||||
"""按前缀批量卸载工具。"""
|
|
||||||
removed_count = self.tools.unregister_by_prefix(prefix)
|
|
||||||
for tool_name in list(self._tool_sources.keys()):
|
|
||||||
if tool_name.startswith(prefix):
|
|
||||||
self._tool_sources.pop(tool_name, None)
|
|
||||||
if removed_count:
|
|
||||||
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def clear_memory(self, user_id: str):
|
def clear_memory(self, user_id: str):
|
||||||
"""清除用户短期记忆。"""
|
|
||||||
self.memory.clear_short_term(user_id)
|
self.memory.clear_short_term(user_id)
|
||||||
logger.info(f"Cleared short-term memory for user {user_id}")
|
logger.info(f"Cleared short-term memory for {user_id}")
|
||||||
|
|
||||||
async def clear_long_term_memory(self, user_id: str) -> bool:
|
async def clear_long_term_memory(self, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
await self.memory.clear_long_term(user_id)
|
await self.memory.clear_long_term(user_id)
|
||||||
logger.info(f"Cleared long-term memory for user {user_id}")
|
logger.info(f"Cleared long-term memory for {user_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}")
|
logger.warning(f"Failed to clear long-term memory for {user_id}: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def list_long_term_memories(self, user_id: str, limit: int = 20):
|
async def list_long_term_memories(self, user_id: str, limit: int = 20):
|
||||||
@@ -823,10 +362,18 @@ class AIClient:
|
|||||||
return await self.memory.delete_long_term(user_id, memory_id)
|
return await self.memory.delete_long_term(user_id, memory_id)
|
||||||
|
|
||||||
async def clear_all_memory(self, user_id: str) -> bool:
|
async def clear_all_memory(self, user_id: str) -> bool:
|
||||||
"""清除用户全部记忆(短期 + 长期)。"""
|
|
||||||
self.clear_memory(user_id)
|
self.clear_memory(user_id)
|
||||||
try:
|
return await self.clear_long_term_memory(user_id)
|
||||||
return await self.clear_long_term_memory(user_id)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.memory.close()
|
||||||
|
|
||||||
|
for model in [self.model, self.embed_model]:
|
||||||
|
if model is None:
|
||||||
|
continue
|
||||||
|
close = getattr(model, "close", None)
|
||||||
|
if close is None:
|
||||||
|
continue
|
||||||
|
maybe_awaitable = close()
|
||||||
|
if asyncio.iscoroutine(maybe_awaitable):
|
||||||
|
await maybe_awaitable
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP模块
|
|
||||||
"""
|
|
||||||
from .base import MCPServer, MCPClient, MCPManager, MCPResource, MCPTool, MCPPrompt
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'MCPServer',
|
|
||||||
'MCPClient',
|
|
||||||
'MCPManager',
|
|
||||||
'MCPResource',
|
|
||||||
'MCPTool',
|
|
||||||
'MCPPrompt'
|
|
||||||
]
|
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) 支持
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Any, Callable
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from pathlib import Path
|
|
||||||
from src.utils.logger import setup_logger
|
|
||||||
|
|
||||||
logger = setup_logger('MCPSystem')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPResource:
|
|
||||||
"""MCP资源"""
|
|
||||||
uri: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
mime_type: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPTool:
|
|
||||||
"""MCP工具"""
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPPrompt:
|
|
||||||
"""MCP提示词"""
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
arguments: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServer:
|
|
||||||
"""MCP服务器基类"""
|
|
||||||
|
|
||||||
def __init__(self, name: str, version: str):
|
|
||||||
self.name = name
|
|
||||||
self.version = version
|
|
||||||
self.resources: Dict[str, MCPResource] = {}
|
|
||||||
self.tools: Dict[str, Callable] = {}
|
|
||||||
self.tool_specs: Dict[str, MCPTool] = {}
|
|
||||||
self.prompts: Dict[str, MCPPrompt] = {}
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""初始化服务器"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
"""关闭服务器"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def register_resource(self, resource: MCPResource):
|
|
||||||
"""注册资源"""
|
|
||||||
self.resources[resource.uri] = resource
|
|
||||||
|
|
||||||
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
|
|
||||||
"""注册工具"""
|
|
||||||
tool = MCPTool(name=name, description=description, input_schema=input_schema)
|
|
||||||
self.tool_specs[name] = tool
|
|
||||||
self.tools[name] = handler
|
|
||||||
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
|
|
||||||
|
|
||||||
def register_prompt(self, prompt: MCPPrompt):
|
|
||||||
"""注册提示词"""
|
|
||||||
self.prompts[prompt.name] = prompt
|
|
||||||
|
|
||||||
async def list_resources(self) -> List[MCPResource]:
|
|
||||||
"""列出资源"""
|
|
||||||
return list(self.resources.values())
|
|
||||||
|
|
||||||
async def read_resource(self, uri: str) -> Optional[str]:
|
|
||||||
"""读取资源"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
|
||||||
"""列出工具"""
|
|
||||||
return list(self.tool_specs.values())
|
|
||||||
|
|
||||||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
|
||||||
"""调用工具"""
|
|
||||||
if name not in self.tools:
|
|
||||||
raise ValueError(f"工具不存在: {name}")
|
|
||||||
|
|
||||||
handler = self.tools[name]
|
|
||||||
logger.info(
|
|
||||||
"MCP工具调用开始: "
|
|
||||||
f"server={self.name}, tool={name}, "
|
|
||||||
f"args={json.dumps(arguments, ensure_ascii=False, default=str)}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
result = await handler(**arguments)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
|
|
||||||
raise
|
|
||||||
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def list_prompts(self) -> List[MCPPrompt]:
|
|
||||||
"""列出提示词"""
|
|
||||||
return list(self.prompts.values())
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]:
|
|
||||||
"""获取提示词"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""MCP客户端"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.servers: Dict[str, MCPServer] = {}
|
|
||||||
|
|
||||||
async def connect_server(self, server: MCPServer):
|
|
||||||
"""连接服务器"""
|
|
||||||
await server.initialize()
|
|
||||||
self.servers[server.name] = server
|
|
||||||
logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}")
|
|
||||||
|
|
||||||
async def disconnect_server(self, server_name: str):
|
|
||||||
"""断开服务器"""
|
|
||||||
if server_name in self.servers:
|
|
||||||
await self.servers[server_name].shutdown()
|
|
||||||
del self.servers[server_name]
|
|
||||||
logger.info(f"✅ 断开MCP服务器: {server_name}")
|
|
||||||
|
|
||||||
def get_server(self, name: str) -> Optional[MCPServer]:
|
|
||||||
"""获取服务器"""
|
|
||||||
return self.servers.get(name)
|
|
||||||
|
|
||||||
def list_servers(self) -> List[str]:
|
|
||||||
"""列出所有服务器"""
|
|
||||||
return list(self.servers.keys())
|
|
||||||
|
|
||||||
async def list_all_resources(self) -> Dict[str, List[MCPResource]]:
|
|
||||||
"""列出所有资源"""
|
|
||||||
result = {}
|
|
||||||
for name, server in self.servers.items():
|
|
||||||
result[name] = await server.list_resources()
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def list_all_tools(self) -> Dict[str, List[MCPTool]]:
|
|
||||||
"""列出所有工具"""
|
|
||||||
result = {}
|
|
||||||
for name, server in self.servers.items():
|
|
||||||
result[name] = await server.list_tools()
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
||||||
"""调用工具"""
|
|
||||||
server = self.get_server(server_name)
|
|
||||||
if not server:
|
|
||||||
raise ValueError(f"服务器不存在: {server_name}")
|
|
||||||
|
|
||||||
return await server.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPManager:
|
|
||||||
"""MCP管理器"""
|
|
||||||
|
|
||||||
def __init__(self, config_path: Path):
|
|
||||||
self.config_path = config_path
|
|
||||||
self.client = MCPClient()
|
|
||||||
self.server_configs: Dict[str, Dict] = {}
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
def _load_config(self):
|
|
||||||
"""加载配置"""
|
|
||||||
if self.config_path.exists():
|
|
||||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
||||||
self.server_configs = json.load(f)
|
|
||||||
|
|
||||||
def _save_config(self):
|
|
||||||
"""保存配置"""
|
|
||||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(self.server_configs, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
async def register_server(self, server: MCPServer, config: Optional[Dict] = None):
|
|
||||||
"""注册服务器"""
|
|
||||||
await self.client.connect_server(server)
|
|
||||||
|
|
||||||
if config:
|
|
||||||
self.server_configs[server.name] = config
|
|
||||||
self._save_config()
|
|
||||||
|
|
||||||
async def unregister_server(self, server_name: str):
|
|
||||||
"""注销服务器"""
|
|
||||||
await self.client.disconnect_server(server_name)
|
|
||||||
|
|
||||||
if server_name in self.server_configs:
|
|
||||||
del self.server_configs[server_name]
|
|
||||||
self._save_config()
|
|
||||||
|
|
||||||
def get_client(self) -> MCPClient:
|
|
||||||
"""获取客户端"""
|
|
||||||
return self.client
|
|
||||||
|
|
||||||
async def get_all_tools_for_ai(self) -> List[Dict]:
|
|
||||||
"""获取所有工具(AI格式)"""
|
|
||||||
all_tools = []
|
|
||||||
tools_by_server = await self.client.list_all_tools()
|
|
||||||
|
|
||||||
for server_name, tools in tools_by_server.items():
|
|
||||||
for tool in tools:
|
|
||||||
all_tools.append({
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": f"{server_name}.{tool.name}",
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": tool.input_schema
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return all_tools
|
|
||||||
|
|
||||||
async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any:
|
|
||||||
"""执行工具"""
|
|
||||||
parts = full_tool_name.split('.', 1)
|
|
||||||
if len(parts) != 2:
|
|
||||||
raise ValueError(f"工具名格式错误: {full_tool_name}")
|
|
||||||
|
|
||||||
server_name, tool_name = parts
|
|
||||||
logger.info(f"MCP执行请求: {full_tool_name}")
|
|
||||||
return await self.client.call_tool(server_name, tool_name, arguments)
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP服务器实现
|
|
||||||
"""
|
|
||||||
from .filesystem import FileSystemMCPServer
|
|
||||||
|
|
||||||
__all__ = ['FileSystemMCPServer']
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP示例服务器 - 文件系统访问
|
|
||||||
"""
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
from ..base import MCPServer, MCPResource
|
|
||||||
|
|
||||||
|
|
||||||
class FileSystemMCPServer(MCPServer):
|
|
||||||
"""文件系统MCP服务器"""
|
|
||||||
|
|
||||||
def __init__(self, root_path: Path):
|
|
||||||
super().__init__(name="filesystem", version="1.0.0")
|
|
||||||
self.root_path = root_path
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""初始化"""
|
|
||||||
# 注册工具
|
|
||||||
self.register_tool(
|
|
||||||
name="read_file",
|
|
||||||
description="读取文件内容",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "文件路径"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
},
|
|
||||||
handler=self.read_file
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_tool(
|
|
||||||
name="write_file",
|
|
||||||
description="写入文件内容",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "文件路径"
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "文件内容"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path", "content"]
|
|
||||||
},
|
|
||||||
handler=self.write_file
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_tool(
|
|
||||||
name="list_directory",
|
|
||||||
description="列出目录内容",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "目录路径"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
},
|
|
||||||
handler=self.list_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
def _resolve_path(self, path: str) -> Path:
|
|
||||||
"""解析路径"""
|
|
||||||
full_path = (self.root_path / path).resolve()
|
|
||||||
|
|
||||||
# 安全检查:确保路径在root_path内
|
|
||||||
if not str(full_path).startswith(str(self.root_path)):
|
|
||||||
raise ValueError("路径超出允许范围")
|
|
||||||
|
|
||||||
return full_path
|
|
||||||
|
|
||||||
async def read_file(self, path: str) -> str:
|
|
||||||
"""读取文件"""
|
|
||||||
file_path = self._resolve_path(path)
|
|
||||||
|
|
||||||
if not file_path.exists():
|
|
||||||
raise FileNotFoundError(f"文件不存在: {path}")
|
|
||||||
|
|
||||||
if not file_path.is_file():
|
|
||||||
raise ValueError(f"不是文件: {path}")
|
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
return f.read()
|
|
||||||
|
|
||||||
async def write_file(self, path: str, content: str) -> str:
|
|
||||||
"""写入文件"""
|
|
||||||
file_path = self._resolve_path(path)
|
|
||||||
|
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
return f"文件已写入: {path}"
|
|
||||||
|
|
||||||
async def list_directory(self, path: str) -> list:
|
|
||||||
"""列出目录"""
|
|
||||||
dir_path = self._resolve_path(path)
|
|
||||||
|
|
||||||
if not dir_path.exists():
|
|
||||||
raise FileNotFoundError(f"目录不存在: {path}")
|
|
||||||
|
|
||||||
if not dir_path.is_dir():
|
|
||||||
raise ValueError(f"不是目录: {path}")
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for item in dir_path.iterdir():
|
|
||||||
items.append({
|
|
||||||
"name": item.name,
|
|
||||||
"type": "directory" if item.is_dir() else "file",
|
|
||||||
"size": item.stat().st_size if item.is_file() else None
|
|
||||||
})
|
|
||||||
|
|
||||||
return items
|
|
||||||
309
src/ai/memory.py
309
src/ai/memory.py
@@ -1,25 +1,29 @@
|
|||||||
"""
|
|
||||||
记忆系统:短期记忆、长期记忆与 RAG 检索(向量数据库)。
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
Memory system: short-term window + long-term retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict, Optional, Tuple, Callable, Awaitable
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
from .vector_store import VectorStore, VectorMemory, ChromaVectorStore, JSONVectorStore
|
|
||||||
|
from .vector_store import ChromaVectorStore, JSONVectorStore, VectorMemory, VectorStore
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger('MemorySystem')
|
logger = setup_logger("MemorySystem")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryItem:
|
class MemoryItem:
|
||||||
"""记忆项(用于短期记忆)。"""
|
"""In-memory short-term record."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -27,154 +31,136 @@ class MemoryItem:
|
|||||||
metadata: Dict = field(default_factory=dict)
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
"""转换为字典。"""
|
|
||||||
return {
|
return {
|
||||||
'content': self.content,
|
"content": self.content,
|
||||||
'timestamp': self.timestamp.isoformat(),
|
"timestamp": self.timestamp.isoformat(),
|
||||||
'user_id': self.user_id,
|
"user_id": self.user_id,
|
||||||
'importance': self.importance,
|
"importance": self.importance,
|
||||||
'metadata': self.metadata
|
"metadata": self.metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ShortTermMemory:
|
class ShortTermMemory:
|
||||||
"""短期记忆(滑动窗口)。"""
|
"""Short-term memory window."""
|
||||||
|
|
||||||
def __init__(self, max_size: int = 20, max_age_minutes: int = 30):
|
def __init__(self, max_size: int = 20, max_age_minutes: int = 30):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.max_age = timedelta(minutes=max_age_minutes)
|
self.max_age = timedelta(minutes=max_age_minutes)
|
||||||
self.memories: Dict[str, deque] = {} # user_id -> deque of MemoryItem
|
self.memories: Dict[str, deque] = {}
|
||||||
|
|
||||||
def add(self, user_id: str, content: str, metadata: Optional[Dict] = None):
|
def add(self, user_id: str, content: str, metadata: Optional[Dict] = None):
|
||||||
"""添加短期记忆。"""
|
|
||||||
if user_id not in self.memories:
|
if user_id not in self.memories:
|
||||||
self.memories[user_id] = deque(maxlen=self.max_size)
|
self.memories[user_id] = deque(maxlen=self.max_size)
|
||||||
|
|
||||||
memory = MemoryItem(
|
self.memories[user_id].append(
|
||||||
content=content,
|
MemoryItem(
|
||||||
timestamp=datetime.now(),
|
content=content,
|
||||||
user_id=user_id,
|
timestamp=datetime.now(),
|
||||||
metadata=metadata or {}
|
user_id=user_id,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.memories[user_id].append(memory)
|
|
||||||
|
|
||||||
def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]:
|
def get(self, user_id: str, limit: Optional[int] = None) -> List[MemoryItem]:
|
||||||
"""获取短期记忆。"""
|
items = list(self.memories.get(user_id, []))
|
||||||
if user_id not in self.memories:
|
if not items:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 过滤过期记忆
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
valid_memories = [
|
valid = [m for m in items if now - m.timestamp <= self.max_age]
|
||||||
m for m in self.memories[user_id]
|
self.memories[user_id] = deque(valid, maxlen=self.max_size)
|
||||||
if now - m.timestamp <= self.max_age
|
|
||||||
]
|
|
||||||
|
|
||||||
if limit:
|
if limit and limit > 0:
|
||||||
valid_memories = valid_memories[-limit:]
|
return valid[-limit:]
|
||||||
|
return valid
|
||||||
return valid_memories
|
|
||||||
|
|
||||||
def clear(self, user_id: str):
|
def clear(self, user_id: str):
|
||||||
"""清除用户短期记忆。"""
|
self.memories.pop(user_id, None)
|
||||||
if user_id in self.memories:
|
|
||||||
self.memories.pop(user_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
class MemorySystem:
|
class MemorySystem:
|
||||||
"""记忆系统:整合短期记忆与长期记忆。"""
|
"""Memory system: short-term + long-term storage."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_path: Path,
|
storage_path: Path,
|
||||||
embed_func: Optional[callable] = None,
|
embed_func: Optional[Callable[[str], Awaitable[List[float]]]] = None,
|
||||||
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
|
importance_evaluator: Optional[Callable[[str, Optional[Dict]], Awaitable[float]]] = None,
|
||||||
importance_threshold: float = 0.6,
|
importance_threshold: float = 0.6,
|
||||||
use_vector_db: bool = True,
|
use_vector_db: bool = True,
|
||||||
use_query_embedding: bool = False,
|
use_query_embedding: bool = False,
|
||||||
|
max_long_term_per_user: int = 500,
|
||||||
|
dedup_window_seconds: int = 300,
|
||||||
):
|
):
|
||||||
self.short_term = ShortTermMemory()
|
self.short_term = ShortTermMemory()
|
||||||
self.embed_func = embed_func
|
self.embed_func = embed_func
|
||||||
self.importance_evaluator = importance_evaluator
|
self.importance_evaluator = importance_evaluator
|
||||||
self.importance_threshold = importance_threshold
|
self.importance_threshold = importance_threshold
|
||||||
# Only embed retrieval queries when explicitly enabled.
|
|
||||||
self.use_query_embedding = use_query_embedding
|
self.use_query_embedding = use_query_embedding
|
||||||
|
self.max_long_term_per_user = max(10, int(max_long_term_per_user))
|
||||||
|
self.dedup_window = timedelta(seconds=max(1, int(dedup_window_seconds)))
|
||||||
|
|
||||||
# 初始化向量存储
|
|
||||||
if use_vector_db:
|
if use_vector_db:
|
||||||
chroma_path = storage_path.parent / "chroma_db"
|
chroma_path = storage_path.parent / "chroma_db"
|
||||||
chroma_store = self._init_chroma_store(chroma_path)
|
chroma_store = self._init_chroma_store(chroma_path)
|
||||||
if chroma_store is not None:
|
if chroma_store is not None:
|
||||||
self.vector_store = chroma_store
|
self.vector_store = chroma_store
|
||||||
logger.info("Using Chroma vector store")
|
|
||||||
else:
|
else:
|
||||||
self.vector_store = JSONVectorStore(storage_path)
|
self.vector_store = JSONVectorStore(storage_path)
|
||||||
else:
|
else:
|
||||||
# 使用 JSON 存储(向后兼容)
|
|
||||||
self.vector_store = JSONVectorStore(storage_path)
|
self.vector_store = JSONVectorStore(storage_path)
|
||||||
logger.info("使用 JSON 存储")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_chroma_table_conflict(error: Exception) -> bool:
|
def _is_chroma_table_conflict(error: Exception) -> bool:
|
||||||
msg = str(error).lower()
|
return "table embeddings already exists" in str(error).lower()
|
||||||
return "table embeddings already exists" in msg
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_chroma_trigram_error(error: Exception) -> bool:
|
def _is_chroma_trigram_error(error: Exception) -> bool:
|
||||||
msg = str(error).lower()
|
return "no such tokenizer: trigram" in str(error).lower()
|
||||||
return "no such tokenizer: trigram" in msg
|
|
||||||
|
|
||||||
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
|
def _init_chroma_store(self, chroma_path: Path) -> Optional[VectorStore]:
|
||||||
"""初始化 Chroma,遇到已知 sqlite schema 冲突时尝试修复。"""
|
|
||||||
try:
|
try:
|
||||||
return ChromaVectorStore(chroma_path)
|
return ChromaVectorStore(chroma_path)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
if self._is_chroma_trigram_error(error):
|
if self._is_chroma_trigram_error(error):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Chroma 初始化失败,降级为 JSON 存储: sqlite 缺少 trigram tokenizer。"
|
"Chroma unavailable (sqlite trigram unsupported), fallback to JSON store."
|
||||||
"请在运行环境升级 sqlite 或安装 pysqlite3-binary。"
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not self._is_chroma_table_conflict(error):
|
if not self._is_chroma_table_conflict(error):
|
||||||
logger.warning(f"Chroma 初始化失败,降级为 JSON 存储: {error}")
|
logger.warning(f"Chroma init failed, fallback to JSON: {error}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 先做一次短暂重试,处理并发启动时的瞬时冲突。
|
logger.warning(f"Chroma schema conflict, retry once: {error}")
|
||||||
logger.warning(f"Chroma 初始化出现 schema 冲突,正在重试: {error}")
|
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
try:
|
try:
|
||||||
return ChromaVectorStore(chroma_path)
|
return ChromaVectorStore(chroma_path)
|
||||||
except Exception as retry_error:
|
except Exception as retry_error:
|
||||||
if not self._is_chroma_table_conflict(retry_error):
|
if not self._is_chroma_table_conflict(retry_error):
|
||||||
logger.warning(f"Chroma 重试失败,降级为 JSON 存储: {retry_error}")
|
logger.warning(f"Chroma retry failed, fallback to JSON: {retry_error}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
backup_name = (
|
backup_path = chroma_path.parent / (
|
||||||
f"{chroma_path.name}_backup_conflict_"
|
f"{chroma_path.name}_backup_conflict_"
|
||||||
f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||||
)
|
)
|
||||||
backup_path = chroma_path.parent / backup_name
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if chroma_path.exists():
|
if chroma_path.exists():
|
||||||
shutil.move(str(chroma_path), str(backup_path))
|
shutil.move(str(chroma_path), str(backup_path))
|
||||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||||
repaired = ChromaVectorStore(chroma_path)
|
repaired = ChromaVectorStore(chroma_path)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"检测到 Chroma 元数据库冲突,已重建目录并保留备份: {backup_path}"
|
f"Chroma metadata repaired by rebuilding directory. Backup: {backup_path}"
|
||||||
)
|
)
|
||||||
return repaired
|
return repaired
|
||||||
except Exception as repair_error:
|
except Exception as repair_error:
|
||||||
logger.warning(f"Chroma 修复失败,降级为 JSON 存储: {repair_error}")
|
logger.warning(f"Chroma repair failed, fallback to JSON: {repair_error}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]:
|
def _normalize_embedding(values: List[float], dim: int = 1024) -> List[float]:
|
||||||
if not values:
|
if not values:
|
||||||
return [0.0] * dim
|
return [0.0] * dim
|
||||||
|
|
||||||
normalized = [float(v) for v in values[:dim]]
|
normalized = [float(v) for v in values[:dim]]
|
||||||
if len(normalized) < dim:
|
if len(normalized) < dim:
|
||||||
normalized.extend([0.0] * (dim - len(normalized)))
|
normalized.extend([0.0] * (dim - len(normalized)))
|
||||||
@@ -191,14 +177,11 @@ class MemorySystem:
|
|||||||
return vec
|
return vec
|
||||||
|
|
||||||
for idx, byte in enumerate(encoded):
|
for idx, byte in enumerate(encoded):
|
||||||
bucket = idx % dim
|
vec[idx % dim] += byte / 255.0
|
||||||
vec[bucket] += (byte / 255.0)
|
|
||||||
|
|
||||||
digest = hashlib.sha256(encoded).digest()
|
digest = hashlib.sha256(encoded).digest()
|
||||||
for idx, byte in enumerate(digest):
|
for idx, byte in enumerate(digest):
|
||||||
bucket = idx % dim
|
vec[idx % dim] += ((byte / 255.0) - 0.5) * 0.1
|
||||||
vec[bucket] += ((byte / 255.0) - 0.5) * 0.1
|
|
||||||
|
|
||||||
return vec
|
return vec
|
||||||
|
|
||||||
async def _build_embedding(self, text: str) -> List[float]:
|
async def _build_embedding(self, text: str) -> List[float]:
|
||||||
@@ -207,9 +190,8 @@ class MemorySystem:
|
|||||||
embedding = await self.embed_func(text)
|
embedding = await self.embed_func(text)
|
||||||
if embedding:
|
if embedding:
|
||||||
return [float(v) for v in list(embedding)]
|
return [float(v) for v in list(embedding)]
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.warning(f"embedding generation failed: {e}")
|
logger.warning(f"Embedding generation failed, fallback to local: {exc}")
|
||||||
|
|
||||||
return self._local_embedding(text)
|
return self._local_embedding(text)
|
||||||
|
|
||||||
async def _add_vector_memory(
|
async def _add_vector_memory(
|
||||||
@@ -231,10 +213,8 @@ class MemorySystem:
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Chroma collection may have a fixed historical embedding dimension.
|
candidate_dims: List[int] = []
|
||||||
candidate_dims = []
|
for dim in [len(embedding or []), 1024, 1536, 768, 384, 3072]:
|
||||||
base_len = len(embedding or [])
|
|
||||||
for dim in [base_len, 1024, 1536, 768, 384, 3072]:
|
|
||||||
if dim and dim > 0 and dim not in candidate_dims:
|
if dim and dim > 0 and dim not in candidate_dims:
|
||||||
candidate_dims.append(dim)
|
candidate_dims.append(dim)
|
||||||
|
|
||||||
@@ -250,7 +230,6 @@ class MemorySystem:
|
|||||||
)
|
)
|
||||||
if ok:
|
if ok:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -261,15 +240,53 @@ class MemorySystem:
|
|||||||
value = 0.5
|
value = 0.5
|
||||||
return max(0.0, min(1.0, value))
|
return max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float:
|
||||||
|
if not content or not content.strip():
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if self.importance_evaluator:
|
||||||
|
try:
|
||||||
|
score = await self.importance_evaluator(content, metadata)
|
||||||
|
return self._normalize_importance(score)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Importance evaluation failed, fallback to 0.5: {exc}")
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
async def _is_duplicate_long_term(self, user_id: str, content: str) -> bool:
|
||||||
|
normalized = " ".join((content or "").split())
|
||||||
|
if not normalized:
|
||||||
|
return True
|
||||||
|
|
||||||
|
all_memories = await self.vector_store.get_all(user_id)
|
||||||
|
recent_cutoff = datetime.now() - self.dedup_window
|
||||||
|
for memory in all_memories[-20:]:
|
||||||
|
if (
|
||||||
|
" ".join((memory.content or "").split()) == normalized
|
||||||
|
and memory.timestamp >= recent_cutoff
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _trim_user_long_term(self, user_id: str):
|
||||||
|
all_memories = await self.vector_store.get_all(user_id)
|
||||||
|
if len(all_memories) <= self.max_long_term_per_user:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_memories.sort(key=lambda m: (m.importance, m.timestamp))
|
||||||
|
to_delete = len(all_memories) - self.max_long_term_per_user
|
||||||
|
for memory in all_memories[:to_delete]:
|
||||||
|
await self.vector_store.delete(memory.id)
|
||||||
|
|
||||||
async def add_message(
|
async def add_message(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
role: str,
|
role: str,
|
||||||
content: str,
|
content: str,
|
||||||
metadata: Optional[Dict] = None
|
metadata: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
"""向短期记忆添加单条消息(不做长期记忆评分)。"""
|
payload = dict(metadata or {})
|
||||||
self.short_term.add(user_id, content, metadata)
|
payload.setdefault("role", role)
|
||||||
|
self.short_term.add(user_id, content, payload)
|
||||||
|
|
||||||
async def add_qa_pair(
|
async def add_qa_pair(
|
||||||
self,
|
self,
|
||||||
@@ -278,9 +295,6 @@ class MemorySystem:
|
|||||||
answer: str,
|
answer: str,
|
||||||
metadata: Optional[Dict] = None,
|
metadata: Optional[Dict] = None,
|
||||||
) -> Optional[VectorMemory]:
|
) -> Optional[VectorMemory]:
|
||||||
"""
|
|
||||||
添加最新问答对,并仅对该问答对做模型重要性评估。
|
|
||||||
"""
|
|
||||||
user_meta = {"role": "user"}
|
user_meta = {"role": "user"}
|
||||||
assistant_meta = {"role": "assistant"}
|
assistant_meta = {"role": "assistant"}
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
@@ -297,6 +311,8 @@ class MemorySystem:
|
|||||||
importance = await self._evaluate_importance(qa_content, qa_metadata)
|
importance = await self._evaluate_importance(qa_content, qa_metadata)
|
||||||
if importance < self.importance_threshold:
|
if importance < self.importance_threshold:
|
||||||
return None
|
return None
|
||||||
|
if await self._is_duplicate_long_term(user_id, qa_content):
|
||||||
|
return None
|
||||||
|
|
||||||
embedding = await self._build_embedding(qa_content)
|
embedding = await self._build_embedding(qa_content)
|
||||||
memory_id = str(uuid.uuid4())
|
memory_id = str(uuid.uuid4())
|
||||||
@@ -311,97 +327,88 @@ class MemorySystem:
|
|||||||
if not ok:
|
if not ok:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
await self._trim_user_long_term(user_id)
|
||||||
return await self.get_long_term(user_id, memory_id)
|
return await self.get_long_term(user_id, memory_id)
|
||||||
|
|
||||||
async def _evaluate_importance(self, content: str, metadata: Optional[Dict]) -> float:
|
@staticmethod
|
||||||
"""评估记忆重要性。"""
|
def _simple_text_score(query: str, text: str) -> float:
|
||||||
if not content or not content.strip():
|
query_tokens = [token for token in query.lower().split() if token]
|
||||||
|
if not query_tokens:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
text_lower = text.lower()
|
||||||
if self.importance_evaluator:
|
hit = sum(1 for token in query_tokens if token in text_lower)
|
||||||
try:
|
return hit / len(query_tokens)
|
||||||
score = await self.importance_evaluator(content, metadata)
|
|
||||||
return self._normalize_importance(score)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"importance evaluation failed, fallback to neutral score: {e}")
|
|
||||||
|
|
||||||
# 当模型评估不可用时,使用中性分数作为兜底。
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
async def get_context(
|
async def get_context(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
max_short_term: int = 10,
|
max_short_term: int = 10,
|
||||||
max_long_term: int = 5
|
max_long_term: int = 5,
|
||||||
) -> Tuple[List[MemoryItem], List[VectorMemory]]:
|
) -> Tuple[List[MemoryItem], List[VectorMemory]]:
|
||||||
"""获取上下文(短期 + 长期记忆)。"""
|
|
||||||
# 获取短期记忆
|
|
||||||
short_term_memories = self.short_term.get(user_id, limit=max_short_term)
|
short_term_memories = self.short_term.get(user_id, limit=max_short_term)
|
||||||
|
long_term_memories: List[VectorMemory] = []
|
||||||
# 获取相关长期记忆
|
|
||||||
long_term_memories = []
|
|
||||||
|
|
||||||
if query and self.use_query_embedding:
|
if query and self.use_query_embedding:
|
||||||
try:
|
try:
|
||||||
# 使用向量检索
|
|
||||||
query_embedding = await self._build_embedding(query)
|
query_embedding = await self._build_embedding(query)
|
||||||
if query_embedding:
|
long_term_memories = await self.vector_store.search(
|
||||||
long_term_memories = await self.vector_store.search(
|
user_id=user_id,
|
||||||
user_id=user_id,
|
query_embedding=query_embedding,
|
||||||
query_embedding=query_embedding,
|
limit=max_long_term,
|
||||||
limit=max_long_term
|
)
|
||||||
)
|
except Exception as exc:
|
||||||
except Exception as e:
|
logger.warning(f"Vector search failed, fallback to lexical: {exc}")
|
||||||
logger.warning(f"向量检索失败,改用重要性检索: {e}")
|
|
||||||
|
|
||||||
if query and not long_term_memories:
|
if query and not long_term_memories:
|
||||||
query_lower = query.lower()
|
candidates = await self.vector_store.get_all(user_id)
|
||||||
try:
|
scored = []
|
||||||
candidates = await self.vector_store.get_all(user_id)
|
for memory in candidates:
|
||||||
matches = [m for m in candidates if query_lower in m.content.lower()]
|
score = self._simple_text_score(query, memory.content)
|
||||||
matches.sort(key=lambda m: (m.importance, m.timestamp), reverse=True)
|
if score <= 0:
|
||||||
long_term_memories = matches[:max_long_term]
|
continue
|
||||||
except Exception:
|
combined = score * 0.7 + memory.importance * 0.3
|
||||||
pass
|
scored.append((combined, memory))
|
||||||
|
scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True)
|
||||||
|
long_term_memories = [item[1] for item in scored[:max_long_term]]
|
||||||
|
|
||||||
# 濡傛灉鍚戦噺妫€绱㈠け璐ユ垨娌℃湁缁撴灉锛屼娇鐢ㄩ噸瑕佹€ф绱?
|
|
||||||
if not long_term_memories:
|
if not long_term_memories:
|
||||||
long_term_memories = await self.vector_store.get_by_importance(
|
long_term_memories = await self.vector_store.get_by_importance(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
limit=max_long_term
|
limit=max_long_term,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新长期记忆访问记录
|
|
||||||
for memory in long_term_memories:
|
for memory in long_term_memories:
|
||||||
await self.vector_store.update_access(memory.id)
|
await self.vector_store.update_access(memory.id)
|
||||||
|
|
||||||
return short_term_memories, long_term_memories
|
return short_term_memories, long_term_memories
|
||||||
|
|
||||||
def format_context(
|
def format_context(
|
||||||
self,
|
self, short_term: List[MemoryItem], long_term: List[VectorMemory]
|
||||||
short_term: List[MemoryItem],
|
|
||||||
long_term: List[VectorMemory]
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""格式化上下文为文本。"""
|
lines: List[str] = []
|
||||||
context = ""
|
|
||||||
|
|
||||||
if long_term:
|
if long_term:
|
||||||
context += "## 相关历史记忆\n"
|
lines.append("## 相关历史记忆")
|
||||||
for i, memory in enumerate(long_term, 1):
|
for idx, memory in enumerate(long_term, 1):
|
||||||
context += f"{i}. {memory.content}\n"
|
lines.append(f"{idx}. {memory.content}")
|
||||||
context += "\n"
|
lines.append("")
|
||||||
|
|
||||||
if short_term:
|
if short_term:
|
||||||
context += "## 最近对话\n"
|
lines.append("## 最近对话")
|
||||||
for memory in short_term:
|
for item in short_term:
|
||||||
context += f"- {memory.content}\n"
|
role = str((item.metadata or {}).get("role") or "").strip().lower()
|
||||||
|
if role == "user":
|
||||||
|
prefix = "用户"
|
||||||
|
elif role == "assistant":
|
||||||
|
prefix = "助手"
|
||||||
|
else:
|
||||||
|
prefix = "对话"
|
||||||
|
lines.append(f"- {prefix}: {item.content}")
|
||||||
|
|
||||||
return context
|
return "\n".join(lines).strip()
|
||||||
|
|
||||||
async def list_long_term(
|
async def list_long_term(self, user_id: str, limit: int = 20) -> List[VectorMemory]:
|
||||||
self, user_id: str, limit: int = 20
|
|
||||||
) -> List[VectorMemory]:
|
|
||||||
memories = await self.vector_store.get_all(user_id)
|
memories = await self.vector_store.get_all(user_id)
|
||||||
memories.sort(key=lambda m: m.timestamp, reverse=True)
|
memories.sort(key=lambda m: m.timestamp, reverse=True)
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
@@ -422,21 +429,24 @@ class MemorySystem:
|
|||||||
importance: float = 0.8,
|
importance: float = 0.8,
|
||||||
metadata: Optional[Dict] = None,
|
metadata: Optional[Dict] = None,
|
||||||
) -> Optional[VectorMemory]:
|
) -> Optional[VectorMemory]:
|
||||||
memory_id = str(uuid.uuid4())
|
if await self._is_duplicate_long_term(user_id, content):
|
||||||
importance = self._normalize_importance(importance)
|
return None
|
||||||
embedding = await self._build_embedding(content)
|
|
||||||
|
|
||||||
|
memory_id = str(uuid.uuid4())
|
||||||
|
normalized_importance = self._normalize_importance(importance)
|
||||||
|
embedding = await self._build_embedding(content)
|
||||||
ok = await self._add_vector_memory(
|
ok = await self._add_vector_memory(
|
||||||
memory_id=memory_id,
|
memory_id=memory_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
content=content,
|
content=content,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
importance=importance,
|
importance=normalized_importance,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
if not ok:
|
if not ok:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
await self._trim_user_long_term(user_id)
|
||||||
return await self.get_long_term(user_id, memory_id)
|
return await self.get_long_term(user_id, memory_id)
|
||||||
|
|
||||||
async def search_long_term(
|
async def search_long_term(
|
||||||
@@ -457,10 +467,15 @@ class MemorySystem:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
all_memories = await self.vector_store.get_all(user_id)
|
all_memories = await self.vector_store.get_all(user_id)
|
||||||
query_lower = query.lower()
|
scored = []
|
||||||
matched = [m for m in all_memories if query_lower in m.content.lower()]
|
for memory in all_memories:
|
||||||
matched.sort(key=lambda m: (m.importance, m.timestamp), reverse=True)
|
score = self._simple_text_score(query, memory.content)
|
||||||
return matched[:limit]
|
if score <= 0:
|
||||||
|
continue
|
||||||
|
combined = score * 0.7 + memory.importance * 0.3
|
||||||
|
scored.append((combined, memory))
|
||||||
|
scored.sort(key=lambda x: (x[0], x[1].timestamp), reverse=True)
|
||||||
|
return [item[1] for item in scored[:limit]]
|
||||||
|
|
||||||
async def update_long_term(
|
async def update_long_term(
|
||||||
self,
|
self,
|
||||||
@@ -503,7 +518,6 @@ class MemorySystem:
|
|||||||
)
|
)
|
||||||
if not added:
|
if not added:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self.get_long_term(user_id, memory_id)
|
return await self.get_long_term(user_id, memory_id)
|
||||||
|
|
||||||
async def delete_long_term(self, user_id: str, memory_id: str) -> bool:
|
async def delete_long_term(self, user_id: str, memory_id: str) -> bool:
|
||||||
@@ -513,13 +527,10 @@ class MemorySystem:
|
|||||||
return await self.vector_store.delete(memory_id)
|
return await self.vector_store.delete(memory_id)
|
||||||
|
|
||||||
def clear_short_term(self, user_id: str):
|
def clear_short_term(self, user_id: str):
|
||||||
"""清除短期记忆。"""
|
|
||||||
self.short_term.clear(user_id)
|
self.short_term.clear(user_id)
|
||||||
|
|
||||||
async def clear_long_term(self, user_id: str):
|
async def clear_long_term(self, user_id: str):
|
||||||
"""清除长期记忆。"""
|
|
||||||
await self.vector_store.clear_user(user_id)
|
await self.vector_store.clear_user(user_id)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭记忆系统。"""
|
|
||||||
await self.vector_store.close()
|
await self.vector_store.close()
|
||||||
|
|||||||
@@ -1,30 +1,33 @@
|
|||||||
"""
|
"""
|
||||||
Anthropic Claude模型实现
|
Anthropic Claude model implementation.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, AsyncIterator
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import AsyncIterator, List, Optional
|
||||||
|
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
from ..base import BaseAIModel, Message, ModelConfig
|
from ..base import BaseAIModel, Message, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
class AnthropicModel(BaseAIModel):
|
class AnthropicModel(BaseAIModel):
|
||||||
"""Anthropic Claude模型实现"""
|
"""Anthropic Claude model implementation."""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.client = AsyncAnthropic(
|
self.client = AsyncAnthropic(
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.api_base,
|
base_url=config.api_base,
|
||||||
timeout=config.timeout
|
timeout=config.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""同步对话"""
|
|
||||||
# 分离system消息
|
|
||||||
system_message = None
|
system_message = None
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
|
||||||
@@ -32,10 +35,7 @@ class AnthropicModel(BaseAIModel):
|
|||||||
if msg.role == "system":
|
if msg.role == "system":
|
||||||
system_message = msg.content
|
system_message = msg.content
|
||||||
else:
|
else:
|
||||||
formatted_messages.append({
|
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content
|
|
||||||
})
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
@@ -43,46 +43,43 @@ class AnthropicModel(BaseAIModel):
|
|||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if system_message:
|
if system_message:
|
||||||
params["system"] = system_message
|
params["system"] = system_message
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
params["tools"] = tools
|
params["tools"] = tools
|
||||||
|
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
response = await self.client.messages.create(**params)
|
response = await self.client.messages.create(**params)
|
||||||
|
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if block.type == "text":
|
if block.type == "text":
|
||||||
content += block.text
|
content += block.text
|
||||||
elif block.type == "tool_use":
|
elif block.type == "tool_use":
|
||||||
tool_calls.append({
|
tool_calls.append(
|
||||||
"id": block.id,
|
{
|
||||||
"type": "function",
|
"id": block.id,
|
||||||
"function": {
|
"type": "function",
|
||||||
"name": block.name,
|
"function": {
|
||||||
"arguments": block.input
|
"name": block.name,
|
||||||
|
"arguments": block.input,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=content,
|
content=content,
|
||||||
tool_calls=tool_calls if tool_calls else None
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""流式对话"""
|
|
||||||
system_message = None
|
system_message = None
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
|
||||||
@@ -90,10 +87,7 @@ class AnthropicModel(BaseAIModel):
|
|||||||
if msg.role == "system":
|
if msg.role == "system":
|
||||||
system_message = msg.content
|
system_message = msg.content
|
||||||
else:
|
else:
|
||||||
formatted_messages.append({
|
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content
|
|
||||||
})
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
@@ -102,13 +96,10 @@ class AnthropicModel(BaseAIModel):
|
|||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if system_message:
|
if system_message:
|
||||||
params["system"] = system_message
|
params["system"] = system_message
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
params["tools"] = tools
|
params["tools"] = tools
|
||||||
|
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
async with self.client.messages.stream(**params) as stream:
|
async with self.client.messages.stream(**params) as stream:
|
||||||
@@ -116,5 +107,13 @@ class AnthropicModel(BaseAIModel):
|
|||||||
yield text
|
yield text
|
||||||
|
|
||||||
async def embed(self, text: str) -> List[float]:
|
async def embed(self, text: str) -> List[float]:
|
||||||
"""文本嵌入(Anthropic不直接提供,需要使用其他服务)"""
|
raise NotImplementedError(
|
||||||
raise NotImplementedError("Anthropic不提供嵌入API,请使用OpenAI或其他服务")
|
"Anthropic does not provide embedding API; use OpenAI-compatible embedding model."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
close_fn = getattr(self.client, "close", None)
|
||||||
|
if callable(close_fn):
|
||||||
|
maybe_awaitable = close_fn()
|
||||||
|
if hasattr(maybe_awaitable, "__await__"):
|
||||||
|
await maybe_awaitable
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self._embedding_token_limit: Optional[int] = None
|
self._embedding_token_limit: Optional[int] = None
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(
|
self._http_client = httpx.AsyncClient(
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||||
)
|
)
|
||||||
@@ -33,7 +33,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.api_base,
|
base_url=config.api_base,
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
http_client=http_client,
|
http_client=self._http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._supports_tools = False
|
self._supports_tools = False
|
||||||
@@ -590,3 +590,8 @@ class OpenAIModel(BaseAIModel):
|
|||||||
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
|
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
|
||||||
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
|
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Release network resources."""
|
||||||
|
if self._http_client:
|
||||||
|
await self._http_client.aclose()
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
"""Personality system for role-play profiles."""
|
"""Personality system for role-play profiles with scope priority."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -32,8 +34,6 @@ class PersonalityProfile:
|
|||||||
custom_instructions: str = ""
|
custom_instructions: str = ""
|
||||||
|
|
||||||
def to_system_prompt(self) -> str:
|
def to_system_prompt(self) -> str:
|
||||||
"""Build plain-text system prompt."""
|
|
||||||
|
|
||||||
traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly"
|
traits_text = ", ".join([t.value for t in self.traits]) if self.traits else "Friendly"
|
||||||
lines = [
|
lines = [
|
||||||
"Role Setting",
|
"Role Setting",
|
||||||
@@ -55,13 +55,29 @@ class PersonalityProfile:
|
|||||||
|
|
||||||
|
|
||||||
class PersonalitySystem:
|
class PersonalitySystem:
|
||||||
"""Personality management and persistence."""
|
"""Personality management, persistence and scope resolution."""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[Path] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path: Optional[Path] = None,
|
||||||
|
state_path: Optional[Path] = None,
|
||||||
|
):
|
||||||
self.config_path = config_path or Path("config/personalities.json")
|
self.config_path = config_path or Path("config/personalities.json")
|
||||||
|
self.state_path = state_path or self.config_path.with_name("personality_state.json")
|
||||||
|
|
||||||
self.personalities: Dict[str, PersonalityProfile] = {}
|
self.personalities: Dict[str, PersonalityProfile] = {}
|
||||||
self.current_personality: Optional[PersonalityProfile] = None
|
self._active_global_key: str = "default"
|
||||||
|
self._active_user_keys: Dict[str, str] = {}
|
||||||
|
self._active_group_keys: Dict[str, str] = {}
|
||||||
|
self._active_session_keys: Dict[str, str] = {}
|
||||||
|
|
||||||
self._load_personalities()
|
self._load_personalities()
|
||||||
|
self._load_state()
|
||||||
|
self._ensure_valid_active_keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_personality(self) -> Optional[PersonalityProfile]:
|
||||||
|
return self.personalities.get(self._active_global_key)
|
||||||
|
|
||||||
def _dict_to_profile(self, config: Dict) -> PersonalityProfile:
|
def _dict_to_profile(self, config: Dict) -> PersonalityProfile:
|
||||||
trait_names = config.get("traits", [])
|
trait_names = config.get("traits", [])
|
||||||
@@ -83,27 +99,30 @@ class PersonalitySystem:
|
|||||||
custom_instructions=str(config.get("custom_instructions", "")),
|
custom_instructions=str(config.get("custom_instructions", "")),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_personalities(self):
|
def _profile_to_dict(self, profile: PersonalityProfile) -> Dict:
|
||||||
"""Load personality config from disk or create defaults."""
|
return {
|
||||||
|
"name": profile.name,
|
||||||
|
"description": profile.description,
|
||||||
|
"traits": [trait.name for trait in profile.traits],
|
||||||
|
"speaking_style": profile.speaking_style,
|
||||||
|
"example_responses": profile.example_responses,
|
||||||
|
"custom_instructions": profile.custom_instructions,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_personalities(self):
|
||||||
if self.config_path.exists():
|
if self.config_path.exists():
|
||||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
for key, config in data.items():
|
for key, config in data.items():
|
||||||
self.personalities[key] = self._dict_to_profile(config)
|
if isinstance(config, dict):
|
||||||
|
self.personalities[key] = self._dict_to_profile(config)
|
||||||
|
|
||||||
if "default" in self.personalities:
|
if self.personalities:
|
||||||
self.current_personality = self.personalities["default"]
|
return
|
||||||
elif self.personalities:
|
|
||||||
first_key = next(iter(self.personalities.keys()))
|
|
||||||
self.current_personality = self.personalities[first_key]
|
|
||||||
return
|
|
||||||
|
|
||||||
self._create_default_personalities()
|
self._create_default_personalities()
|
||||||
|
|
||||||
def _create_default_personalities(self):
|
def _create_default_personalities(self):
|
||||||
"""Create and persist built-in default profiles."""
|
|
||||||
|
|
||||||
default = PersonalityProfile(
|
default = PersonalityProfile(
|
||||||
name="Assistant",
|
name="Assistant",
|
||||||
description="A friendly and practical AI assistant.",
|
description="A friendly and practical AI assistant.",
|
||||||
@@ -146,87 +165,200 @@ class PersonalitySystem:
|
|||||||
"tech_expert": tech_expert,
|
"tech_expert": tech_expert,
|
||||||
"creative": creative,
|
"creative": creative,
|
||||||
}
|
}
|
||||||
self.current_personality = default
|
self._active_global_key = "default"
|
||||||
self._save_personalities()
|
self._save_personalities()
|
||||||
|
|
||||||
def _save_personalities(self):
|
def _save_personalities(self):
|
||||||
"""Persist personalities to disk."""
|
|
||||||
|
|
||||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
data = {}
|
data = {
|
||||||
|
key: self._profile_to_dict(profile)
|
||||||
for key, profile in self.personalities.items():
|
for key, profile in self.personalities.items()
|
||||||
data[key] = {
|
}
|
||||||
"name": profile.name,
|
|
||||||
"description": profile.description,
|
|
||||||
"traits": [trait.name for trait in profile.traits],
|
|
||||||
"speaking_style": profile.speaking_style,
|
|
||||||
"example_responses": profile.example_responses,
|
|
||||||
"custom_instructions": profile.custom_instructions,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
def set_personality(self, key: str) -> bool:
|
def _load_state(self):
|
||||||
"""Switch active personality by key."""
|
if not self.state_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.state_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._active_global_key = str(data.get("global", self._active_global_key))
|
||||||
|
self._active_user_keys = {
|
||||||
|
str(k): str(v)
|
||||||
|
for k, v in (data.get("user") or {}).items()
|
||||||
|
if k is not None and v is not None
|
||||||
|
}
|
||||||
|
self._active_group_keys = {
|
||||||
|
str(k): str(v)
|
||||||
|
for k, v in (data.get("group") or {}).items()
|
||||||
|
if k is not None and v is not None
|
||||||
|
}
|
||||||
|
self._active_session_keys = {
|
||||||
|
str(k): str(v)
|
||||||
|
for k, v in (data.get("session") or {}).items()
|
||||||
|
if k is not None and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_state(self):
|
||||||
|
self.state_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
data = {
|
||||||
|
"global": self._active_global_key,
|
||||||
|
"user": self._active_user_keys,
|
||||||
|
"group": self._active_group_keys,
|
||||||
|
"session": self._active_session_keys,
|
||||||
|
}
|
||||||
|
with open(self.state_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def _ensure_valid_active_keys(self):
|
||||||
|
if self._active_global_key not in self.personalities:
|
||||||
|
if "default" in self.personalities:
|
||||||
|
self._active_global_key = "default"
|
||||||
|
elif self.personalities:
|
||||||
|
self._active_global_key = next(iter(self.personalities.keys()))
|
||||||
|
else:
|
||||||
|
self._active_global_key = ""
|
||||||
|
|
||||||
|
self._active_user_keys = {
|
||||||
|
scope_id: key
|
||||||
|
for scope_id, key in self._active_user_keys.items()
|
||||||
|
if key in self.personalities
|
||||||
|
}
|
||||||
|
self._active_group_keys = {
|
||||||
|
scope_id: key
|
||||||
|
for scope_id, key in self._active_group_keys.items()
|
||||||
|
if key in self.personalities
|
||||||
|
}
|
||||||
|
self._active_session_keys = {
|
||||||
|
scope_id: key
|
||||||
|
for scope_id, key in self._active_session_keys.items()
|
||||||
|
if key in self.personalities
|
||||||
|
}
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
|
def set_personality(
|
||||||
|
self, key: str, scope: str = "global", scope_id: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
if key not in self.personalities:
|
if key not in self.personalities:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.current_personality = self.personalities[key]
|
scope_normalized = (scope or "global").strip().lower()
|
||||||
|
if scope_normalized == "global":
|
||||||
|
self._active_global_key = key
|
||||||
|
self._save_state()
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not scope_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if scope_normalized == "user":
|
||||||
|
self._active_user_keys[scope_id] = key
|
||||||
|
elif scope_normalized == "group":
|
||||||
|
self._active_group_keys[scope_id] = key
|
||||||
|
elif scope_normalized == "session":
|
||||||
|
self._active_session_keys[scope_id] = key
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._save_state()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_active_personality(
|
||||||
"""Get current personality prompt."""
|
self,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> Optional[PersonalityProfile]:
|
||||||
|
# Priority: session > group > user > global > default
|
||||||
|
if session_id:
|
||||||
|
key = self._active_session_keys.get(session_id)
|
||||||
|
if key in self.personalities:
|
||||||
|
return self.personalities[key]
|
||||||
|
|
||||||
if self.current_personality:
|
if group_id:
|
||||||
return self.current_personality.to_system_prompt()
|
key = self._active_group_keys.get(group_id)
|
||||||
return ""
|
if key in self.personalities:
|
||||||
|
return self.personalities[key]
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
key = self._active_user_keys.get(user_id)
|
||||||
|
if key in self.personalities:
|
||||||
|
return self.personalities[key]
|
||||||
|
|
||||||
|
if self._active_global_key in self.personalities:
|
||||||
|
return self.personalities[self._active_global_key]
|
||||||
|
|
||||||
|
return self.personalities.get("default")
|
||||||
|
|
||||||
|
def get_system_prompt(
|
||||||
|
self,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
profile = self.get_active_personality(
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return profile.to_system_prompt() if profile else ""
|
||||||
|
|
||||||
def add_personality(self, key: str, profile: PersonalityProfile) -> bool:
|
def add_personality(self, key: str, profile: PersonalityProfile) -> bool:
|
||||||
"""Add a new personality profile."""
|
|
||||||
|
|
||||||
key = key.strip()
|
key = key.strip()
|
||||||
if not key:
|
if not key:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.personalities[key] = profile
|
self.personalities[key] = profile
|
||||||
if not self.current_personality:
|
if not self._active_global_key:
|
||||||
self.current_personality = profile
|
self._active_global_key = key
|
||||||
|
self._save_state()
|
||||||
self._save_personalities()
|
self._save_personalities()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_personality(self, key: str) -> bool:
|
def remove_personality(self, key: str) -> bool:
|
||||||
"""Remove a personality profile."""
|
|
||||||
|
|
||||||
if key == "default":
|
if key == "default":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if key not in self.personalities:
|
if key not in self.personalities:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
removed_profile = self.personalities[key]
|
|
||||||
del self.personalities[key]
|
del self.personalities[key]
|
||||||
|
if not self.personalities:
|
||||||
|
self._create_default_personalities()
|
||||||
|
|
||||||
if self.current_personality == removed_profile:
|
if self._active_global_key == key:
|
||||||
if "default" in self.personalities:
|
self._active_global_key = (
|
||||||
self.current_personality = self.personalities["default"]
|
"default"
|
||||||
elif self.personalities:
|
if "default" in self.personalities
|
||||||
first_key = next(iter(self.personalities.keys()))
|
else next(iter(self.personalities.keys()))
|
||||||
self.current_personality = self.personalities[first_key]
|
)
|
||||||
else:
|
|
||||||
self.current_personality = None
|
self._active_user_keys = {
|
||||||
|
scope_id: active_key
|
||||||
|
for scope_id, active_key in self._active_user_keys.items()
|
||||||
|
if active_key != key
|
||||||
|
}
|
||||||
|
self._active_group_keys = {
|
||||||
|
scope_id: active_key
|
||||||
|
for scope_id, active_key in self._active_group_keys.items()
|
||||||
|
if active_key != key
|
||||||
|
}
|
||||||
|
self._active_session_keys = {
|
||||||
|
scope_id: active_key
|
||||||
|
for scope_id, active_key in self._active_session_keys.items()
|
||||||
|
if active_key != key
|
||||||
|
}
|
||||||
|
|
||||||
self._save_personalities()
|
self._save_personalities()
|
||||||
|
self._save_state()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def list_personalities(self) -> List[str]:
|
def list_personalities(self) -> List[str]:
|
||||||
"""List all personality keys."""
|
|
||||||
|
|
||||||
return sorted(self.personalities.keys())
|
return sorted(self.personalities.keys())
|
||||||
|
|
||||||
def get_personality(self, key: str) -> Optional[PersonalityProfile]:
|
def get_personality(self, key: str) -> Optional[PersonalityProfile]:
|
||||||
"""Get personality by key."""
|
|
||||||
|
|
||||||
return self.personalities.get(key)
|
return self.personalities.get(key)
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
"""
|
|
||||||
Skills系统初始化
|
|
||||||
"""
|
|
||||||
from .base import Skill, SkillsManager, SkillMetadata, create_skill_template
|
|
||||||
|
|
||||||
__all__ = ['Skill', 'SkillsManager', 'SkillMetadata', 'create_skill_template']
|
|
||||||
@@ -1,750 +0,0 @@
|
|||||||
"""
|
|
||||||
Skills 系统 - 可扩展技能插件框架。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
||||||
import urllib.parse
|
|
||||||
import urllib.request
|
|
||||||
import zipfile
|
|
||||||
import os
|
|
||||||
import stat
|
|
||||||
|
|
||||||
from src.utils.logger import setup_logger
|
|
||||||
|
|
||||||
logger = setup_logger("SkillsSystem")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SkillMetadata:
|
|
||||||
"""技能元数据。"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
version: str
|
|
||||||
description: str
|
|
||||||
author: str
|
|
||||||
dependencies: List[str]
|
|
||||||
enabled: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class Skill:
|
|
||||||
"""技能基类。"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.metadata: Optional[SkillMetadata] = None
|
|
||||||
self.tools: Dict[str, Callable] = {}
|
|
||||||
self.manager = None
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""初始化技能。"""
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
"""清理技能。"""
|
|
||||||
|
|
||||||
def get_tools(self) -> Dict[str, Callable]:
|
|
||||||
"""获取技能提供的工具。"""
|
|
||||||
|
|
||||||
return self.tools
|
|
||||||
|
|
||||||
def register_tool(self, name: str, func: Callable):
|
|
||||||
"""注册工具。"""
|
|
||||||
|
|
||||||
self.tools[name] = func
|
|
||||||
|
|
||||||
|
|
||||||
class SkillsManager:
|
|
||||||
"""技能管理器。"""
|
|
||||||
|
|
||||||
_SKILL_KEY_PATTERN = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
_GITHUB_SHORTCUT_PATTERN = re.compile(
|
|
||||||
r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+(?:#[A-Za-z0-9_.-]+)?$"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, skills_dir: Path):
|
|
||||||
self.skills_dir = skills_dir
|
|
||||||
self.skills: Dict[str, Skill] = {}
|
|
||||||
self.skills_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
logger.info(f"✅ Skills 目录: {skills_dir}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def normalize_skill_key(cls, raw_name: str) -> str:
|
|
||||||
"""将任意输入规范化为可导入的 Python 包名。"""
|
|
||||||
|
|
||||||
key = raw_name.strip().lower().replace("-", "_").replace(" ", "_")
|
|
||||||
key = cls._SKILL_KEY_PATTERN.sub("_", key)
|
|
||||||
key = re.sub(r"_+", "_", key).strip("_")
|
|
||||||
|
|
||||||
if not key:
|
|
||||||
raise ValueError("技能名不能为空")
|
|
||||||
|
|
||||||
if key[0].isdigit():
|
|
||||||
key = f"skill_{key}"
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
def _get_skill_path(self, skill_name: str) -> Path:
|
|
||||||
return self.skills_dir / self.normalize_skill_key(skill_name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _on_rmtree_error(func, path, exc_info):
|
|
||||||
"""Handle Windows readonly/locked file deletion errors."""
|
|
||||||
try:
|
|
||||||
os.chmod(path, stat.S_IWRITE)
|
|
||||||
func(path)
|
|
||||||
except Exception:
|
|
||||||
# Keep original failure path for upper retry logic.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _read_metadata(self, skill_path: Path, fallback_name: str) -> Dict[str, Any]:
|
|
||||||
metadata_file = skill_path / "skill.json"
|
|
||||||
if metadata_file.exists():
|
|
||||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
else:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
metadata.setdefault("name", fallback_name)
|
|
||||||
metadata.setdefault("version", "1.0.0")
|
|
||||||
metadata.setdefault("description", f"{fallback_name} skill")
|
|
||||||
metadata.setdefault("author", "unknown")
|
|
||||||
metadata.setdefault("dependencies", [])
|
|
||||||
metadata.setdefault("enabled", True)
|
|
||||||
|
|
||||||
with open(metadata_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _ensure_skill_package_layout(self, skill_path: Path, skill_key: str):
|
|
||||||
"""确保技能目录满足运行最小结构。"""
|
|
||||||
|
|
||||||
skill_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
init_file = skill_path / "__init__.py"
|
|
||||||
if not init_file.exists():
|
|
||||||
init_file.write_text("", encoding="utf-8")
|
|
||||||
|
|
||||||
main_file = skill_path / "main.py"
|
|
||||||
if not main_file.exists():
|
|
||||||
template = f'''"""{skill_key} skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class {"".join(p.capitalize() for p in skill_key.split("_"))}Skill(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("ping", self.ping)
|
|
||||||
|
|
||||||
async def ping(self, text: str = "ok") -> str:
|
|
||||||
return text
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
'''
|
|
||||||
main_file.write_text(template, encoding="utf-8")
|
|
||||||
|
|
||||||
self._read_metadata(skill_path, skill_key)
|
|
||||||
|
|
||||||
async def load_skill(self, skill_name: str) -> bool:
|
|
||||||
"""加载技能。"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
skill_name = self.normalize_skill_key(skill_name)
|
|
||||||
|
|
||||||
if skill_name in self.skills:
|
|
||||||
logger.info(f"✅ 技能已加载: {skill_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
skill_path = self._get_skill_path(skill_name)
|
|
||||||
if not skill_path.exists():
|
|
||||||
logger.error(f"❌ 技能不存在: {skill_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
metadata_file = skill_path / "skill.json"
|
|
||||||
if not metadata_file.exists():
|
|
||||||
logger.error(f"❌ 技能元数据不存在: {skill_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
|
||||||
metadata_dict = json.load(f)
|
|
||||||
|
|
||||||
metadata = SkillMetadata(**metadata_dict)
|
|
||||||
|
|
||||||
if not metadata.enabled:
|
|
||||||
logger.info(f"⏸️ 技能已禁用: {skill_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
module_path = f"skills.{skill_name}.main"
|
|
||||||
importlib.invalidate_caches()
|
|
||||||
|
|
||||||
try:
|
|
||||||
old_dont_write = sys.dont_write_bytecode
|
|
||||||
sys.dont_write_bytecode = True
|
|
||||||
try:
|
|
||||||
if module_path in sys.modules:
|
|
||||||
module = importlib.reload(sys.modules[module_path])
|
|
||||||
else:
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
finally:
|
|
||||||
sys.dont_write_bytecode = old_dont_write
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"❌ 无法导入技能模块 {module_path}: {exc}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
skill_class = None
|
|
||||||
for _, obj in inspect.getmembers(module):
|
|
||||||
if inspect.isclass(obj) and issubclass(obj, Skill) and obj != Skill:
|
|
||||||
skill_class = obj
|
|
||||||
break
|
|
||||||
|
|
||||||
if not skill_class:
|
|
||||||
logger.error(f"❌ 技能中未找到 Skill 子类: {skill_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
skill = skill_class()
|
|
||||||
skill.metadata = metadata
|
|
||||||
skill.manager = self
|
|
||||||
|
|
||||||
await skill.initialize()
|
|
||||||
self.skills[skill_name] = skill
|
|
||||||
|
|
||||||
logger.info(f"✅ 加载技能: {skill_name} v{metadata.version}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"❌ 加载技能失败 {skill_name}: {exc}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_all_skills(self):
|
|
||||||
"""加载所有可用技能。"""
|
|
||||||
|
|
||||||
for skill_name in self.list_available_skills():
|
|
||||||
await self.load_skill(skill_name)
|
|
||||||
|
|
||||||
async def unload_skill(self, skill_name: str) -> bool:
|
|
||||||
"""仅卸载内存中的技能。"""
|
|
||||||
|
|
||||||
skill_name = self.normalize_skill_key(skill_name)
|
|
||||||
if skill_name not in self.skills:
|
|
||||||
return False
|
|
||||||
|
|
||||||
skill = self.skills[skill_name]
|
|
||||||
await skill.cleanup()
|
|
||||||
del self.skills[skill_name]
|
|
||||||
|
|
||||||
sys.modules.pop(f"skills.{skill_name}.main", None)
|
|
||||||
sys.modules.pop(f"skills.{skill_name}", None)
|
|
||||||
importlib.invalidate_caches()
|
|
||||||
|
|
||||||
logger.info(f"✅ 卸载技能: {skill_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def uninstall_skill(self, skill_name: str, delete_files: bool = True) -> bool:
|
|
||||||
"""卸载技能并可选删除文件。"""
|
|
||||||
|
|
||||||
skill_name = self.normalize_skill_key(skill_name)
|
|
||||||
|
|
||||||
if skill_name in self.skills:
|
|
||||||
await self.unload_skill(skill_name)
|
|
||||||
|
|
||||||
if not delete_files:
|
|
||||||
return True
|
|
||||||
|
|
||||||
skill_path = self._get_skill_path(skill_name)
|
|
||||||
if not skill_path.exists():
|
|
||||||
return False
|
|
||||||
|
|
||||||
removed = False
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
shutil.rmtree(skill_path, ignore_errors=False, onerror=self._on_rmtree_error)
|
|
||||||
except PermissionError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not skill_path.exists():
|
|
||||||
removed = True
|
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
if not removed:
|
|
||||||
try:
|
|
||||||
metadata_file = skill_path / "skill.json"
|
|
||||||
metadata = {}
|
|
||||||
if metadata_file.exists():
|
|
||||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
metadata["enabled"] = False
|
|
||||||
with open(metadata_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.warning(f"⚠️ 删除目录失败,已软卸载技能: {skill_name}")
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
importlib.invalidate_caches()
|
|
||||||
logger.info(f"✅ 删除技能目录: {skill_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_skill(self, skill_name: str) -> Optional[Skill]:
|
|
||||||
"""获取已加载技能实例。"""
|
|
||||||
|
|
||||||
skill_name = self.normalize_skill_key(skill_name)
|
|
||||||
return self.skills.get(skill_name)
|
|
||||||
|
|
||||||
def list_skills(self) -> List[str]:
|
|
||||||
"""列出已加载技能。"""
|
|
||||||
|
|
||||||
return sorted(self.skills.keys())
|
|
||||||
|
|
||||||
def list_available_skills(self) -> List[str]:
|
|
||||||
"""列出可加载技能目录。"""
|
|
||||||
|
|
||||||
if not self.skills_dir.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
available: List[str] = []
|
|
||||||
for skill_dir in self.skills_dir.iterdir():
|
|
||||||
if not skill_dir.is_dir() or skill_dir.name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (skill_dir / "skill.json").exists() and (skill_dir / "main.py").exists():
|
|
||||||
try:
|
|
||||||
with open(skill_dir / "skill.json", "r", encoding="utf-8") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
if not metadata.get("enabled", True):
|
|
||||||
continue
|
|
||||||
available.append(self.normalize_skill_key(skill_dir.name))
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return sorted(set(available))
|
|
||||||
|
|
||||||
def get_all_tools(self) -> Dict[str, Callable]:
|
|
||||||
"""获取全部技能工具。"""
|
|
||||||
|
|
||||||
all_tools: Dict[str, Callable] = {}
|
|
||||||
for skill_name, skill in self.skills.items():
|
|
||||||
for tool_name, tool_func in skill.get_tools().items():
|
|
||||||
all_tools[f"{skill_name}.{tool_name}"] = tool_func
|
|
||||||
return all_tools
|
|
||||||
|
|
||||||
async def reload_skill(self, skill_name: str) -> bool:
|
|
||||||
"""重载技能。"""
|
|
||||||
|
|
||||||
skill_name = self.normalize_skill_key(skill_name)
|
|
||||||
if skill_name in self.skills:
|
|
||||||
await self.unload_skill(skill_name)
|
|
||||||
return await self.load_skill(skill_name)
|
|
||||||
|
|
||||||
def _parse_github_repo_source(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
) -> Optional[Tuple[str, str, str, Optional[str]]]:
|
|
||||||
"""解析 GitHub 仓库 URL,返回 owner/repo/branch/subpath。"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed = urllib.parse.urlparse(source.strip())
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if parsed.scheme not in {"http", "https"}:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if parsed.netloc.lower() != "github.com":
|
|
||||||
return None
|
|
||||||
|
|
||||||
parts = [part for part in parsed.path.split("/") if part]
|
|
||||||
if len(parts) < 2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
owner = parts[0]
|
|
||||||
repo = parts[1]
|
|
||||||
if repo.endswith(".git"):
|
|
||||||
repo = repo[:-4]
|
|
||||||
|
|
||||||
if not owner or not repo:
|
|
||||||
return None
|
|
||||||
|
|
||||||
branch = "main"
|
|
||||||
subpath: Optional[str] = None
|
|
||||||
|
|
||||||
if len(parts) >= 4 and parts[2] == "tree":
|
|
||||||
branch = parts[3]
|
|
||||||
if len(parts) > 4:
|
|
||||||
subpath = "/".join(parts[4:])
|
|
||||||
elif len(parts) > 2:
|
|
||||||
# 不支持 /issues、/blob 等深链接
|
|
||||||
return None
|
|
||||||
|
|
||||||
return owner, repo, branch, subpath
|
|
||||||
|
|
||||||
def _resolve_network_source(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
|
||||||
"""将网络来源解析为下载 URL,并返回安装提示信息。"""
|
|
||||||
|
|
||||||
source = source.strip()
|
|
||||||
|
|
||||||
github_repo = self._parse_github_repo_source(source)
|
|
||||||
if github_repo:
|
|
||||||
owner, repo, branch, subpath = github_repo
|
|
||||||
codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}"
|
|
||||||
return codeload_url, self.normalize_skill_key(repo), subpath
|
|
||||||
|
|
||||||
if source.startswith(("http://", "https://")):
|
|
||||||
return source, None, None
|
|
||||||
|
|
||||||
if self._GITHUB_SHORTCUT_PATTERN.match(source):
|
|
||||||
repo_ref, _, branch = source.partition("#")
|
|
||||||
owner, repo = repo_ref.split("/", 1)
|
|
||||||
if repo.endswith(".git"):
|
|
||||||
repo = repo[:-4]
|
|
||||||
branch = branch or "main"
|
|
||||||
codeload_url = f"https://codeload.github.com/{owner}/{repo}/zip/refs/heads/{branch}"
|
|
||||||
return codeload_url, self.normalize_skill_key(repo), None
|
|
||||||
|
|
||||||
raise ValueError("source 必须是 URL 或 owner/repo[#branch]")
|
|
||||||
|
|
||||||
def _download_zip(self, url: str, output_zip: Path):
|
|
||||||
"""下载 zip 包到本地。"""
|
|
||||||
|
|
||||||
req = urllib.request.Request(url, headers={"User-Agent": "QQBot-Skills/1.0"})
|
|
||||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
||||||
data = resp.read()
|
|
||||||
output_zip.write_bytes(data)
|
|
||||||
|
|
||||||
def _find_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]:
|
|
||||||
"""在目录中扫描技能候选项。"""
|
|
||||||
|
|
||||||
candidates: List[Tuple[str, Path]] = []
|
|
||||||
for metadata_file in root_dir.rglob("skill.json"):
|
|
||||||
candidate_dir = metadata_file.parent
|
|
||||||
if not (candidate_dir / "main.py").exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
raw_name = str(metadata.get("name") or candidate_dir.name)
|
|
||||||
except Exception:
|
|
||||||
raw_name = candidate_dir.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
skill_key = self.normalize_skill_key(raw_name)
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
candidates.append((skill_key, candidate_dir))
|
|
||||||
|
|
||||||
uniq: Dict[str, Path] = {}
|
|
||||||
for key, path in candidates:
|
|
||||||
uniq[key] = path
|
|
||||||
|
|
||||||
return sorted(uniq.items(), key=lambda x: x[0])
|
|
||||||
|
|
||||||
def _find_codex_skill_candidates(self, root_dir: Path) -> List[Tuple[str, Path]]:
|
|
||||||
"""在目录中扫描仅包含 SKILL.md 的候选项。"""
|
|
||||||
|
|
||||||
candidates: List[Tuple[str, Path]] = []
|
|
||||||
for skill_doc in root_dir.rglob("SKILL.md"):
|
|
||||||
candidate_dir = skill_doc.parent
|
|
||||||
if (candidate_dir / "skill.json").exists() and (candidate_dir / "main.py").exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
skill_key = self.normalize_skill_key(candidate_dir.name)
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
candidates.append((skill_key, candidate_dir))
|
|
||||||
|
|
||||||
uniq: Dict[str, Path] = {}
|
|
||||||
for key, path in candidates:
|
|
||||||
uniq[key] = path
|
|
||||||
|
|
||||||
return sorted(uniq.items(), key=lambda x: x[0])
|
|
||||||
|
|
||||||
def _scope_extract_root(self, extract_root: Path, subpath: Optional[str]) -> Path:
|
|
||||||
"""将解压目录缩小到 repo 子路径(若指定)。"""
|
|
||||||
|
|
||||||
if not subpath:
|
|
||||||
return extract_root
|
|
||||||
|
|
||||||
clean_subpath = Path(subpath.strip("/\\"))
|
|
||||||
if str(clean_subpath) == ".":
|
|
||||||
return extract_root
|
|
||||||
|
|
||||||
direct = extract_root / clean_subpath
|
|
||||||
if direct.exists():
|
|
||||||
return direct
|
|
||||||
|
|
||||||
for child in extract_root.iterdir():
|
|
||||||
candidate = child / clean_subpath
|
|
||||||
if candidate.exists():
|
|
||||||
return candidate
|
|
||||||
|
|
||||||
return extract_root
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_markdown_title(markdown_content: str) -> str:
|
|
||||||
for line in markdown_content.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped.startswith("#"):
|
|
||||||
return stripped.lstrip("#").strip()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def _install_codex_skill_adapter(self, source_dir: Path, target_path: Path, skill_key: str):
|
|
||||||
"""将 Codex SKILL.md 转换为可加载的项目技能。"""
|
|
||||||
|
|
||||||
skill_doc = source_dir / "SKILL.md"
|
|
||||||
if not skill_doc.exists():
|
|
||||||
raise FileNotFoundError(f"SKILL.md not found in {source_dir}")
|
|
||||||
|
|
||||||
content = skill_doc.read_text(encoding="utf-8")
|
|
||||||
title = self._extract_markdown_title(content)
|
|
||||||
description = (
|
|
||||||
f"Imported from SKILL.md: {title}" if title else f"Imported from SKILL.md ({skill_key})"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
(target_path / "SKILL.md").write_text(content, encoding="utf-8")
|
|
||||||
(target_path / "__init__.py").write_text("", encoding="utf-8")
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"name": skill_key,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": description,
|
|
||||||
"author": "imported",
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": True,
|
|
||||||
}
|
|
||||||
with open(target_path / "skill.json", "w", encoding="utf-8") as f:
|
|
||||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill"
|
|
||||||
main_code = f'''"""{skill_key} adapter skill (generated from SKILL.md)"""
|
|
||||||
from pathlib import Path
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class {class_name}(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("read_skill_doc", self.read_skill_doc)
|
|
||||||
|
|
||||||
async def read_skill_doc(self) -> str:
|
|
||||||
skill_doc = Path(__file__).with_name("SKILL.md")
|
|
||||||
if not skill_doc.exists():
|
|
||||||
return "SKILL.md not found."
|
|
||||||
return skill_doc.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
'''
|
|
||||||
(target_path / "main.py").write_text(main_code, encoding="utf-8")
|
|
||||||
|
|
||||||
def install_skill_from_source(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
skill_name: Optional[str] = None,
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
"""从网络或本地源安装技能目录(仅落盘,不自动加载)。"""
|
|
||||||
|
|
||||||
desired_key = self.normalize_skill_key(skill_name) if skill_name else None
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory(prefix="qqbot_skill_") as tmp:
|
|
||||||
tmp_dir = Path(tmp)
|
|
||||||
extract_root: Optional[Path] = None
|
|
||||||
source_hint_key: Optional[str] = None
|
|
||||||
source_subpath: Optional[str] = None
|
|
||||||
|
|
||||||
source_path = Path(source)
|
|
||||||
if source_path.exists():
|
|
||||||
if source_path.is_dir():
|
|
||||||
extract_root = source_path
|
|
||||||
elif source_path.is_file() and source_path.suffix.lower() == ".zip":
|
|
||||||
with zipfile.ZipFile(source_path, "r") as zf:
|
|
||||||
zf.extractall(tmp_dir / "extract")
|
|
||||||
extract_root = tmp_dir / "extract"
|
|
||||||
else:
|
|
||||||
return False, "本地 source 仅支持目录或 zip 文件"
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
url, source_hint_key, source_subpath = self._resolve_network_source(source)
|
|
||||||
except ValueError as exc:
|
|
||||||
return False, str(exc)
|
|
||||||
|
|
||||||
download_zip = tmp_dir / "download.zip"
|
|
||||||
try:
|
|
||||||
self._download_zip(url, download_zip)
|
|
||||||
except Exception as exc:
|
|
||||||
# GitHub 简写默认 main 失败时尝试 master
|
|
||||||
if "codeload.github.com" in url and url.endswith("/main"):
|
|
||||||
fallback = url[:-4] + "master"
|
|
||||||
try:
|
|
||||||
self._download_zip(fallback, download_zip)
|
|
||||||
except Exception:
|
|
||||||
return False, f"下载技能失败: {exc}"
|
|
||||||
else:
|
|
||||||
return False, f"下载技能失败: {exc}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
with zipfile.ZipFile(download_zip, "r") as zf:
|
|
||||||
zf.extractall(tmp_dir / "extract")
|
|
||||||
except Exception as exc:
|
|
||||||
return False, f"解压技能失败: {exc}"
|
|
||||||
|
|
||||||
extract_root = tmp_dir / "extract"
|
|
||||||
|
|
||||||
extract_root = self._scope_extract_root(extract_root, source_subpath)
|
|
||||||
|
|
||||||
candidates = self._find_skill_candidates(extract_root)
|
|
||||||
use_codex_adapter = False
|
|
||||||
|
|
||||||
selected_key: Optional[str] = None
|
|
||||||
selected_path: Optional[Path] = None
|
|
||||||
|
|
||||||
if candidates:
|
|
||||||
if desired_key:
|
|
||||||
for key, path in candidates:
|
|
||||||
if key == desired_key:
|
|
||||||
selected_key, selected_path = key, path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not selected_path:
|
|
||||||
names = ", ".join([k for k, _ in candidates])
|
|
||||||
return False, f"源中未找到技能 {desired_key},可选: {names}"
|
|
||||||
else:
|
|
||||||
if len(candidates) > 1:
|
|
||||||
names = ", ".join([k for k, _ in candidates])
|
|
||||||
return False, f"检测到多个技能,请指定 skill_name。可选: {names}"
|
|
||||||
selected_key, selected_path = candidates[0]
|
|
||||||
else:
|
|
||||||
codex_candidates = self._find_codex_skill_candidates(extract_root)
|
|
||||||
if not codex_candidates:
|
|
||||||
return False, "未找到可安装技能(需包含 skill.json 与 main.py,或 SKILL.md)"
|
|
||||||
|
|
||||||
use_codex_adapter = True
|
|
||||||
if desired_key:
|
|
||||||
for key, path in codex_candidates:
|
|
||||||
if key == desired_key:
|
|
||||||
selected_key, selected_path = key, path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not selected_path:
|
|
||||||
if len(codex_candidates) == 1:
|
|
||||||
selected_key = desired_key
|
|
||||||
selected_path = codex_candidates[0][1]
|
|
||||||
else:
|
|
||||||
names = ", ".join([k for k, _ in codex_candidates])
|
|
||||||
return False, f"源中未找到技能 {desired_key},可选: {names}"
|
|
||||||
else:
|
|
||||||
if source_hint_key:
|
|
||||||
for key, path in codex_candidates:
|
|
||||||
if key == source_hint_key:
|
|
||||||
selected_key, selected_path = key, path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not selected_path:
|
|
||||||
if len(codex_candidates) > 1:
|
|
||||||
names = ", ".join([k for k, _ in codex_candidates])
|
|
||||||
return False, f"检测到多个 SKILL.md 技能,请指定 skill_name。可选: {names}"
|
|
||||||
selected_key, selected_path = codex_candidates[0]
|
|
||||||
if source_hint_key:
|
|
||||||
selected_key = source_hint_key
|
|
||||||
|
|
||||||
assert selected_key is not None and selected_path is not None
|
|
||||||
target_path = self._get_skill_path(selected_key)
|
|
||||||
|
|
||||||
if target_path.exists():
|
|
||||||
if not overwrite:
|
|
||||||
return False, f"技能已存在: {selected_key}"
|
|
||||||
shutil.rmtree(target_path)
|
|
||||||
|
|
||||||
if use_codex_adapter:
|
|
||||||
self._install_codex_skill_adapter(selected_path, target_path, selected_key)
|
|
||||||
else:
|
|
||||||
shutil.copytree(
|
|
||||||
selected_path,
|
|
||||||
target_path,
|
|
||||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".git", ".github"),
|
|
||||||
)
|
|
||||||
self._ensure_skill_package_layout(target_path, selected_key)
|
|
||||||
|
|
||||||
importlib.invalidate_caches()
|
|
||||||
|
|
||||||
logger.info(f"✅ 安装技能成功: {selected_key} <- {source}")
|
|
||||||
return True, selected_key
|
|
||||||
|
|
||||||
|
|
||||||
def create_skill_template(
|
|
||||||
skill_name: str,
|
|
||||||
output_dir: Path,
|
|
||||||
description: str = "技能描述",
|
|
||||||
author: str = "QQBot",
|
|
||||||
):
|
|
||||||
"""创建技能模板。"""
|
|
||||||
|
|
||||||
skill_key = SkillsManager.normalize_skill_key(skill_name)
|
|
||||||
skill_dir = output_dir / skill_key
|
|
||||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"name": skill_key,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": description,
|
|
||||||
"author": author,
|
|
||||||
"dependencies": [],
|
|
||||||
"enabled": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(skill_dir / "skill.json", "w", encoding="utf-8") as f:
|
|
||||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
class_name = "".join(word.capitalize() for word in skill_key.split("_")) + "Skill"
|
|
||||||
main_code = f'''"""{skill_key} skill"""
|
|
||||||
from src.ai.skills.base import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class {class_name}(Skill):
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool("example_tool", self.example_tool)
|
|
||||||
|
|
||||||
async def example_tool(self, text: str) -> str:
|
|
||||||
return f"{skill_key} 收到: {{text}}"
|
|
||||||
|
|
||||||
async def cleanup(self):
|
|
||||||
pass
|
|
||||||
'''
|
|
||||||
|
|
||||||
with open(skill_dir / "main.py", "w", encoding="utf-8") as f:
|
|
||||||
f.write(main_code)
|
|
||||||
|
|
||||||
with open(skill_dir / "__init__.py", "w", encoding="utf-8") as f:
|
|
||||||
f.write("")
|
|
||||||
|
|
||||||
readme = f"""# {skill_key}
|
|
||||||
|
|
||||||
## 描述
|
|
||||||
{description}
|
|
||||||
|
|
||||||
## 工具
|
|
||||||
- example_tool(text)
|
|
||||||
"""
|
|
||||||
|
|
||||||
with open(skill_dir / "README.md", "w", encoding="utf-8") as f:
|
|
||||||
f.write(readme)
|
|
||||||
|
|
||||||
logger.info(f"✅ 创建技能模板: {skill_dir}")
|
|
||||||
@@ -1,327 +0,0 @@
|
|||||||
"""
|
|
||||||
长任务管理器 - 处理需要多步骤的复杂任务
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
from typing import List, Dict, Optional, Callable, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
|
||||||
"""任务状态"""
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
PAUSED = "paused"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskStep:
|
|
||||||
"""任务步骤"""
|
|
||||||
step_id: str
|
|
||||||
description: str
|
|
||||||
action: str
|
|
||||||
params: Dict[str, Any]
|
|
||||||
status: TaskStatus = TaskStatus.PENDING
|
|
||||||
result: Optional[Any] = None
|
|
||||||
error: Optional[str] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
return {
|
|
||||||
'step_id': self.step_id,
|
|
||||||
'description': self.description,
|
|
||||||
'action': self.action,
|
|
||||||
'params': self.params,
|
|
||||||
'status': self.status.value,
|
|
||||||
'result': self.result,
|
|
||||||
'error': self.error,
|
|
||||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
|
||||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LongTask:
|
|
||||||
"""长任务"""
|
|
||||||
task_id: str
|
|
||||||
user_id: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
steps: List[TaskStep] = field(default_factory=list)
|
|
||||||
status: TaskStatus = TaskStatus.PENDING
|
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
progress: float = 0.0
|
|
||||||
metadata: Dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
return {
|
|
||||||
'task_id': self.task_id,
|
|
||||||
'user_id': self.user_id,
|
|
||||||
'title': self.title,
|
|
||||||
'description': self.description,
|
|
||||||
'steps': [step.to_dict() for step in self.steps],
|
|
||||||
'status': self.status.value,
|
|
||||||
'created_at': self.created_at.isoformat(),
|
|
||||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
|
||||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
|
||||||
'progress': self.progress,
|
|
||||||
'metadata': self.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> 'LongTask':
|
|
||||||
steps = [
|
|
||||||
TaskStep(
|
|
||||||
step_id=s['step_id'],
|
|
||||||
description=s['description'],
|
|
||||||
action=s['action'],
|
|
||||||
params=s['params'],
|
|
||||||
status=TaskStatus(s['status']),
|
|
||||||
result=s.get('result'),
|
|
||||||
error=s.get('error'),
|
|
||||||
started_at=datetime.fromisoformat(s['started_at']) if s.get('started_at') else None,
|
|
||||||
completed_at=datetime.fromisoformat(s['completed_at']) if s.get('completed_at') else None
|
|
||||||
)
|
|
||||||
for s in data['steps']
|
|
||||||
]
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
task_id=data['task_id'],
|
|
||||||
user_id=data['user_id'],
|
|
||||||
title=data['title'],
|
|
||||||
description=data['description'],
|
|
||||||
steps=steps,
|
|
||||||
status=TaskStatus(data['status']),
|
|
||||||
created_at=datetime.fromisoformat(data['created_at']),
|
|
||||||
started_at=datetime.fromisoformat(data['started_at']) if data.get('started_at') else None,
|
|
||||||
completed_at=datetime.fromisoformat(data['completed_at']) if data.get('completed_at') else None,
|
|
||||||
progress=data.get('progress', 0.0),
|
|
||||||
metadata=data.get('metadata', {})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LongTaskManager:
|
|
||||||
"""长任务管理器"""
|
|
||||||
|
|
||||||
def __init__(self, storage_path: Path):
|
|
||||||
self.storage_path = storage_path
|
|
||||||
self.tasks: Dict[str, LongTask] = {}
|
|
||||||
self.action_handlers: Dict[str, Callable] = {}
|
|
||||||
self.running_tasks: Dict[str, asyncio.Task] = {}
|
|
||||||
self._load()
|
|
||||||
|
|
||||||
def _load(self):
|
|
||||||
"""加载任务"""
|
|
||||||
if self.storage_path.exists():
|
|
||||||
with open(self.storage_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for task_data in data:
|
|
||||||
task = LongTask.from_dict(task_data)
|
|
||||||
self.tasks[task.task_id] = task
|
|
||||||
|
|
||||||
def _save(self):
|
|
||||||
"""保存任务"""
|
|
||||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
data = [task.to_dict() for task in self.tasks.values()]
|
|
||||||
with open(self.storage_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
def register_action(self, action_name: str, handler: Callable):
|
|
||||||
"""注册动作处理器"""
|
|
||||||
self.action_handlers[action_name] = handler
|
|
||||||
|
|
||||||
def create_task(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
title: str,
|
|
||||||
description: str,
|
|
||||||
steps: List[Dict],
|
|
||||||
metadata: Optional[Dict] = None
|
|
||||||
) -> str:
|
|
||||||
"""创建任务"""
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
task_steps = [
|
|
||||||
TaskStep(
|
|
||||||
step_id=str(uuid.uuid4()),
|
|
||||||
description=step['description'],
|
|
||||||
action=step['action'],
|
|
||||||
params=step.get('params', {})
|
|
||||||
)
|
|
||||||
for step in steps
|
|
||||||
]
|
|
||||||
|
|
||||||
task = LongTask(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
steps=task_steps,
|
|
||||||
metadata=metadata or {}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tasks[task_id] = task
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
async def execute_task(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
progress_callback: Optional[Callable[[str, float, str], None]] = None
|
|
||||||
) -> bool:
|
|
||||||
"""执行任务"""
|
|
||||||
if task_id not in self.tasks:
|
|
||||||
return False
|
|
||||||
|
|
||||||
task = self.tasks[task_id]
|
|
||||||
|
|
||||||
if task.status == TaskStatus.RUNNING:
|
|
||||||
return False
|
|
||||||
|
|
||||||
task.status = TaskStatus.RUNNING
|
|
||||||
task.started_at = datetime.now()
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
try:
|
|
||||||
total_steps = len(task.steps)
|
|
||||||
|
|
||||||
for i, step in enumerate(task.steps):
|
|
||||||
# 检查是否被取消
|
|
||||||
if task.status == TaskStatus.CANCELLED:
|
|
||||||
break
|
|
||||||
|
|
||||||
step.status = TaskStatus.RUNNING
|
|
||||||
step.started_at = datetime.now()
|
|
||||||
|
|
||||||
# 执行步骤
|
|
||||||
try:
|
|
||||||
handler = self.action_handlers.get(step.action)
|
|
||||||
if not handler:
|
|
||||||
raise ValueError(f"未找到动作处理器: {step.action}")
|
|
||||||
|
|
||||||
result = await handler(**step.params)
|
|
||||||
step.result = result
|
|
||||||
step.status = TaskStatus.COMPLETED
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
step.error = str(e)
|
|
||||||
step.status = TaskStatus.FAILED
|
|
||||||
task.status = TaskStatus.FAILED
|
|
||||||
self._save()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
step.completed_at = datetime.now()
|
|
||||||
|
|
||||||
# 更新进度
|
|
||||||
task.progress = (i + 1) / total_steps
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
await progress_callback(
|
|
||||||
task_id,
|
|
||||||
task.progress,
|
|
||||||
f"完成步骤 {i+1}/{total_steps}: {step.description}"
|
|
||||||
)
|
|
||||||
|
|
||||||
task.status = TaskStatus.COMPLETED
|
|
||||||
task.completed_at = datetime.now()
|
|
||||||
task.progress = 1.0
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
task.status = TaskStatus.FAILED
|
|
||||||
self._save()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def start_task(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
progress_callback: Optional[Callable[[str, float, str], None]] = None
|
|
||||||
):
|
|
||||||
"""启动任务(异步)"""
|
|
||||||
if task_id in self.running_tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def run():
|
|
||||||
try:
|
|
||||||
await self.execute_task(task_id, progress_callback)
|
|
||||||
finally:
|
|
||||||
if task_id in self.running_tasks:
|
|
||||||
del self.running_tasks[task_id]
|
|
||||||
|
|
||||||
self.running_tasks[task_id] = asyncio.create_task(run())
|
|
||||||
|
|
||||||
def pause_task(self, task_id: str) -> bool:
|
|
||||||
"""暂停任务"""
|
|
||||||
if task_id not in self.tasks:
|
|
||||||
return False
|
|
||||||
|
|
||||||
task = self.tasks[task_id]
|
|
||||||
if task.status == TaskStatus.RUNNING:
|
|
||||||
task.status = TaskStatus.PAUSED
|
|
||||||
self._save()
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def cancel_task(self, task_id: str) -> bool:
|
|
||||||
"""取消任务"""
|
|
||||||
if task_id not in self.tasks:
|
|
||||||
return False
|
|
||||||
|
|
||||||
task = self.tasks[task_id]
|
|
||||||
task.status = TaskStatus.CANCELLED
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
# 取消正在运行的任务
|
|
||||||
if task_id in self.running_tasks:
|
|
||||||
self.running_tasks[task_id].cancel()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_task(self, task_id: str) -> Optional[LongTask]:
|
|
||||||
"""获取任务"""
|
|
||||||
return self.tasks.get(task_id)
|
|
||||||
|
|
||||||
def get_user_tasks(self, user_id: str) -> List[LongTask]:
|
|
||||||
"""获取用户的所有任务"""
|
|
||||||
return [
|
|
||||||
task for task in self.tasks.values()
|
|
||||||
if task.user_id == user_id
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
|
||||||
"""获取任务状态"""
|
|
||||||
task = self.get_task(task_id)
|
|
||||||
if not task:
|
|
||||||
return None
|
|
||||||
|
|
||||||
completed_steps = sum(1 for step in task.steps if step.status == TaskStatus.COMPLETED)
|
|
||||||
total_steps = len(task.steps)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'task_id': task.task_id,
|
|
||||||
'title': task.title,
|
|
||||||
'status': task.status.value,
|
|
||||||
'progress': task.progress,
|
|
||||||
'completed_steps': completed_steps,
|
|
||||||
'total_steps': total_steps,
|
|
||||||
'current_step': next(
|
|
||||||
(step.description for step in task.steps if step.status == TaskStatus.RUNNING),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,63 +1,77 @@
|
|||||||
"""
|
"""
|
||||||
JSON文件存储实现(向后兼容)
|
JSON-backed vector store implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .base import VectorStore, VectorMemory
|
|
||||||
|
from .base import VectorMemory, VectorStore
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger('JSONStore')
|
logger = setup_logger("JSONStore")
|
||||||
|
|
||||||
|
|
||||||
class JSONVectorStore(VectorStore):
|
class JSONVectorStore(VectorStore):
|
||||||
"""JSON文件存储实现(向后兼容旧版本)"""
|
"""JSON file storage implementation."""
|
||||||
|
|
||||||
def __init__(self, storage_path: Path):
|
def __init__(self, storage_path: Path):
|
||||||
"""初始化JSON存储"""
|
|
||||||
self.storage_path = storage_path
|
self.storage_path = storage_path
|
||||||
self.memories: Dict[str, List[VectorMemory]] = {} # user_id -> List[VectorMemory]
|
self.memories: Dict[str, List[VectorMemory]] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
self._load()
|
self._load()
|
||||||
logger.info(f"✅ JSON存储初始化: {storage_path}")
|
logger.info(f"JSON storage initialized: {storage_path}")
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
"""加载记忆"""
|
if not self.storage_path.exists():
|
||||||
if self.storage_path.exists():
|
return
|
||||||
try:
|
|
||||||
with open(self.storage_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
for user_id, items in data.items():
|
|
||||||
self.memories[user_id] = []
|
|
||||||
for item in items:
|
|
||||||
# 兼容旧格式
|
|
||||||
if 'id' not in item:
|
|
||||||
item['id'] = str(uuid.uuid4())
|
|
||||||
|
|
||||||
memory = VectorMemory.from_dict(item)
|
|
||||||
self.memories[user_id].append(memory)
|
|
||||||
|
|
||||||
logger.info(f"加载了 {sum(len(v) for v in self.memories.values())} 条记忆")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载记忆失败: {e}")
|
|
||||||
self.memories = {}
|
|
||||||
|
|
||||||
def _save(self):
|
|
||||||
"""保存记忆"""
|
|
||||||
try:
|
try:
|
||||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
with open(self.storage_path, "r", encoding="utf-8") as f:
|
||||||
data = {
|
data = json.load(f)
|
||||||
user_id: [memory.to_dict() for memory in memories]
|
except Exception as exc:
|
||||||
for user_id, memories in self.memories.items()
|
logger.error(f"Failed to load memory file: {exc}")
|
||||||
}
|
self.memories = {}
|
||||||
|
return
|
||||||
|
|
||||||
with open(self.storage_path, 'w', encoding='utf-8') as f:
|
loaded = 0
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
memories: Dict[str, List[VectorMemory]] = {}
|
||||||
except Exception as e:
|
for user_id, items in (data or {}).items():
|
||||||
logger.error(f"保存记忆失败: {e}")
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized: List[VectorMemory] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if "id" not in item:
|
||||||
|
item["id"] = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
normalized.append(VectorMemory.from_dict(item))
|
||||||
|
loaded += 1
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
memories[str(user_id)] = normalized
|
||||||
|
|
||||||
|
self.memories = memories
|
||||||
|
logger.info(f"Loaded {loaded} memories from JSON store")
|
||||||
|
|
||||||
|
def _save_locked(self):
|
||||||
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
data = {
|
||||||
|
user_id: [memory.to_dict() for memory in user_memories]
|
||||||
|
for user_id, user_memories in self.memories.items()
|
||||||
|
}
|
||||||
|
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
@@ -66,9 +80,8 @@ class JSONVectorStore(VectorStore):
|
|||||||
content: str,
|
content: str,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
importance: float,
|
importance: float,
|
||||||
metadata: Optional[Dict] = None
|
metadata: Optional[Dict] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""添加记忆"""
|
|
||||||
try:
|
try:
|
||||||
memory = VectorMemory(
|
memory = VectorMemory(
|
||||||
id=id,
|
id=id,
|
||||||
@@ -79,19 +92,14 @@ class JSONVectorStore(VectorStore):
|
|||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
access_count=0,
|
access_count=0,
|
||||||
last_access=None
|
last_access=None,
|
||||||
)
|
)
|
||||||
|
async with self._lock:
|
||||||
if user_id not in self.memories:
|
self.memories.setdefault(user_id, []).append(memory)
|
||||||
self.memories[user_id] = []
|
self._save_locked()
|
||||||
|
|
||||||
self.memories[user_id].append(memory)
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
logger.debug(f"添加记忆: {id} (用户: {user_id})")
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"添加记忆失败: {e}")
|
logger.error(f"Failed to add memory: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
@@ -99,100 +107,93 @@ class JSONVectorStore(VectorStore):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
query_embedding: List[float],
|
query_embedding: List[float],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
min_importance: float = 0.3
|
min_importance: float = 0.3,
|
||||||
) -> List[VectorMemory]:
|
) -> List[VectorMemory]:
|
||||||
"""搜索相似记忆"""
|
async with self._lock:
|
||||||
if user_id not in self.memories:
|
source = list(self.memories.get(user_id, []))
|
||||||
|
|
||||||
|
candidates = [m for m in source if m.importance >= min_importance]
|
||||||
|
if not candidates:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
memories = self.memories[user_id]
|
|
||||||
|
|
||||||
# 过滤重要性
|
|
||||||
memories = [m for m in memories if m.importance >= min_importance]
|
|
||||||
|
|
||||||
if not memories:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 使用向量相似度排序
|
|
||||||
scored_memories = []
|
scored_memories = []
|
||||||
for memory in memories:
|
for memory in candidates:
|
||||||
if memory.embedding:
|
if memory.embedding:
|
||||||
similarity = self._cosine_similarity(query_embedding, memory.embedding)
|
similarity = self._cosine_similarity(query_embedding, memory.embedding)
|
||||||
scored_memories.append((similarity, memory))
|
if similarity is not None:
|
||||||
|
scored_memories.append((similarity, memory))
|
||||||
|
|
||||||
if not scored_memories:
|
if not scored_memories:
|
||||||
# 如果没有嵌入向量,按重要性排序
|
|
||||||
return await self.get_by_importance(user_id, limit, min_importance)
|
return await self.get_by_importance(user_id, limit, min_importance)
|
||||||
|
|
||||||
scored_memories.sort(reverse=True, key=lambda x: x[0])
|
scored_memories.sort(reverse=True, key=lambda x: x[0])
|
||||||
return [m for _, m in scored_memories[:limit]]
|
return [memory for _, memory in scored_memories[:limit]]
|
||||||
|
|
||||||
async def get_by_importance(
|
async def get_by_importance(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
min_importance: float = 0.3
|
min_importance: float = 0.3,
|
||||||
) -> List[VectorMemory]:
|
) -> List[VectorMemory]:
|
||||||
"""按重要性获取记忆"""
|
async with self._lock:
|
||||||
if user_id not in self.memories:
|
source = list(self.memories.get(user_id, []))
|
||||||
return []
|
|
||||||
|
|
||||||
memories = [m for m in self.memories[user_id] if m.importance >= min_importance]
|
memories = [m for m in source if m.importance >= min_importance]
|
||||||
memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True)
|
memories.sort(key=lambda m: (m.importance, m.timestamp), reverse=True)
|
||||||
return memories[:limit]
|
return memories[:limit]
|
||||||
|
|
||||||
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
@staticmethod
|
||||||
"""计算余弦相似度"""
|
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> Optional[float]:
|
||||||
vec1 = np.array(vec1)
|
arr1 = np.array(vec1, dtype=float)
|
||||||
vec2 = np.array(vec2)
|
arr2 = np.array(vec2, dtype=float)
|
||||||
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
denom = np.linalg.norm(arr1) * np.linalg.norm(arr2)
|
||||||
|
if denom == 0:
|
||||||
|
return None
|
||||||
|
return float(np.dot(arr1, arr2) / denom)
|
||||||
|
|
||||||
async def update_access(self, memory_id: str) -> bool:
|
async def update_access(self, memory_id: str) -> bool:
|
||||||
"""更新访问记录"""
|
|
||||||
try:
|
try:
|
||||||
for memories in self.memories.values():
|
async with self._lock:
|
||||||
for memory in memories:
|
for memories in self.memories.values():
|
||||||
if memory.id == memory_id:
|
for memory in memories:
|
||||||
memory.access_count += 1
|
if memory.id == memory_id:
|
||||||
memory.last_access = datetime.now()
|
memory.access_count += 1
|
||||||
self._save()
|
memory.last_access = datetime.now()
|
||||||
return True
|
self._save_locked()
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"更新访问记录失败: {e}")
|
logger.error(f"Failed to update access: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def delete(self, memory_id: str) -> bool:
|
async def delete(self, memory_id: str) -> bool:
|
||||||
"""删除记忆"""
|
|
||||||
try:
|
try:
|
||||||
for user_id, memories in self.memories.items():
|
async with self._lock:
|
||||||
for i, memory in enumerate(memories):
|
for user_id, memories in self.memories.items():
|
||||||
if memory.id == memory_id:
|
for idx, memory in enumerate(memories):
|
||||||
del self.memories[user_id][i]
|
if memory.id == memory_id:
|
||||||
self._save()
|
del self.memories[user_id][idx]
|
||||||
return True
|
self._save_locked()
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"删除记忆失败: {e}")
|
logger.error(f"Failed to delete memory: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_all(self, user_id: str) -> List[VectorMemory]:
|
async def get_all(self, user_id: str) -> List[VectorMemory]:
|
||||||
"""获取用户所有记忆"""
|
async with self._lock:
|
||||||
return self.memories.get(user_id, [])
|
return list(self.memories.get(user_id, []))
|
||||||
|
|
||||||
async def clear_user(self, user_id: str) -> bool:
|
async def clear_user(self, user_id: str) -> bool:
|
||||||
"""清除用户所有记忆"""
|
|
||||||
try:
|
try:
|
||||||
if user_id in self.memories:
|
async with self._lock:
|
||||||
del self.memories[user_id]
|
self.memories.pop(user_id, None)
|
||||||
self._save()
|
self._save_locked()
|
||||||
logger.info(f"清除用户记忆: {user_id}")
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"清除用户记忆失败: {e}")
|
logger.error(f"Failed to clear user memories: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭连接"""
|
async with self._lock:
|
||||||
self._save()
|
self._save_locked()
|
||||||
logger.info("JSON存储已关闭")
|
|
||||||
|
|||||||
@@ -1,105 +1,71 @@
|
|||||||
"""
|
"""
|
||||||
QQ机器人主程序
|
QQ bot application entry module.
|
||||||
基于官方SDK: https://github.com/tencent-connect/botpy
|
|
||||||
官方文档: https://bot.q.qq.com/wiki/develop/api-v2/
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
from botpy.message import Message
|
from botpy.message import Message
|
||||||
from src.core.config import Config
|
|
||||||
from src.utils.logger import setup_logger
|
|
||||||
from src.handlers.message_handler_ai import MessageHandler
|
|
||||||
|
|
||||||
logger = setup_logger('QQBot')
|
from src.core.config import Config
|
||||||
|
from src.handlers.message_handler_ai import MessageHandler
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger("QQBot")
|
||||||
|
|
||||||
|
|
||||||
def build_intents() -> botpy.Intents:
|
def build_intents() -> botpy.Intents:
|
||||||
"""
|
|
||||||
构建最小可用的 intents。
|
|
||||||
|
|
||||||
- public_guild_messages: 频道公域 @机器人 消息
|
|
||||||
- public_messages: 群聊@ + C2C 私聊(好友单聊)消息
|
|
||||||
"""
|
|
||||||
intents = botpy.Intents.none()
|
intents = botpy.Intents.none()
|
||||||
intents.public_guild_messages = True
|
intents.public_guild_messages = True
|
||||||
|
|
||||||
# 新版 botpy 中,QQ 群聊@ / C2C 私聊依赖 public_messages(GROUP_AND_C2C_EVENT)。
|
|
||||||
if hasattr(intents, "public_messages"):
|
if hasattr(intents, "public_messages"):
|
||||||
intents.public_messages = True
|
intents.public_messages = True
|
||||||
logger.info("✅ 已启用 public_messages(支持群聊@与 C2C 私聊)")
|
logger.info("Enabled public_messages for group/C2C events")
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ 当前 botpy 不支持 public_messages,可能无法接收 C2C 私聊事件")
|
logger.warning("Current botpy version does not expose public_messages")
|
||||||
|
|
||||||
return intents
|
return intents
|
||||||
|
|
||||||
|
|
||||||
class MyClient(botpy.Client):
|
class MyClient(botpy.Client):
|
||||||
"""QQ机器人客户端"""
|
"""QQ bot client wrapper."""
|
||||||
|
|
||||||
def __init__(self, intents: botpy.Intents):
|
def __init__(self, intents: botpy.Intents):
|
||||||
super().__init__(intents=intents)
|
super().__init__(intents=intents)
|
||||||
self.message_handler = MessageHandler(self)
|
self.message_handler = MessageHandler(self)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""机器人启动完成事件"""
|
logger.info(f"Bot is ready: {self.robot.name} (ID: {self.robot.id})")
|
||||||
logger.info(f"🤖 机器人已启动: {self.robot.name} (ID: {self.robot.id})")
|
|
||||||
|
|
||||||
async def on_at_message_create(self, message: Message):
|
async def on_at_message_create(self, message: Message):
|
||||||
"""处理@机器人的消息(频道公域消息)"""
|
|
||||||
await self.message_handler.handle_at_message(message)
|
await self.message_handler.handle_at_message(message)
|
||||||
|
|
||||||
async def on_message_create(self, message: Message):
|
async def on_message_create(self, message: Message):
|
||||||
"""处理普通消息(需要私域权限)"""
|
|
||||||
await self.message_handler.handle_at_message(message)
|
await self.message_handler.handle_at_message(message)
|
||||||
|
|
||||||
async def on_direct_message_create(self, message: Message):
|
async def on_direct_message_create(self, message: Message):
|
||||||
"""处理私信消息"""
|
|
||||||
await self.message_handler.handle_at_message(message)
|
await self.message_handler.handle_at_message(message)
|
||||||
|
|
||||||
async def on_group_at_message_create(self, message: Message):
|
async def on_group_at_message_create(self, message: Message):
|
||||||
"""处理群聊@消息"""
|
|
||||||
await self.message_handler.handle_at_message(message)
|
await self.message_handler.handle_at_message(message)
|
||||||
|
|
||||||
async def on_c2c_message_create(self, message: Message):
|
async def on_c2c_message_create(self, message: Message):
|
||||||
"""处理C2C消息(单聊)"""
|
|
||||||
await self.message_handler.handle_at_message(message)
|
await self.message_handler.handle_at_message(message)
|
||||||
|
|
||||||
async def on_guild_create(self, guild):
|
|
||||||
"""机器人加入频道事件"""
|
|
||||||
logger.info(f"➕ 加入频道: {guild.name} (ID: {guild.id})")
|
|
||||||
|
|
||||||
async def on_guild_delete(self, guild):
|
|
||||||
"""机器人离开频道事件"""
|
|
||||||
logger.info(f"➖ 离开频道: {guild.name} (ID: {guild.id})")
|
|
||||||
|
|
||||||
async def on_error(self, error):
|
async def on_error(self, error):
|
||||||
"""错误处理"""
|
logger.error(f"Bot error: {error}", exc_info=True)
|
||||||
logger.error(f"❌ 发生错误: {error}")
|
|
||||||
|
async def on_close(self):
|
||||||
|
ai_client = getattr(self.message_handler, "ai_client", None)
|
||||||
|
if ai_client:
|
||||||
|
try:
|
||||||
|
await ai_client.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to close AI resources cleanly: {exc}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
Config.validate()
|
||||||
try:
|
intents = build_intents()
|
||||||
# 验证配置
|
client = MyClient(intents=intents)
|
||||||
Config.validate()
|
client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET)
|
||||||
logger.info("✅ 配置验证通过")
|
|
||||||
|
|
||||||
# 创建机器人实例 - 使用最小可用 intents,避免 disallowed intents(4014)
|
|
||||||
intents = build_intents()
|
|
||||||
logger.info("✅ Intents 配置完成(最小权限模式)")
|
|
||||||
|
|
||||||
client = MyClient(intents=intents)
|
|
||||||
|
|
||||||
# 启动机器人
|
|
||||||
logger.info("🚀 正在启动机器人...")
|
|
||||||
client.run(appid=Config.BOT_APPID, secret=Config.BOT_SECRET)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"❌ 配置错误: {e}")
|
|
||||||
logger.error("请检查 .env 文件配置")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 启动失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1,72 +1,123 @@
|
|||||||
"""
|
"""
|
||||||
QQ机器人配置管理模块
|
Centralized project configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional, Set
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# 加载环境变量
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def _read_env(name: str, default: Optional[str] = None) -> Optional[str]:
|
def _read_env(name: str, default: Optional[str] = None) -> Optional[str]:
|
||||||
"""
|
|
||||||
读取并清洗环境变量。
|
|
||||||
- 去除首尾空白
|
|
||||||
- 空字符串视为未设置
|
|
||||||
- 以 # 开头的值视为注释占位,视为未设置
|
|
||||||
"""
|
|
||||||
value = os.getenv(name)
|
value = os.getenv(name)
|
||||||
if value is None:
|
if value is None:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
if not value or value.startswith('#'):
|
if not value or value.startswith("#"):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _read_bool(name: str, default: bool) -> bool:
|
||||||
|
raw = _read_env(name, None)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
lowered = raw.lower()
|
||||||
|
if lowered in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if lowered in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _read_int(name: str, default: int) -> int:
|
||||||
|
raw = _read_env(name, None)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(raw)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _read_float(name: str, default: float) -> float:
|
||||||
|
raw = _read_env(name, None)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _read_csv_set(name: str) -> Set[str]:
|
||||||
|
raw = _read_env(name, "") or ""
|
||||||
|
return {item.strip() for item in raw.split(",") if item.strip()}
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""机器人配置类"""
|
"""Application runtime configuration."""
|
||||||
|
|
||||||
# 机器人基本信息
|
BOT_APPID = _read_env("BOT_APPID", "") or ""
|
||||||
BOT_APPID = _read_env('BOT_APPID', '') or ''
|
BOT_SECRET = _read_env("BOT_SECRET", "") or ""
|
||||||
BOT_SECRET = _read_env('BOT_SECRET', '') or ''
|
|
||||||
|
|
||||||
# 日志配置
|
ENV = _read_env("APP_ENV", "dev") or "dev"
|
||||||
LOG_LEVEL = _read_env('LOG_LEVEL', 'INFO') or 'INFO'
|
LOG_LEVEL = _read_env("LOG_LEVEL", "INFO") or "INFO"
|
||||||
|
LOG_FORMAT = (_read_env("LOG_FORMAT", "text") or "text").lower()
|
||||||
|
|
||||||
# 沙箱模式
|
SANDBOX_MODE = _read_bool("SANDBOX_MODE", False)
|
||||||
SANDBOX_MODE = os.getenv('SANDBOX_MODE', 'False').lower() == 'true'
|
|
||||||
|
|
||||||
# AI配置
|
AI_PROVIDER = _read_env("AI_PROVIDER", "openai") or "openai"
|
||||||
AI_PROVIDER = _read_env('AI_PROVIDER', 'openai') or 'openai'
|
AI_MODEL = _read_env("AI_MODEL", "gpt-4") or "gpt-4"
|
||||||
AI_MODEL = _read_env('AI_MODEL', 'gpt-4') or 'gpt-4'
|
AI_API_KEY = _read_env("AI_API_KEY", "") or ""
|
||||||
AI_API_KEY = _read_env('AI_API_KEY', '') or ''
|
AI_API_BASE = _read_env("AI_API_BASE", None)
|
||||||
AI_API_BASE = _read_env('AI_API_BASE', None)
|
|
||||||
|
|
||||||
# AI嵌入模型配置(用于RAG)
|
AI_EMBED_PROVIDER = _read_env("AI_EMBED_PROVIDER", "openai") or "openai"
|
||||||
AI_EMBED_PROVIDER = _read_env('AI_EMBED_PROVIDER', 'openai') or 'openai'
|
AI_EMBED_MODEL = (
|
||||||
AI_EMBED_MODEL = _read_env('AI_EMBED_MODEL', 'text-embedding-3-small') or 'text-embedding-3-small'
|
_read_env("AI_EMBED_MODEL", "text-embedding-3-small")
|
||||||
AI_EMBED_API_KEY = _read_env('AI_EMBED_API_KEY', None) # 留空则使用 AI_API_KEY
|
or "text-embedding-3-small"
|
||||||
AI_EMBED_API_BASE = _read_env('AI_EMBED_API_BASE', None) # 留空则使用 AI_API_BASE
|
)
|
||||||
|
AI_EMBED_API_KEY = _read_env("AI_EMBED_API_KEY", None)
|
||||||
|
AI_EMBED_API_BASE = _read_env("AI_EMBED_API_BASE", None)
|
||||||
|
|
||||||
# 向量数据库配置
|
AI_USE_VECTOR_DB = _read_bool("AI_USE_VECTOR_DB", True)
|
||||||
AI_USE_VECTOR_DB = os.getenv('AI_USE_VECTOR_DB', 'true').lower() == 'true'
|
AI_USE_QUERY_EMBEDDING = _read_bool("AI_USE_QUERY_EMBEDDING", False)
|
||||||
|
|
||||||
|
# `user`: one memory bucket per user
|
||||||
|
# `session`: memory bucket scoped to chat session (user + channel/group)
|
||||||
|
AI_MEMORY_SCOPE = (_read_env("AI_MEMORY_SCOPE", "session") or "session").lower()
|
||||||
|
|
||||||
|
AI_CHAT_RETRIES = max(0, _read_int("AI_CHAT_RETRIES", 1))
|
||||||
|
AI_CHAT_RETRY_BACKOFF_SECONDS = max(
|
||||||
|
0.0, _read_float("AI_CHAT_RETRY_BACKOFF_SECONDS", 0.8)
|
||||||
|
)
|
||||||
|
|
||||||
|
MESSAGE_DEDUP_SECONDS = max(1, _read_int("MESSAGE_DEDUP_SECONDS", 30))
|
||||||
|
MESSAGE_DEDUP_MAX_SIZE = max(128, _read_int("MESSAGE_DEDUP_MAX_SIZE", 4096))
|
||||||
|
|
||||||
|
BOT_ADMIN_IDS = _read_csv_set("BOT_ADMIN_IDS")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls):
|
def is_admin(cls, user_id: Optional[str]) -> bool:
|
||||||
"""验证配置是否完整"""
|
if not cls.BOT_ADMIN_IDS:
|
||||||
|
return True
|
||||||
|
if not user_id:
|
||||||
|
return False
|
||||||
|
return str(user_id) in cls.BOT_ADMIN_IDS
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls) -> bool:
|
||||||
if not cls.BOT_APPID:
|
if not cls.BOT_APPID:
|
||||||
raise ValueError("BOT_APPID 未配置")
|
raise ValueError("BOT_APPID 未配置")
|
||||||
if not cls.BOT_SECRET:
|
if not cls.BOT_SECRET:
|
||||||
raise ValueError("BOT_SECRET 未配置")
|
raise ValueError("BOT_SECRET 未配置")
|
||||||
|
|
||||||
# AI配置验证(可选)
|
if cls.AI_MEMORY_SCOPE not in {"user", "session"}:
|
||||||
if cls.AI_API_KEY:
|
raise ValueError("AI_MEMORY_SCOPE 仅支持 user 或 session")
|
||||||
print(f"✅ AI配置: {cls.AI_PROVIDER}/{cls.AI_MODEL}")
|
|
||||||
else:
|
|
||||||
print("⚠️ AI_API_KEY 未设置,AI功能将不可用")
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""Message handlers."""
|
||||||
消息处理模块
|
|
||||||
"""
|
|
||||||
from .message_handler import MessageHandler
|
|
||||||
|
|
||||||
__all__ = ['MessageHandler']
|
from .message_handler_ai import MessageHandler
|
||||||
|
|
||||||
|
__all__ = ["MessageHandler"]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,63 +1,101 @@
|
|||||||
"""
|
"""
|
||||||
日志配置模块
|
Logging helpers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name='QQBot', level=None):
|
class _JsonFormatter(logging.Formatter):
|
||||||
"""
|
"""Simple JSON formatter for structured logs."""
|
||||||
设置日志记录器
|
|
||||||
|
|
||||||
Args:
|
_BASE_FIELDS = {
|
||||||
name: 日志记录器名称
|
"name",
|
||||||
level: 日志级别,默认从环境变量读取
|
"msg",
|
||||||
|
"args",
|
||||||
|
"levelname",
|
||||||
|
"levelno",
|
||||||
|
"pathname",
|
||||||
|
"filename",
|
||||||
|
"module",
|
||||||
|
"exc_info",
|
||||||
|
"exc_text",
|
||||||
|
"stack_info",
|
||||||
|
"lineno",
|
||||||
|
"funcName",
|
||||||
|
"created",
|
||||||
|
"msecs",
|
||||||
|
"relativeCreated",
|
||||||
|
"thread",
|
||||||
|
"threadName",
|
||||||
|
"processName",
|
||||||
|
"process",
|
||||||
|
}
|
||||||
|
|
||||||
Returns:
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
logging.Logger: 配置好的日志记录器
|
payload: Dict[str, Any] = {
|
||||||
"""
|
"ts": datetime.now(timezone.utc).isoformat(),
|
||||||
# 创建logs目录
|
"level": record.levelname,
|
||||||
log_dir = Path(__file__).parent.parent.parent / 'logs'
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
"file": record.filename,
|
||||||
|
"line": record.lineno,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key in self._BASE_FIELDS or key.startswith("_"):
|
||||||
|
continue
|
||||||
|
payload[key] = value
|
||||||
|
|
||||||
|
if record.exc_info:
|
||||||
|
payload["exc"] = self.formatException(record.exc_info)
|
||||||
|
|
||||||
|
return json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(name: str = "QQBot", level: str | None = None) -> logging.Logger:
|
||||||
|
log_dir = Path(__file__).parent.parent.parent / "logs"
|
||||||
log_dir.mkdir(exist_ok=True)
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 设置日志级别
|
|
||||||
if level is None:
|
if level is None:
|
||||||
level = os.getenv('LOG_LEVEL', 'INFO')
|
level = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
log_format = (os.getenv("LOG_FORMAT", "text") or "text").lower()
|
||||||
|
|
||||||
# 创建日志记录器
|
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(getattr(logging, level.upper()))
|
logger.setLevel(getattr(logging, str(level).upper(), logging.INFO))
|
||||||
# 避免向 root logger 传播导致重复输出
|
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
# 避免重复添加处理器
|
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
# 控制台处理器
|
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
|
file_handler = logging.FileHandler(log_dir / "bot.log", encoding="utf-8")
|
||||||
|
|
||||||
|
if log_format == "json":
|
||||||
|
formatter = _JsonFormatter()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
else:
|
||||||
|
console_fmt = logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
file_fmt = logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
console_handler.setFormatter(console_fmt)
|
||||||
|
file_handler.setFormatter(file_fmt)
|
||||||
|
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
console_format = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
console_handler.setFormatter(console_format)
|
|
||||||
|
|
||||||
# 文件处理器
|
|
||||||
file_handler = logging.FileHandler(
|
|
||||||
log_dir / 'bot.log',
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
file_format = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_format)
|
|
||||||
|
|
||||||
# 添加处理器
|
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|||||||
7
tests/conftest.py
Normal file
7
tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
628
tests/test_ai.py
628
tests/test_ai.py
@@ -1,628 +0,0 @@
|
|||||||
"""AI integration tests."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import stat
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
project_root = Path(__file__).parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from src.ai import AIClient
|
|
||||||
from src.ai.base import ModelConfig, ModelProvider
|
|
||||||
from src.ai.memory import MemorySystem
|
|
||||||
from src.ai.skills import SkillsManager, create_skill_template
|
|
||||||
from src.handlers.message_handler_ai import MessageHandler
|
|
||||||
|
|
||||||
load_dotenv(project_root / ".env")
|
|
||||||
|
|
||||||
|
|
||||||
TEST_DATA_DIR = Path("data/ai_test")
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_rmtree(path: Path):
|
|
||||||
if not path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
def _onerror(func, target, exc_info):
|
|
||||||
try:
|
|
||||||
os.chmod(target, stat.S_IWRITE)
|
|
||||||
func(target)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
shutil.rmtree(path, onerror=_onerror)
|
|
||||||
return
|
|
||||||
except PermissionError:
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_unlink(path: Path):
|
|
||||||
if not path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
path.unlink()
|
|
||||||
return
|
|
||||||
except PermissionError:
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_env(name: str, default=None):
|
|
||||||
value = os.getenv(name)
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
|
|
||||||
value = value.strip()
|
|
||||||
if not value or value.startswith("#"):
|
|
||||||
return default
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def get_ai_config() -> ModelConfig:
|
|
||||||
provider_map = {
|
|
||||||
"openai": ModelProvider.OPENAI,
|
|
||||||
"anthropic": ModelProvider.ANTHROPIC,
|
|
||||||
"deepseek": ModelProvider.DEEPSEEK,
|
|
||||||
"qwen": ModelProvider.QWEN,
|
|
||||||
"siliconflow": ModelProvider.OPENAI,
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_str = (_read_env("AI_PROVIDER", "openai") or "openai").lower()
|
|
||||||
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
|
|
||||||
|
|
||||||
return ModelConfig(
|
|
||||||
provider=provider,
|
|
||||||
model_name=_read_env("AI_MODEL", "gpt-3.5-turbo") or "gpt-3.5-turbo",
|
|
||||||
api_key=_read_env("AI_API_KEY", "") or "",
|
|
||||||
api_base=_read_env("AI_API_BASE"),
|
|
||||||
temperature=0.7,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_embed_config() -> ModelConfig:
|
|
||||||
provider_map = {
|
|
||||||
"openai": ModelProvider.OPENAI,
|
|
||||||
"anthropic": ModelProvider.ANTHROPIC,
|
|
||||||
"deepseek": ModelProvider.DEEPSEEK,
|
|
||||||
"qwen": ModelProvider.QWEN,
|
|
||||||
"siliconflow": ModelProvider.OPENAI,
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_str = (_read_env("AI_EMBED_PROVIDER", "openai") or "openai").lower()
|
|
||||||
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
|
|
||||||
|
|
||||||
api_key = _read_env("AI_EMBED_API_KEY") or _read_env("AI_API_KEY", "") or ""
|
|
||||||
api_base = _read_env("AI_EMBED_API_BASE") or _read_env("AI_API_BASE")
|
|
||||||
|
|
||||||
return ModelConfig(
|
|
||||||
provider=provider,
|
|
||||||
model_name=_read_env("AI_EMBED_MODEL", "text-embedding-3-small")
|
|
||||||
or "text-embedding-3-small",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMessage:
|
|
||||||
def __init__(self, content: str):
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
self.content = content
|
|
||||||
self.author = SimpleNamespace(id="test_user")
|
|
||||||
self.replies = []
|
|
||||||
|
|
||||||
async def reply(self, content: str):
|
|
||||||
self.replies.append(content)
|
|
||||||
|
|
||||||
|
|
||||||
def make_handler() -> MessageHandler:
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
|
||||||
handler = MessageHandler(fake_bot)
|
|
||||||
handler.ai_client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
|
||||||
handler.skills_manager = SkillsManager(Path("skills"))
|
|
||||||
handler.model_profiles_path = TEST_DATA_DIR / "models_test.json"
|
|
||||||
TEST_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
_safe_unlink(handler.model_profiles_path)
|
|
||||||
handler._ai_initialized = True
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_basic_chat():
|
|
||||||
print("=== test_basic_chat ===")
|
|
||||||
|
|
||||||
config = get_ai_config()
|
|
||||||
if not config.api_key:
|
|
||||||
print("skip: AI_API_KEY not configured")
|
|
||||||
return
|
|
||||||
|
|
||||||
embed_config = get_embed_config()
|
|
||||||
client = AIClient(config, embed_config=embed_config, data_dir=TEST_DATA_DIR)
|
|
||||||
response = await client.chat(
|
|
||||||
user_id="test_user",
|
|
||||||
user_message="你好,请介绍一下你自己",
|
|
||||||
use_memory=False,
|
|
||||||
use_tools=False,
|
|
||||||
)
|
|
||||||
assert response
|
|
||||||
print("ok: chat response length", len(response))
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_memory():
|
|
||||||
print("=== test_memory ===")
|
|
||||||
|
|
||||||
config = get_ai_config()
|
|
||||||
if not config.api_key:
|
|
||||||
print("skip: AI_API_KEY not configured")
|
|
||||||
return
|
|
||||||
|
|
||||||
client = AIClient(config, embed_config=get_embed_config(), data_dir=TEST_DATA_DIR)
|
|
||||||
await client.chat(user_id="test_user", user_message="鎴戝彨寮犱笁", use_memory=True)
|
|
||||||
await client.chat(user_id="test_user", user_message="what is my name", use_memory=True)
|
|
||||||
|
|
||||||
short_term, long_term = await client.memory.get_context("test_user")
|
|
||||||
assert len(short_term) >= 2
|
|
||||||
# 重要性改为模型评估后,是否入长期记忆取决于模型打分,不再固定断言数量。
|
|
||||||
assert isinstance(long_term, list)
|
|
||||||
print("ok: memory short/long", len(short_term), len(long_term))
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_personality():
|
|
||||||
print("=== test_personality ===")
|
|
||||||
|
|
||||||
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
|
||||||
names = client.list_personalities()
|
|
||||||
assert names
|
|
||||||
assert client.set_personality(names[0])
|
|
||||||
|
|
||||||
key = "roleplay_test"
|
|
||||||
added = client.personality.add_personality(
|
|
||||||
key,
|
|
||||||
client.personality.get_personality("default"),
|
|
||||||
)
|
|
||||||
assert added
|
|
||||||
assert key in client.list_personalities()
|
|
||||||
assert client.personality.remove_personality(key)
|
|
||||||
assert key not in client.list_personalities()
|
|
||||||
|
|
||||||
print("ok: personality add/remove")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_skills():
|
|
||||||
print("=== test_skills ===")
|
|
||||||
|
|
||||||
manager = SkillsManager(Path("skills"))
|
|
||||||
assert await manager.load_skill("weather")
|
|
||||||
|
|
||||||
tools = manager.get_all_tools()
|
|
||||||
assert "weather.get_weather" in tools
|
|
||||||
weather = await tools["weather.get_weather"](city="鍖椾含")
|
|
||||||
assert weather
|
|
||||||
|
|
||||||
assert await manager.load_skill("skills_creator")
|
|
||||||
tools = manager.get_all_tools()
|
|
||||||
assert "skills_creator.create_skill" in tools
|
|
||||||
|
|
||||||
await manager.unload_skill("weather")
|
|
||||||
await manager.unload_skill("skills_creator")
|
|
||||||
print("ok: skills load/unload")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_skill_commands():
|
|
||||||
print("=== test_skill_commands ===")
|
|
||||||
|
|
||||||
handler = make_handler()
|
|
||||||
skill_key = f"cmd_zip_skill_{int(time.time() * 1000)}"
|
|
||||||
|
|
||||||
# Prepare a zip package source for install testing
|
|
||||||
tmp_root = TEST_DATA_DIR / "tmp_skill_pkg"
|
|
||||||
if tmp_root.exists():
|
|
||||||
_safe_rmtree(tmp_root)
|
|
||||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
create_skill_template(skill_key, tmp_root, description="zip skill", author="test")
|
|
||||||
|
|
||||||
zip_path = TEST_DATA_DIR / f"{skill_key}.zip"
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
||||||
for file in (tmp_root / skill_key).rglob("*"):
|
|
||||||
if file.is_file():
|
|
||||||
zf.write(file, file.relative_to(tmp_root))
|
|
||||||
|
|
||||||
install_msg = FakeMessage(f"/skills install {zip_path}")
|
|
||||||
await handler._handle_command(install_msg, install_msg.content)
|
|
||||||
assert install_msg.replies, "install command no reply"
|
|
||||||
|
|
||||||
list_msg = FakeMessage("/skills")
|
|
||||||
await handler._handle_command(list_msg, list_msg.content)
|
|
||||||
assert list_msg.replies, "list command no reply"
|
|
||||||
|
|
||||||
reload_msg = FakeMessage(f"/skills reload {skill_key}")
|
|
||||||
await handler._handle_command(reload_msg, reload_msg.content)
|
|
||||||
assert reload_msg.replies, "reload command no reply"
|
|
||||||
|
|
||||||
uninstall_msg = FakeMessage(f"/skills uninstall {skill_key}")
|
|
||||||
await handler._handle_command(uninstall_msg, uninstall_msg.content)
|
|
||||||
assert uninstall_msg.replies, "uninstall command no reply"
|
|
||||||
|
|
||||||
if tmp_root.exists():
|
|
||||||
_safe_rmtree(tmp_root)
|
|
||||||
_safe_unlink(zip_path)
|
|
||||||
|
|
||||||
print("ok: skills install/reload/uninstall command")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_personality_commands():
|
|
||||||
print("=== test_personality_commands ===")
|
|
||||||
|
|
||||||
handler = make_handler()
|
|
||||||
intro = "You are a hot-blooded anime hero. Speak directly and stay in-character."
|
|
||||||
|
|
||||||
add_cmd = (
|
|
||||||
"/personality add roleplay_hero "
|
|
||||||
f"{intro}"
|
|
||||||
)
|
|
||||||
add_msg = FakeMessage(add_cmd)
|
|
||||||
await handler._handle_command(add_msg, add_msg.content)
|
|
||||||
assert add_msg.replies
|
|
||||||
|
|
||||||
set_msg = FakeMessage("/personality set roleplay_hero")
|
|
||||||
await handler._handle_command(set_msg, set_msg.content)
|
|
||||||
assert set_msg.replies
|
|
||||||
assert intro in handler.ai_client.personality.get_system_prompt()
|
|
||||||
|
|
||||||
remove_msg = FakeMessage("/personality remove roleplay_hero")
|
|
||||||
await handler._handle_command(remove_msg, remove_msg.content)
|
|
||||||
assert remove_msg.replies
|
|
||||||
|
|
||||||
assert "roleplay_hero" not in handler.ai_client.list_personalities()
|
|
||||||
print("ok: personality add/set/remove command")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_model_commands():
|
|
||||||
print("=== test_model_commands ===")
|
|
||||||
|
|
||||||
handler = make_handler()
|
|
||||||
|
|
||||||
list_msg = FakeMessage("/models")
|
|
||||||
await handler._handle_command(list_msg, list_msg.content)
|
|
||||||
assert list_msg.replies
|
|
||||||
assert "default" in list_msg.replies[-1].lower()
|
|
||||||
|
|
||||||
add_msg = FakeMessage("/models add roleplay_llm openai gpt-4o-mini")
|
|
||||||
await handler._handle_command(add_msg, add_msg.content)
|
|
||||||
assert add_msg.replies
|
|
||||||
assert handler.active_model_key == "roleplay_llm"
|
|
||||||
assert "roleplay_llm" in handler.model_profiles
|
|
||||||
|
|
||||||
switch_msg = FakeMessage("/models switch default")
|
|
||||||
await handler._handle_command(switch_msg, switch_msg.content)
|
|
||||||
assert switch_msg.replies
|
|
||||||
assert handler.active_model_key == "default"
|
|
||||||
|
|
||||||
current_msg = FakeMessage("/models current")
|
|
||||||
await handler._handle_command(current_msg, current_msg.content)
|
|
||||||
assert current_msg.replies
|
|
||||||
|
|
||||||
old_config = handler.ai_client.config
|
|
||||||
shortcut_model = "Qwen/Qwen2.5-7B-Instruct"
|
|
||||||
shortcut_key = handler._normalize_model_key(shortcut_model)
|
|
||||||
shortcut_add_msg = FakeMessage(f"/models add {shortcut_model}")
|
|
||||||
await handler._handle_command(shortcut_add_msg, shortcut_add_msg.content)
|
|
||||||
assert shortcut_add_msg.replies
|
|
||||||
assert handler.active_model_key == shortcut_key
|
|
||||||
assert handler.ai_client.config.model_name == shortcut_model
|
|
||||||
assert handler.ai_client.config.provider == old_config.provider
|
|
||||||
assert handler.ai_client.config.api_base == old_config.api_base
|
|
||||||
assert handler.ai_client.config.api_key == old_config.api_key
|
|
||||||
|
|
||||||
shortcut_remove_msg = FakeMessage(f"/models remove {shortcut_key}")
|
|
||||||
await handler._handle_command(shortcut_remove_msg, shortcut_remove_msg.content)
|
|
||||||
assert shortcut_remove_msg.replies
|
|
||||||
assert shortcut_key not in handler.model_profiles
|
|
||||||
|
|
||||||
remove_msg = FakeMessage("/models remove roleplay_llm")
|
|
||||||
await handler._handle_command(remove_msg, remove_msg.content)
|
|
||||||
assert remove_msg.replies
|
|
||||||
assert "roleplay_llm" not in handler.model_profiles
|
|
||||||
|
|
||||||
_safe_unlink(handler.model_profiles_path)
|
|
||||||
print("ok: model add/switch/remove command")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_memory_commands():
|
|
||||||
print("=== test_memory_commands ===")
|
|
||||||
|
|
||||||
handler = make_handler()
|
|
||||||
user_id = "test_user"
|
|
||||||
|
|
||||||
await handler.ai_client.clear_all_memory(user_id)
|
|
||||||
|
|
||||||
add_msg = FakeMessage("/memory add this is a long-term memory test")
|
|
||||||
await handler._handle_command(add_msg, add_msg.content)
|
|
||||||
assert add_msg.replies
|
|
||||||
assert "已新增长期记忆" in add_msg.replies[-1]
|
|
||||||
|
|
||||||
memory_id = add_msg.replies[-1].split(": ", 1)[1].split(" ", 1)[0]
|
|
||||||
assert memory_id
|
|
||||||
|
|
||||||
list_msg = FakeMessage("/memory list 5")
|
|
||||||
await handler._handle_command(list_msg, list_msg.content)
|
|
||||||
assert list_msg.replies
|
|
||||||
assert memory_id in list_msg.replies[-1]
|
|
||||||
|
|
||||||
get_msg = FakeMessage(f"/memory get {memory_id}")
|
|
||||||
await handler._handle_command(get_msg, get_msg.content)
|
|
||||||
assert get_msg.replies
|
|
||||||
assert memory_id in get_msg.replies[-1]
|
|
||||||
|
|
||||||
search_msg = FakeMessage("/memory search 长期记忆")
|
|
||||||
await handler._handle_command(search_msg, search_msg.content)
|
|
||||||
assert search_msg.replies
|
|
||||||
assert memory_id in search_msg.replies[-1]
|
|
||||||
|
|
||||||
update_msg = FakeMessage(f"/memory update {memory_id} 这是更新后的长期记忆")
|
|
||||||
await handler._handle_command(update_msg, update_msg.content)
|
|
||||||
assert update_msg.replies
|
|
||||||
assert "已更新长期记忆" in update_msg.replies[-1]
|
|
||||||
|
|
||||||
# Build short-term memory then clear only short-term.
|
|
||||||
await handler.ai_client.memory.add_message(
|
|
||||||
user_id=user_id,
|
|
||||||
role="user",
|
|
||||||
content="short memory for clear short test",
|
|
||||||
)
|
|
||||||
assert handler.ai_client.memory.short_term.get(user_id)
|
|
||||||
|
|
||||||
clear_short_msg = FakeMessage("/clear short")
|
|
||||||
await handler._handle_command(clear_short_msg, clear_short_msg.content)
|
|
||||||
assert clear_short_msg.replies
|
|
||||||
assert not handler.ai_client.memory.short_term.get(user_id)
|
|
||||||
|
|
||||||
# Long-term memory should still exist after clearing short-term only.
|
|
||||||
still_exists = await handler.ai_client.get_long_term_memory(user_id, memory_id)
|
|
||||||
assert still_exists is not None
|
|
||||||
|
|
||||||
delete_msg = FakeMessage(f"/memory delete {memory_id}")
|
|
||||||
await handler._handle_command(delete_msg, delete_msg.content)
|
|
||||||
assert delete_msg.replies
|
|
||||||
assert "已删除长期记忆" in delete_msg.replies[-1]
|
|
||||||
|
|
||||||
removed = await handler.ai_client.get_long_term_memory(user_id, memory_id)
|
|
||||||
assert removed is None
|
|
||||||
|
|
||||||
print("ok: memory command CRUD + clear short")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_plain_text_output():
|
|
||||||
print("=== test_plain_text_output ===")
|
|
||||||
|
|
||||||
handler = make_handler()
|
|
||||||
md_text = "# 标题\n**加粗** 和 `代码`\n- 列表\n[链接](https://example.com)"
|
|
||||||
plain = handler._plain_text(md_text)
|
|
||||||
|
|
||||||
assert "#" not in plain
|
|
||||||
assert "**" not in plain
|
|
||||||
assert "`" not in plain
|
|
||||||
assert "[" not in plain
|
|
||||||
assert "](" not in plain
|
|
||||||
print("ok: markdown stripped")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_skills_creator_autoload():
|
|
||||||
print("=== test_skills_creator_autoload ===")
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
|
||||||
handler = MessageHandler(fake_bot)
|
|
||||||
handler.model_profiles_path = TEST_DATA_DIR / "models_autoload_test.json"
|
|
||||||
_safe_unlink(handler.model_profiles_path)
|
|
||||||
await handler._init_ai()
|
|
||||||
|
|
||||||
assert handler.skills_manager is not None
|
|
||||||
assert "skills_creator" in handler.skills_manager.list_skills()
|
|
||||||
|
|
||||||
tool_names = [tool.name for tool in handler.ai_client.tools.list()]
|
|
||||||
assert "skills_creator.create_skill" in tool_names
|
|
||||||
print("ok: skills_creator autoloaded")
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_mcp():
|
|
||||||
print("=== test_mcp ===")
|
|
||||||
|
|
||||||
from src.ai.mcp import MCPManager
|
|
||||||
from src.ai.mcp.servers import FileSystemMCPServer
|
|
||||||
|
|
||||||
manager = MCPManager(Path("config/mcp.json"))
|
|
||||||
fs_server = FileSystemMCPServer(root_path=Path("data"))
|
|
||||||
await manager.register_server(fs_server)
|
|
||||||
|
|
||||||
tools = await manager.get_all_tools_for_ai()
|
|
||||||
assert len(tools) >= 1
|
|
||||||
print("ok: mcp tools", len(tools))
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_long_task():
|
|
||||||
print("=== test_long_task ===")
|
|
||||||
|
|
||||||
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
|
||||||
|
|
||||||
async def step1():
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
return "step1"
|
|
||||||
|
|
||||||
async def step2():
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
return "step2"
|
|
||||||
|
|
||||||
client.task_manager.register_action("step1", step1)
|
|
||||||
client.task_manager.register_action("step2", step2)
|
|
||||||
|
|
||||||
task_id = await client.create_long_task(
|
|
||||||
user_id="test_user",
|
|
||||||
title="test",
|
|
||||||
description="test task",
|
|
||||||
steps=[
|
|
||||||
{"description": "s1", "action": "step1", "params": {}},
|
|
||||||
{"description": "s2", "action": "step2", "params": {}},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
await client.start_task(task_id)
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
status = client.get_task_status(task_id)
|
|
||||||
assert status is not None
|
|
||||||
assert status["status"] in {"completed", "running"}
|
|
||||||
print("ok: long task", status["status"])
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_memory_importance_evaluator():
|
|
||||||
print("=== test_memory_importance_evaluator ===")
|
|
||||||
|
|
||||||
called = {"value": False}
|
|
||||||
|
|
||||||
async def fake_importance_eval(content, metadata):
|
|
||||||
called["value"] = True
|
|
||||||
assert "用户:" in content
|
|
||||||
assert "助手:" in content
|
|
||||||
return 0.91
|
|
||||||
|
|
||||||
store_path = TEST_DATA_DIR / "importance_test.json"
|
|
||||||
_safe_unlink(store_path)
|
|
||||||
|
|
||||||
memory = MemorySystem(
|
|
||||||
storage_path=store_path,
|
|
||||||
importance_evaluator=fake_importance_eval,
|
|
||||||
use_vector_db=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
stored = await memory.add_qa_pair(
|
|
||||||
user_id="u1",
|
|
||||||
question="请记住我的昵称是小明",
|
|
||||||
answer="好的,我记住了你的昵称是小明",
|
|
||||||
metadata={"source": "test"},
|
|
||||||
)
|
|
||||||
assert called["value"]
|
|
||||||
assert stored is not None
|
|
||||||
assert "用户:" in stored.content
|
|
||||||
assert "助手:" in stored.content
|
|
||||||
assert "小明" in stored.content
|
|
||||||
|
|
||||||
long_term = await memory.list_long_term("u1")
|
|
||||||
assert len(long_term) == 1
|
|
||||||
|
|
||||||
# add_message 仅写入短期记忆,不触发长期记忆评分写入。
|
|
||||||
await memory.add_message(user_id="u1", role="user", content="单条短期消息")
|
|
||||||
long_term_after_single = await memory.list_long_term("u1")
|
|
||||||
assert len(long_term_after_single) == 1
|
|
||||||
|
|
||||||
memory_without_eval = MemorySystem(
|
|
||||||
storage_path=TEST_DATA_DIR / "importance_fallback_test.json",
|
|
||||||
use_vector_db=False,
|
|
||||||
)
|
|
||||||
fallback_score = await memory_without_eval._evaluate_importance("任意内容", None)
|
|
||||||
assert fallback_score == 0.5
|
|
||||||
|
|
||||||
await memory.close()
|
|
||||||
await memory_without_eval.close()
|
|
||||||
_safe_unlink(store_path)
|
|
||||||
_safe_unlink(TEST_DATA_DIR / "importance_fallback_test.json")
|
|
||||||
print("ok: memory importance evaluator")
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_chat():
|
|
||||||
asyncio.run(_test_basic_chat())
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory():
|
|
||||||
asyncio.run(_test_memory())
|
|
||||||
|
|
||||||
|
|
||||||
def test_personality():
|
|
||||||
asyncio.run(_test_personality())
|
|
||||||
|
|
||||||
|
|
||||||
def test_skills():
|
|
||||||
asyncio.run(_test_skills())
|
|
||||||
|
|
||||||
|
|
||||||
def test_skill_commands():
|
|
||||||
asyncio.run(_test_skill_commands())
|
|
||||||
|
|
||||||
|
|
||||||
def test_personality_commands():
|
|
||||||
asyncio.run(_test_personality_commands())
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_commands():
|
|
||||||
asyncio.run(_test_model_commands())
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_commands():
|
|
||||||
asyncio.run(_test_memory_commands())
|
|
||||||
|
|
||||||
|
|
||||||
def test_plain_text_output():
|
|
||||||
asyncio.run(_test_plain_text_output())
|
|
||||||
|
|
||||||
|
|
||||||
def test_skills_creator_autoload():
|
|
||||||
asyncio.run(_test_skills_creator_autoload())
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp():
|
|
||||||
asyncio.run(_test_mcp())
|
|
||||||
|
|
||||||
|
|
||||||
def test_long_task():
|
|
||||||
asyncio.run(_test_long_task())
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_importance_evaluator():
|
|
||||||
asyncio.run(_test_memory_importance_evaluator())
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print("寮€濮?AI 鍔熻兘娴嬭瘯")
|
|
||||||
|
|
||||||
await _test_personality()
|
|
||||||
await _test_skills()
|
|
||||||
await _test_skill_commands()
|
|
||||||
await _test_personality_commands()
|
|
||||||
await _test_model_commands()
|
|
||||||
await _test_memory_commands()
|
|
||||||
await _test_plain_text_output()
|
|
||||||
await _test_skills_creator_autoload()
|
|
||||||
await _test_mcp()
|
|
||||||
await _test_long_task()
|
|
||||||
await _test_memory_importance_evaluator()
|
|
||||||
|
|
||||||
config = get_ai_config()
|
|
||||||
if config.api_key:
|
|
||||||
await _test_basic_chat()
|
|
||||||
await _test_memory()
|
|
||||||
else:
|
|
||||||
print("跳过需要 API Key 的对话/记忆测试")
|
|
||||||
|
|
||||||
print("娴嬭瘯瀹屾垚")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
"""Tests for AIClient forced tool name extraction."""
|
|
||||||
|
|
||||||
from src.ai.client import AIClient
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_forced_tool_name_full_name():
|
|
||||||
tools = [
|
|
||||||
"humanizer_zh.read_skill_doc",
|
|
||||||
"skills_creator.create_skill",
|
|
||||||
]
|
|
||||||
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
|
|
||||||
|
|
||||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
|
||||||
|
|
||||||
assert forced == "humanizer_zh.read_skill_doc"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_forced_tool_name_unique_prefix():
|
|
||||||
tools = [
|
|
||||||
"humanizer_zh.read_skill_doc",
|
|
||||||
"skills_creator.create_skill",
|
|
||||||
]
|
|
||||||
message = "please call tool humanizer_zh only"
|
|
||||||
|
|
||||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
|
||||||
|
|
||||||
assert forced == "humanizer_zh.read_skill_doc"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_forced_tool_name_compact_prefix_without_underscore():
|
|
||||||
tools = [
|
|
||||||
"humanizer_zh.read_skill_doc",
|
|
||||||
"skills_creator.create_skill",
|
|
||||||
]
|
|
||||||
message = "调用humanizerzh人性化处理以下文本"
|
|
||||||
|
|
||||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
|
||||||
|
|
||||||
assert forced == "humanizer_zh.read_skill_doc"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
|
|
||||||
tools = [
|
|
||||||
"skills_creator.create_skill",
|
|
||||||
"skills_creator.reload_skill",
|
|
||||||
]
|
|
||||||
message = "please call tool skills_creator"
|
|
||||||
|
|
||||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
|
||||||
|
|
||||||
assert forced is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_prefix_limit_from_user_message():
|
|
||||||
assert AIClient._extract_prefix_limit("直接返回前100字") == 100
|
|
||||||
assert AIClient._extract_prefix_limit("前 256 字") == 256
|
|
||||||
assert AIClient._extract_prefix_limit("返回全文") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_processing_payload_with_marker():
|
|
||||||
message = "调用humanizer_zh.read_skill_doc人性化处理以下文本:\n第一段。\n第二段。"
|
|
||||||
payload = AIClient._extract_processing_payload(message)
|
|
||||||
assert payload == "第一段。\n第二段。"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_processing_payload_with_generic_pattern():
|
|
||||||
message = "请按技能规则优化如下:\n这是待处理文本。"
|
|
||||||
payload = AIClient._extract_processing_payload(message)
|
|
||||||
assert payload == "这是待处理文本。"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_processing_payload_returns_none_when_absent():
|
|
||||||
assert AIClient._extract_processing_payload("请调用工具 humanizer_zh.read_skill_doc") is None
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.ai.mcp.base import MCPManager, MCPServer
|
|
||||||
|
|
||||||
|
|
||||||
class _DummyMCPServer(MCPServer):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(name="dummy", version="1.0.0")
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
self.register_tool(
|
|
||||||
name="echo",
|
|
||||||
description="Echo text",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"text": {"type": "string"}},
|
|
||||||
"required": ["text"],
|
|
||||||
},
|
|
||||||
handler=self.echo,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def echo(self, text: str) -> str:
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
|
|
||||||
manager = MCPManager(tmp_path / "mcp.json")
|
|
||||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
|
||||||
|
|
||||||
tools = asyncio.run(manager.get_all_tools_for_ai())
|
|
||||||
assert len(tools) == 1
|
|
||||||
function_info = tools[0]["function"]
|
|
||||||
assert function_info["name"] == "dummy.echo"
|
|
||||||
assert function_info["description"] == "Echo text"
|
|
||||||
assert function_info["parameters"]["required"] == ["text"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_manager_execute_tool(tmp_path: Path):
|
|
||||||
manager = MCPManager(tmp_path / "mcp.json")
|
|
||||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
|
||||||
|
|
||||||
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
|
|
||||||
assert result == "hello"
|
|
||||||
22
tests/test_message_dedup.py
Normal file
22
tests/test_message_dedup.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from src.handlers.message_handler_ai import MessageHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _build_handler() -> MessageHandler:
|
||||||
|
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="bot_1", name="TestBot"))
|
||||||
|
return MessageHandler(fake_bot)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedup_by_message_id():
|
||||||
|
handler = _build_handler()
|
||||||
|
msg = SimpleNamespace(id="m1", content="hello", author=SimpleNamespace(id="u1"))
|
||||||
|
assert handler._is_duplicate_message(msg) is False
|
||||||
|
assert handler._is_duplicate_message(msg) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedup_fallback_without_message_id():
|
||||||
|
handler = _build_handler()
|
||||||
|
msg = SimpleNamespace(content="hello", author=SimpleNamespace(id="u1"), group_id="g1")
|
||||||
|
assert handler._is_duplicate_message(msg) is False
|
||||||
|
assert handler._is_duplicate_message(msg) is True
|
||||||
@@ -10,16 +10,14 @@ from src.handlers.message_handler_ai import MessageHandler
|
|||||||
|
|
||||||
|
|
||||||
def _handler() -> MessageHandler:
|
def _handler() -> MessageHandler:
|
||||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot", name="TestBot"))
|
||||||
return MessageHandler(fake_bot)
|
return MessageHandler(fake_bot)
|
||||||
|
|
||||||
|
|
||||||
def test_plain_text_removes_markdown_link_url():
|
def test_plain_text_removes_markdown_link_url():
|
||||||
handler = _handler()
|
handler = _handler()
|
||||||
text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。"
|
text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。"
|
||||||
|
|
||||||
result = handler._plain_text(text)
|
result = handler._plain_text(text)
|
||||||
|
|
||||||
assert "Wikipedia" in result
|
assert "Wikipedia" in result
|
||||||
assert "http" not in result.lower()
|
assert "http" not in result.lower()
|
||||||
|
|
||||||
@@ -27,9 +25,7 @@ def test_plain_text_removes_markdown_link_url():
|
|||||||
def test_plain_text_removes_bare_url():
|
def test_plain_text_removes_bare_url():
|
||||||
handler = _handler()
|
handler = _handler()
|
||||||
text = "访问 https://example.com/path?a=1 或 www.example.org 查看。"
|
text = "访问 https://example.com/path?a=1 或 www.example.org 查看。"
|
||||||
|
|
||||||
result = handler._plain_text(text)
|
result = handler._plain_text(text)
|
||||||
|
|
||||||
assert "http" not in result.lower()
|
assert "http" not in result.lower()
|
||||||
assert "www." not in result.lower()
|
assert "www." not in result.lower()
|
||||||
assert "[链接已省略]" in result
|
assert "[链接已省略]" in result
|
||||||
|
|||||||
44
tests/test_personality_scope_priority.py
Normal file
44
tests/test_personality_scope_priority.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.ai.personality import PersonalityProfile, PersonalitySystem, PersonalityTrait
|
||||||
|
|
||||||
|
|
||||||
|
def _profile(name: str) -> PersonalityProfile:
|
||||||
|
return PersonalityProfile(
|
||||||
|
name=name,
|
||||||
|
description=f"{name} profile",
|
||||||
|
traits=[PersonalityTrait.FRIENDLY],
|
||||||
|
speaking_style="plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scope_priority_session_over_group_over_user_over_global(tmp_path: Path):
|
||||||
|
cfg = tmp_path / "personalities.json"
|
||||||
|
state = tmp_path / "personality_state.json"
|
||||||
|
system = PersonalitySystem(config_path=cfg, state_path=state)
|
||||||
|
|
||||||
|
system.add_personality("p_global", _profile("global"))
|
||||||
|
system.add_personality("p_user", _profile("user"))
|
||||||
|
system.add_personality("p_group", _profile("group"))
|
||||||
|
system.add_personality("p_session", _profile("session"))
|
||||||
|
|
||||||
|
assert system.set_personality("p_global", scope="global")
|
||||||
|
assert system.set_personality("p_user", scope="user", scope_id="u1")
|
||||||
|
assert system.set_personality("p_group", scope="group", scope_id="g1")
|
||||||
|
assert system.set_personality("p_session", scope="session", scope_id="g1:u1")
|
||||||
|
|
||||||
|
profile = system.get_active_personality(user_id="u1", group_id="g1", session_id="g1:u1")
|
||||||
|
assert profile is not None
|
||||||
|
assert profile.name == "session"
|
||||||
|
|
||||||
|
profile_no_session = system.get_active_personality(user_id="u1", group_id="g1", session_id="other")
|
||||||
|
assert profile_no_session is not None
|
||||||
|
assert profile_no_session.name == "group"
|
||||||
|
|
||||||
|
profile_user_only = system.get_active_personality(user_id="u1", group_id=None, session_id=None)
|
||||||
|
assert profile_user_only is not None
|
||||||
|
assert profile_user_only.name == "user"
|
||||||
|
|
||||||
|
profile_global = system.get_active_personality(user_id="u2", group_id=None, session_id=None)
|
||||||
|
assert profile_global is not None
|
||||||
|
assert profile_global.name == "global"
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import io
|
|
||||||
from pathlib import Path
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
from src.ai.skills.base import SkillsManager
|
|
||||||
|
|
||||||
|
|
||||||
def _build_codex_skill_zip_bytes(markdown_text: str, root_name: str = "Humanizer-zh-main") -> bytes:
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
||||||
zf.writestr(f"{root_name}/SKILL.md", markdown_text)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_network_source_supports_github_git_url(tmp_path: Path):
|
|
||||||
manager = SkillsManager(tmp_path / "skills")
|
|
||||||
url, hint_key, subpath = manager._resolve_network_source(
|
|
||||||
"https://github.com/op7418/Humanizer-zh.git"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert url == "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
|
|
||||||
assert hint_key == "humanizer_zh"
|
|
||||||
assert subpath is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_skill_from_local_skill_markdown_source(tmp_path: Path):
|
|
||||||
source_dir = tmp_path / "Humanizer-zh-main"
|
|
||||||
source_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
(source_dir / "SKILL.md").write_text(
|
|
||||||
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n",
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
manager = SkillsManager(tmp_path / "skills")
|
|
||||||
ok, installed_key = manager.install_skill_from_source(str(source_dir), skill_name="humanizer_zh")
|
|
||||||
|
|
||||||
assert ok
|
|
||||||
assert installed_key == "humanizer_zh"
|
|
||||||
installed_dir = tmp_path / "skills" / "humanizer_zh"
|
|
||||||
assert (installed_dir / "skill.json").exists()
|
|
||||||
assert (installed_dir / "main.py").exists()
|
|
||||||
assert (installed_dir / "SKILL.md").exists()
|
|
||||||
|
|
||||||
main_code = (installed_dir / "main.py").read_text(encoding="utf-8")
|
|
||||||
assert "read_skill_doc" in main_code
|
|
||||||
skill_text = (installed_dir / "SKILL.md").read_text(encoding="utf-8")
|
|
||||||
assert "Humanizer-zh" in skill_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_install_skill_from_github_git_url_uses_repo_zip_and_markdown_adapter(
|
|
||||||
tmp_path: Path, monkeypatch
|
|
||||||
):
|
|
||||||
manager = SkillsManager(tmp_path / "skills")
|
|
||||||
zip_bytes = _build_codex_skill_zip_bytes(
|
|
||||||
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n"
|
|
||||||
)
|
|
||||||
captured_urls = []
|
|
||||||
|
|
||||||
def fake_download(url: str, output_zip: Path):
|
|
||||||
captured_urls.append(url)
|
|
||||||
output_zip.write_bytes(zip_bytes)
|
|
||||||
|
|
||||||
monkeypatch.setattr(manager, "_download_zip", fake_download)
|
|
||||||
|
|
||||||
ok, installed_key = manager.install_skill_from_source(
|
|
||||||
"https://github.com/op7418/Humanizer-zh.git"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ok
|
|
||||||
assert installed_key == "humanizer_zh"
|
|
||||||
assert captured_urls == [
|
|
||||||
"https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
|
|
||||||
]
|
|
||||||
assert (tmp_path / "skills" / "humanizer_zh" / "SKILL.md").exists()
|
|
||||||
Reference in New Issue
Block a user