Refactor configuration and enhance AI capabilities
Updated .env.example to improve clarity and added new configuration options for memory and reliability settings. Refactored main.py to streamline the bot's entry point and improved error handling. Enhanced README to reflect new features and command structure. Removed deprecated cmd_zip_skill and skills_creator modules to clean up the codebase. Updated AIClient and MemorySystem for better performance and flexibility in handling user interactions.
This commit is contained in:
7
tests/conftest.py
Normal file
7
tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
628
tests/test_ai.py
628
tests/test_ai.py
@@ -1,628 +0,0 @@
|
||||
"""AI integration tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import stat
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.ai import AIClient
|
||||
from src.ai.base import ModelConfig, ModelProvider
|
||||
from src.ai.memory import MemorySystem
|
||||
from src.ai.skills import SkillsManager, create_skill_template
|
||||
from src.handlers.message_handler_ai import MessageHandler
|
||||
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
|
||||
TEST_DATA_DIR = Path("data/ai_test")
|
||||
|
||||
|
||||
def _safe_rmtree(path: Path):
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
def _onerror(func, target, exc_info):
|
||||
try:
|
||||
os.chmod(target, stat.S_IWRITE)
|
||||
func(target)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
shutil.rmtree(path, onerror=_onerror)
|
||||
return
|
||||
except PermissionError:
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def _safe_unlink(path: Path):
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
path.unlink()
|
||||
return
|
||||
except PermissionError:
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def _read_env(name: str, default=None):
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
value = value.strip()
|
||||
if not value or value.startswith("#"):
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def get_ai_config() -> ModelConfig:
|
||||
provider_map = {
|
||||
"openai": ModelProvider.OPENAI,
|
||||
"anthropic": ModelProvider.ANTHROPIC,
|
||||
"deepseek": ModelProvider.DEEPSEEK,
|
||||
"qwen": ModelProvider.QWEN,
|
||||
"siliconflow": ModelProvider.OPENAI,
|
||||
}
|
||||
|
||||
provider_str = (_read_env("AI_PROVIDER", "openai") or "openai").lower()
|
||||
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
|
||||
|
||||
return ModelConfig(
|
||||
provider=provider,
|
||||
model_name=_read_env("AI_MODEL", "gpt-3.5-turbo") or "gpt-3.5-turbo",
|
||||
api_key=_read_env("AI_API_KEY", "") or "",
|
||||
api_base=_read_env("AI_API_BASE"),
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
|
||||
def get_embed_config() -> ModelConfig:
|
||||
provider_map = {
|
||||
"openai": ModelProvider.OPENAI,
|
||||
"anthropic": ModelProvider.ANTHROPIC,
|
||||
"deepseek": ModelProvider.DEEPSEEK,
|
||||
"qwen": ModelProvider.QWEN,
|
||||
"siliconflow": ModelProvider.OPENAI,
|
||||
}
|
||||
|
||||
provider_str = (_read_env("AI_EMBED_PROVIDER", "openai") or "openai").lower()
|
||||
provider = provider_map.get(provider_str, ModelProvider.OPENAI)
|
||||
|
||||
api_key = _read_env("AI_EMBED_API_KEY") or _read_env("AI_API_KEY", "") or ""
|
||||
api_base = _read_env("AI_EMBED_API_BASE") or _read_env("AI_API_BASE")
|
||||
|
||||
return ModelConfig(
|
||||
provider=provider,
|
||||
model_name=_read_env("AI_EMBED_MODEL", "text-embedding-3-small")
|
||||
or "text-embedding-3-small",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
class FakeMessage:
|
||||
def __init__(self, content: str):
|
||||
from types import SimpleNamespace
|
||||
|
||||
self.content = content
|
||||
self.author = SimpleNamespace(id="test_user")
|
||||
self.replies = []
|
||||
|
||||
async def reply(self, content: str):
|
||||
self.replies.append(content)
|
||||
|
||||
|
||||
def make_handler() -> MessageHandler:
|
||||
from types import SimpleNamespace
|
||||
|
||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
||||
handler = MessageHandler(fake_bot)
|
||||
handler.ai_client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
||||
handler.skills_manager = SkillsManager(Path("skills"))
|
||||
handler.model_profiles_path = TEST_DATA_DIR / "models_test.json"
|
||||
TEST_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_safe_unlink(handler.model_profiles_path)
|
||||
handler._ai_initialized = True
|
||||
return handler
|
||||
|
||||
|
||||
async def _test_basic_chat():
|
||||
print("=== test_basic_chat ===")
|
||||
|
||||
config = get_ai_config()
|
||||
if not config.api_key:
|
||||
print("skip: AI_API_KEY not configured")
|
||||
return
|
||||
|
||||
embed_config = get_embed_config()
|
||||
client = AIClient(config, embed_config=embed_config, data_dir=TEST_DATA_DIR)
|
||||
response = await client.chat(
|
||||
user_id="test_user",
|
||||
user_message="你好,请介绍一下你自己",
|
||||
use_memory=False,
|
||||
use_tools=False,
|
||||
)
|
||||
assert response
|
||||
print("ok: chat response length", len(response))
|
||||
|
||||
|
||||
async def _test_memory():
|
||||
print("=== test_memory ===")
|
||||
|
||||
config = get_ai_config()
|
||||
if not config.api_key:
|
||||
print("skip: AI_API_KEY not configured")
|
||||
return
|
||||
|
||||
client = AIClient(config, embed_config=get_embed_config(), data_dir=TEST_DATA_DIR)
|
||||
await client.chat(user_id="test_user", user_message="鎴戝彨寮犱笁", use_memory=True)
|
||||
await client.chat(user_id="test_user", user_message="what is my name", use_memory=True)
|
||||
|
||||
short_term, long_term = await client.memory.get_context("test_user")
|
||||
assert len(short_term) >= 2
|
||||
# 重要性改为模型评估后,是否入长期记忆取决于模型打分,不再固定断言数量。
|
||||
assert isinstance(long_term, list)
|
||||
print("ok: memory short/long", len(short_term), len(long_term))
|
||||
|
||||
|
||||
async def _test_personality():
|
||||
print("=== test_personality ===")
|
||||
|
||||
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
||||
names = client.list_personalities()
|
||||
assert names
|
||||
assert client.set_personality(names[0])
|
||||
|
||||
key = "roleplay_test"
|
||||
added = client.personality.add_personality(
|
||||
key,
|
||||
client.personality.get_personality("default"),
|
||||
)
|
||||
assert added
|
||||
assert key in client.list_personalities()
|
||||
assert client.personality.remove_personality(key)
|
||||
assert key not in client.list_personalities()
|
||||
|
||||
print("ok: personality add/remove")
|
||||
|
||||
|
||||
async def _test_skills():
|
||||
print("=== test_skills ===")
|
||||
|
||||
manager = SkillsManager(Path("skills"))
|
||||
assert await manager.load_skill("weather")
|
||||
|
||||
tools = manager.get_all_tools()
|
||||
assert "weather.get_weather" in tools
|
||||
weather = await tools["weather.get_weather"](city="鍖椾含")
|
||||
assert weather
|
||||
|
||||
assert await manager.load_skill("skills_creator")
|
||||
tools = manager.get_all_tools()
|
||||
assert "skills_creator.create_skill" in tools
|
||||
|
||||
await manager.unload_skill("weather")
|
||||
await manager.unload_skill("skills_creator")
|
||||
print("ok: skills load/unload")
|
||||
|
||||
|
||||
async def _test_skill_commands():
|
||||
print("=== test_skill_commands ===")
|
||||
|
||||
handler = make_handler()
|
||||
skill_key = f"cmd_zip_skill_{int(time.time() * 1000)}"
|
||||
|
||||
# Prepare a zip package source for install testing
|
||||
tmp_root = TEST_DATA_DIR / "tmp_skill_pkg"
|
||||
if tmp_root.exists():
|
||||
_safe_rmtree(tmp_root)
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
create_skill_template(skill_key, tmp_root, description="zip skill", author="test")
|
||||
|
||||
zip_path = TEST_DATA_DIR / f"{skill_key}.zip"
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for file in (tmp_root / skill_key).rglob("*"):
|
||||
if file.is_file():
|
||||
zf.write(file, file.relative_to(tmp_root))
|
||||
|
||||
install_msg = FakeMessage(f"/skills install {zip_path}")
|
||||
await handler._handle_command(install_msg, install_msg.content)
|
||||
assert install_msg.replies, "install command no reply"
|
||||
|
||||
list_msg = FakeMessage("/skills")
|
||||
await handler._handle_command(list_msg, list_msg.content)
|
||||
assert list_msg.replies, "list command no reply"
|
||||
|
||||
reload_msg = FakeMessage(f"/skills reload {skill_key}")
|
||||
await handler._handle_command(reload_msg, reload_msg.content)
|
||||
assert reload_msg.replies, "reload command no reply"
|
||||
|
||||
uninstall_msg = FakeMessage(f"/skills uninstall {skill_key}")
|
||||
await handler._handle_command(uninstall_msg, uninstall_msg.content)
|
||||
assert uninstall_msg.replies, "uninstall command no reply"
|
||||
|
||||
if tmp_root.exists():
|
||||
_safe_rmtree(tmp_root)
|
||||
_safe_unlink(zip_path)
|
||||
|
||||
print("ok: skills install/reload/uninstall command")
|
||||
|
||||
|
||||
async def _test_personality_commands():
|
||||
print("=== test_personality_commands ===")
|
||||
|
||||
handler = make_handler()
|
||||
intro = "You are a hot-blooded anime hero. Speak directly and stay in-character."
|
||||
|
||||
add_cmd = (
|
||||
"/personality add roleplay_hero "
|
||||
f"{intro}"
|
||||
)
|
||||
add_msg = FakeMessage(add_cmd)
|
||||
await handler._handle_command(add_msg, add_msg.content)
|
||||
assert add_msg.replies
|
||||
|
||||
set_msg = FakeMessage("/personality set roleplay_hero")
|
||||
await handler._handle_command(set_msg, set_msg.content)
|
||||
assert set_msg.replies
|
||||
assert intro in handler.ai_client.personality.get_system_prompt()
|
||||
|
||||
remove_msg = FakeMessage("/personality remove roleplay_hero")
|
||||
await handler._handle_command(remove_msg, remove_msg.content)
|
||||
assert remove_msg.replies
|
||||
|
||||
assert "roleplay_hero" not in handler.ai_client.list_personalities()
|
||||
print("ok: personality add/set/remove command")
|
||||
|
||||
|
||||
async def _test_model_commands():
|
||||
print("=== test_model_commands ===")
|
||||
|
||||
handler = make_handler()
|
||||
|
||||
list_msg = FakeMessage("/models")
|
||||
await handler._handle_command(list_msg, list_msg.content)
|
||||
assert list_msg.replies
|
||||
assert "default" in list_msg.replies[-1].lower()
|
||||
|
||||
add_msg = FakeMessage("/models add roleplay_llm openai gpt-4o-mini")
|
||||
await handler._handle_command(add_msg, add_msg.content)
|
||||
assert add_msg.replies
|
||||
assert handler.active_model_key == "roleplay_llm"
|
||||
assert "roleplay_llm" in handler.model_profiles
|
||||
|
||||
switch_msg = FakeMessage("/models switch default")
|
||||
await handler._handle_command(switch_msg, switch_msg.content)
|
||||
assert switch_msg.replies
|
||||
assert handler.active_model_key == "default"
|
||||
|
||||
current_msg = FakeMessage("/models current")
|
||||
await handler._handle_command(current_msg, current_msg.content)
|
||||
assert current_msg.replies
|
||||
|
||||
old_config = handler.ai_client.config
|
||||
shortcut_model = "Qwen/Qwen2.5-7B-Instruct"
|
||||
shortcut_key = handler._normalize_model_key(shortcut_model)
|
||||
shortcut_add_msg = FakeMessage(f"/models add {shortcut_model}")
|
||||
await handler._handle_command(shortcut_add_msg, shortcut_add_msg.content)
|
||||
assert shortcut_add_msg.replies
|
||||
assert handler.active_model_key == shortcut_key
|
||||
assert handler.ai_client.config.model_name == shortcut_model
|
||||
assert handler.ai_client.config.provider == old_config.provider
|
||||
assert handler.ai_client.config.api_base == old_config.api_base
|
||||
assert handler.ai_client.config.api_key == old_config.api_key
|
||||
|
||||
shortcut_remove_msg = FakeMessage(f"/models remove {shortcut_key}")
|
||||
await handler._handle_command(shortcut_remove_msg, shortcut_remove_msg.content)
|
||||
assert shortcut_remove_msg.replies
|
||||
assert shortcut_key not in handler.model_profiles
|
||||
|
||||
remove_msg = FakeMessage("/models remove roleplay_llm")
|
||||
await handler._handle_command(remove_msg, remove_msg.content)
|
||||
assert remove_msg.replies
|
||||
assert "roleplay_llm" not in handler.model_profiles
|
||||
|
||||
_safe_unlink(handler.model_profiles_path)
|
||||
print("ok: model add/switch/remove command")
|
||||
|
||||
|
||||
async def _test_memory_commands():
|
||||
print("=== test_memory_commands ===")
|
||||
|
||||
handler = make_handler()
|
||||
user_id = "test_user"
|
||||
|
||||
await handler.ai_client.clear_all_memory(user_id)
|
||||
|
||||
add_msg = FakeMessage("/memory add this is a long-term memory test")
|
||||
await handler._handle_command(add_msg, add_msg.content)
|
||||
assert add_msg.replies
|
||||
assert "已新增长期记忆" in add_msg.replies[-1]
|
||||
|
||||
memory_id = add_msg.replies[-1].split(": ", 1)[1].split(" ", 1)[0]
|
||||
assert memory_id
|
||||
|
||||
list_msg = FakeMessage("/memory list 5")
|
||||
await handler._handle_command(list_msg, list_msg.content)
|
||||
assert list_msg.replies
|
||||
assert memory_id in list_msg.replies[-1]
|
||||
|
||||
get_msg = FakeMessage(f"/memory get {memory_id}")
|
||||
await handler._handle_command(get_msg, get_msg.content)
|
||||
assert get_msg.replies
|
||||
assert memory_id in get_msg.replies[-1]
|
||||
|
||||
search_msg = FakeMessage("/memory search 长期记忆")
|
||||
await handler._handle_command(search_msg, search_msg.content)
|
||||
assert search_msg.replies
|
||||
assert memory_id in search_msg.replies[-1]
|
||||
|
||||
update_msg = FakeMessage(f"/memory update {memory_id} 这是更新后的长期记忆")
|
||||
await handler._handle_command(update_msg, update_msg.content)
|
||||
assert update_msg.replies
|
||||
assert "已更新长期记忆" in update_msg.replies[-1]
|
||||
|
||||
# Build short-term memory then clear only short-term.
|
||||
await handler.ai_client.memory.add_message(
|
||||
user_id=user_id,
|
||||
role="user",
|
||||
content="short memory for clear short test",
|
||||
)
|
||||
assert handler.ai_client.memory.short_term.get(user_id)
|
||||
|
||||
clear_short_msg = FakeMessage("/clear short")
|
||||
await handler._handle_command(clear_short_msg, clear_short_msg.content)
|
||||
assert clear_short_msg.replies
|
||||
assert not handler.ai_client.memory.short_term.get(user_id)
|
||||
|
||||
# Long-term memory should still exist after clearing short-term only.
|
||||
still_exists = await handler.ai_client.get_long_term_memory(user_id, memory_id)
|
||||
assert still_exists is not None
|
||||
|
||||
delete_msg = FakeMessage(f"/memory delete {memory_id}")
|
||||
await handler._handle_command(delete_msg, delete_msg.content)
|
||||
assert delete_msg.replies
|
||||
assert "已删除长期记忆" in delete_msg.replies[-1]
|
||||
|
||||
removed = await handler.ai_client.get_long_term_memory(user_id, memory_id)
|
||||
assert removed is None
|
||||
|
||||
print("ok: memory command CRUD + clear short")
|
||||
|
||||
|
||||
async def _test_plain_text_output():
|
||||
print("=== test_plain_text_output ===")
|
||||
|
||||
handler = make_handler()
|
||||
md_text = "# 标题\n**加粗** 和 `代码`\n- 列表\n[链接](https://example.com)"
|
||||
plain = handler._plain_text(md_text)
|
||||
|
||||
assert "#" not in plain
|
||||
assert "**" not in plain
|
||||
assert "`" not in plain
|
||||
assert "[" not in plain
|
||||
assert "](" not in plain
|
||||
print("ok: markdown stripped")
|
||||
|
||||
|
||||
async def _test_skills_creator_autoload():
|
||||
print("=== test_skills_creator_autoload ===")
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
||||
handler = MessageHandler(fake_bot)
|
||||
handler.model_profiles_path = TEST_DATA_DIR / "models_autoload_test.json"
|
||||
_safe_unlink(handler.model_profiles_path)
|
||||
await handler._init_ai()
|
||||
|
||||
assert handler.skills_manager is not None
|
||||
assert "skills_creator" in handler.skills_manager.list_skills()
|
||||
|
||||
tool_names = [tool.name for tool in handler.ai_client.tools.list()]
|
||||
assert "skills_creator.create_skill" in tool_names
|
||||
print("ok: skills_creator autoloaded")
|
||||
|
||||
|
||||
async def _test_mcp():
|
||||
print("=== test_mcp ===")
|
||||
|
||||
from src.ai.mcp import MCPManager
|
||||
from src.ai.mcp.servers import FileSystemMCPServer
|
||||
|
||||
manager = MCPManager(Path("config/mcp.json"))
|
||||
fs_server = FileSystemMCPServer(root_path=Path("data"))
|
||||
await manager.register_server(fs_server)
|
||||
|
||||
tools = await manager.get_all_tools_for_ai()
|
||||
assert len(tools) >= 1
|
||||
print("ok: mcp tools", len(tools))
|
||||
|
||||
|
||||
async def _test_long_task():
|
||||
print("=== test_long_task ===")
|
||||
|
||||
client = AIClient(get_ai_config(), data_dir=TEST_DATA_DIR)
|
||||
|
||||
async def step1():
|
||||
await asyncio.sleep(0.1)
|
||||
return "step1"
|
||||
|
||||
async def step2():
|
||||
await asyncio.sleep(0.1)
|
||||
return "step2"
|
||||
|
||||
client.task_manager.register_action("step1", step1)
|
||||
client.task_manager.register_action("step2", step2)
|
||||
|
||||
task_id = await client.create_long_task(
|
||||
user_id="test_user",
|
||||
title="test",
|
||||
description="test task",
|
||||
steps=[
|
||||
{"description": "s1", "action": "step1", "params": {}},
|
||||
{"description": "s2", "action": "step2", "params": {}},
|
||||
],
|
||||
)
|
||||
|
||||
await client.start_task(task_id)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
status = client.get_task_status(task_id)
|
||||
assert status is not None
|
||||
assert status["status"] in {"completed", "running"}
|
||||
print("ok: long task", status["status"])
|
||||
|
||||
|
||||
async def _test_memory_importance_evaluator():
|
||||
print("=== test_memory_importance_evaluator ===")
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
async def fake_importance_eval(content, metadata):
|
||||
called["value"] = True
|
||||
assert "用户:" in content
|
||||
assert "助手:" in content
|
||||
return 0.91
|
||||
|
||||
store_path = TEST_DATA_DIR / "importance_test.json"
|
||||
_safe_unlink(store_path)
|
||||
|
||||
memory = MemorySystem(
|
||||
storage_path=store_path,
|
||||
importance_evaluator=fake_importance_eval,
|
||||
use_vector_db=False,
|
||||
)
|
||||
|
||||
stored = await memory.add_qa_pair(
|
||||
user_id="u1",
|
||||
question="请记住我的昵称是小明",
|
||||
answer="好的,我记住了你的昵称是小明",
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
assert called["value"]
|
||||
assert stored is not None
|
||||
assert "用户:" in stored.content
|
||||
assert "助手:" in stored.content
|
||||
assert "小明" in stored.content
|
||||
|
||||
long_term = await memory.list_long_term("u1")
|
||||
assert len(long_term) == 1
|
||||
|
||||
# add_message 仅写入短期记忆,不触发长期记忆评分写入。
|
||||
await memory.add_message(user_id="u1", role="user", content="单条短期消息")
|
||||
long_term_after_single = await memory.list_long_term("u1")
|
||||
assert len(long_term_after_single) == 1
|
||||
|
||||
memory_without_eval = MemorySystem(
|
||||
storage_path=TEST_DATA_DIR / "importance_fallback_test.json",
|
||||
use_vector_db=False,
|
||||
)
|
||||
fallback_score = await memory_without_eval._evaluate_importance("任意内容", None)
|
||||
assert fallback_score == 0.5
|
||||
|
||||
await memory.close()
|
||||
await memory_without_eval.close()
|
||||
_safe_unlink(store_path)
|
||||
_safe_unlink(TEST_DATA_DIR / "importance_fallback_test.json")
|
||||
print("ok: memory importance evaluator")
|
||||
|
||||
|
||||
def test_basic_chat():
|
||||
asyncio.run(_test_basic_chat())
|
||||
|
||||
|
||||
def test_memory():
|
||||
asyncio.run(_test_memory())
|
||||
|
||||
|
||||
def test_personality():
|
||||
asyncio.run(_test_personality())
|
||||
|
||||
|
||||
def test_skills():
|
||||
asyncio.run(_test_skills())
|
||||
|
||||
|
||||
def test_skill_commands():
|
||||
asyncio.run(_test_skill_commands())
|
||||
|
||||
|
||||
def test_personality_commands():
|
||||
asyncio.run(_test_personality_commands())
|
||||
|
||||
|
||||
def test_model_commands():
|
||||
asyncio.run(_test_model_commands())
|
||||
|
||||
|
||||
def test_memory_commands():
|
||||
asyncio.run(_test_memory_commands())
|
||||
|
||||
|
||||
def test_plain_text_output():
|
||||
asyncio.run(_test_plain_text_output())
|
||||
|
||||
|
||||
def test_skills_creator_autoload():
|
||||
asyncio.run(_test_skills_creator_autoload())
|
||||
|
||||
|
||||
def test_mcp():
|
||||
asyncio.run(_test_mcp())
|
||||
|
||||
|
||||
def test_long_task():
|
||||
asyncio.run(_test_long_task())
|
||||
|
||||
|
||||
def test_memory_importance_evaluator():
|
||||
asyncio.run(_test_memory_importance_evaluator())
|
||||
|
||||
|
||||
async def main():
|
||||
print("寮€濮?AI 鍔熻兘娴嬭瘯")
|
||||
|
||||
await _test_personality()
|
||||
await _test_skills()
|
||||
await _test_skill_commands()
|
||||
await _test_personality_commands()
|
||||
await _test_model_commands()
|
||||
await _test_memory_commands()
|
||||
await _test_plain_text_output()
|
||||
await _test_skills_creator_autoload()
|
||||
await _test_mcp()
|
||||
await _test_long_task()
|
||||
await _test_memory_importance_evaluator()
|
||||
|
||||
config = get_ai_config()
|
||||
if config.api_key:
|
||||
await _test_basic_chat()
|
||||
await _test_memory()
|
||||
else:
|
||||
print("跳过需要 API Key 的对话/记忆测试")
|
||||
|
||||
print("娴嬭瘯瀹屾垚")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Tests for AIClient forced tool name extraction."""
|
||||
|
||||
from src.ai.client import AIClient
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_full_name():
|
||||
tools = [
|
||||
"humanizer_zh.read_skill_doc",
|
||||
"skills_creator.create_skill",
|
||||
]
|
||||
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced == "humanizer_zh.read_skill_doc"
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_unique_prefix():
|
||||
tools = [
|
||||
"humanizer_zh.read_skill_doc",
|
||||
"skills_creator.create_skill",
|
||||
]
|
||||
message = "please call tool humanizer_zh only"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced == "humanizer_zh.read_skill_doc"
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_compact_prefix_without_underscore():
|
||||
tools = [
|
||||
"humanizer_zh.read_skill_doc",
|
||||
"skills_creator.create_skill",
|
||||
]
|
||||
message = "调用humanizerzh人性化处理以下文本"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced == "humanizer_zh.read_skill_doc"
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
|
||||
tools = [
|
||||
"skills_creator.create_skill",
|
||||
"skills_creator.reload_skill",
|
||||
]
|
||||
message = "please call tool skills_creator"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced is None
|
||||
|
||||
|
||||
def test_extract_prefix_limit_from_user_message():
|
||||
assert AIClient._extract_prefix_limit("直接返回前100字") == 100
|
||||
assert AIClient._extract_prefix_limit("前 256 字") == 256
|
||||
assert AIClient._extract_prefix_limit("返回全文") is None
|
||||
|
||||
|
||||
def test_extract_processing_payload_with_marker():
|
||||
message = "调用humanizer_zh.read_skill_doc人性化处理以下文本:\n第一段。\n第二段。"
|
||||
payload = AIClient._extract_processing_payload(message)
|
||||
assert payload == "第一段。\n第二段。"
|
||||
|
||||
|
||||
def test_extract_processing_payload_with_generic_pattern():
|
||||
message = "请按技能规则优化如下:\n这是待处理文本。"
|
||||
payload = AIClient._extract_processing_payload(message)
|
||||
assert payload == "这是待处理文本。"
|
||||
|
||||
|
||||
def test_extract_processing_payload_returns_none_when_absent():
|
||||
assert AIClient._extract_processing_payload("请调用工具 humanizer_zh.read_skill_doc") is None
|
||||
@@ -1,44 +0,0 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from src.ai.mcp.base import MCPManager, MCPServer
|
||||
|
||||
|
||||
class _DummyMCPServer(MCPServer):
|
||||
def __init__(self):
|
||||
super().__init__(name="dummy", version="1.0.0")
|
||||
|
||||
async def initialize(self):
|
||||
self.register_tool(
|
||||
name="echo",
|
||||
description="Echo text",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
handler=self.echo,
|
||||
)
|
||||
|
||||
async def echo(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
|
||||
manager = MCPManager(tmp_path / "mcp.json")
|
||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||
|
||||
tools = asyncio.run(manager.get_all_tools_for_ai())
|
||||
assert len(tools) == 1
|
||||
function_info = tools[0]["function"]
|
||||
assert function_info["name"] == "dummy.echo"
|
||||
assert function_info["description"] == "Echo text"
|
||||
assert function_info["parameters"]["required"] == ["text"]
|
||||
|
||||
|
||||
def test_mcp_manager_execute_tool(tmp_path: Path):
|
||||
manager = MCPManager(tmp_path / "mcp.json")
|
||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||
|
||||
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
|
||||
assert result == "hello"
|
||||
22
tests/test_message_dedup.py
Normal file
22
tests/test_message_dedup.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.handlers.message_handler_ai import MessageHandler
|
||||
|
||||
|
||||
def _build_handler() -> MessageHandler:
|
||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="bot_1", name="TestBot"))
|
||||
return MessageHandler(fake_bot)
|
||||
|
||||
|
||||
def test_message_dedup_by_message_id():
|
||||
handler = _build_handler()
|
||||
msg = SimpleNamespace(id="m1", content="hello", author=SimpleNamespace(id="u1"))
|
||||
assert handler._is_duplicate_message(msg) is False
|
||||
assert handler._is_duplicate_message(msg) is True
|
||||
|
||||
|
||||
def test_message_dedup_fallback_without_message_id():
|
||||
handler = _build_handler()
|
||||
msg = SimpleNamespace(content="hello", author=SimpleNamespace(id="u1"), group_id="g1")
|
||||
assert handler._is_duplicate_message(msg) is False
|
||||
assert handler._is_duplicate_message(msg) is True
|
||||
@@ -10,16 +10,14 @@ from src.handlers.message_handler_ai import MessageHandler
|
||||
|
||||
|
||||
def _handler() -> MessageHandler:
|
||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot"))
|
||||
fake_bot = SimpleNamespace(robot=SimpleNamespace(id="test_bot", name="TestBot"))
|
||||
return MessageHandler(fake_bot)
|
||||
|
||||
|
||||
def test_plain_text_removes_markdown_link_url():
|
||||
handler = _handler()
|
||||
text = "参考 [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) 获取详情。"
|
||||
|
||||
result = handler._plain_text(text)
|
||||
|
||||
assert "Wikipedia" in result
|
||||
assert "http" not in result.lower()
|
||||
|
||||
@@ -27,9 +25,7 @@ def test_plain_text_removes_markdown_link_url():
|
||||
def test_plain_text_removes_bare_url():
|
||||
handler = _handler()
|
||||
text = "访问 https://example.com/path?a=1 或 www.example.org 查看。"
|
||||
|
||||
result = handler._plain_text(text)
|
||||
|
||||
assert "http" not in result.lower()
|
||||
assert "www." not in result.lower()
|
||||
assert "[链接已省略]" in result
|
||||
|
||||
44
tests/test_personality_scope_priority.py
Normal file
44
tests/test_personality_scope_priority.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.ai.personality import PersonalityProfile, PersonalitySystem, PersonalityTrait
|
||||
|
||||
|
||||
def _profile(name: str) -> PersonalityProfile:
|
||||
return PersonalityProfile(
|
||||
name=name,
|
||||
description=f"{name} profile",
|
||||
traits=[PersonalityTrait.FRIENDLY],
|
||||
speaking_style="plain",
|
||||
)
|
||||
|
||||
|
||||
def test_scope_priority_session_over_group_over_user_over_global(tmp_path: Path):
|
||||
cfg = tmp_path / "personalities.json"
|
||||
state = tmp_path / "personality_state.json"
|
||||
system = PersonalitySystem(config_path=cfg, state_path=state)
|
||||
|
||||
system.add_personality("p_global", _profile("global"))
|
||||
system.add_personality("p_user", _profile("user"))
|
||||
system.add_personality("p_group", _profile("group"))
|
||||
system.add_personality("p_session", _profile("session"))
|
||||
|
||||
assert system.set_personality("p_global", scope="global")
|
||||
assert system.set_personality("p_user", scope="user", scope_id="u1")
|
||||
assert system.set_personality("p_group", scope="group", scope_id="g1")
|
||||
assert system.set_personality("p_session", scope="session", scope_id="g1:u1")
|
||||
|
||||
profile = system.get_active_personality(user_id="u1", group_id="g1", session_id="g1:u1")
|
||||
assert profile is not None
|
||||
assert profile.name == "session"
|
||||
|
||||
profile_no_session = system.get_active_personality(user_id="u1", group_id="g1", session_id="other")
|
||||
assert profile_no_session is not None
|
||||
assert profile_no_session.name == "group"
|
||||
|
||||
profile_user_only = system.get_active_personality(user_id="u1", group_id=None, session_id=None)
|
||||
assert profile_user_only is not None
|
||||
assert profile_user_only.name == "user"
|
||||
|
||||
profile_global = system.get_active_personality(user_id="u2", group_id=None, session_id=None)
|
||||
assert profile_global is not None
|
||||
assert profile_global.name == "global"
|
||||
@@ -1,74 +0,0 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
|
||||
from src.ai.skills.base import SkillsManager
|
||||
|
||||
|
||||
def _build_codex_skill_zip_bytes(markdown_text: str, root_name: str = "Humanizer-zh-main") -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(f"{root_name}/SKILL.md", markdown_text)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def test_resolve_network_source_supports_github_git_url(tmp_path: Path):
|
||||
manager = SkillsManager(tmp_path / "skills")
|
||||
url, hint_key, subpath = manager._resolve_network_source(
|
||||
"https://github.com/op7418/Humanizer-zh.git"
|
||||
)
|
||||
|
||||
assert url == "https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
|
||||
assert hint_key == "humanizer_zh"
|
||||
assert subpath is None
|
||||
|
||||
|
||||
def test_install_skill_from_local_skill_markdown_source(tmp_path: Path):
|
||||
source_dir = tmp_path / "Humanizer-zh-main"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
(source_dir / "SKILL.md").write_text(
|
||||
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
manager = SkillsManager(tmp_path / "skills")
|
||||
ok, installed_key = manager.install_skill_from_source(str(source_dir), skill_name="humanizer_zh")
|
||||
|
||||
assert ok
|
||||
assert installed_key == "humanizer_zh"
|
||||
installed_dir = tmp_path / "skills" / "humanizer_zh"
|
||||
assert (installed_dir / "skill.json").exists()
|
||||
assert (installed_dir / "main.py").exists()
|
||||
assert (installed_dir / "SKILL.md").exists()
|
||||
|
||||
main_code = (installed_dir / "main.py").read_text(encoding="utf-8")
|
||||
assert "read_skill_doc" in main_code
|
||||
skill_text = (installed_dir / "SKILL.md").read_text(encoding="utf-8")
|
||||
assert "Humanizer-zh" in skill_text
|
||||
|
||||
|
||||
def test_install_skill_from_github_git_url_uses_repo_zip_and_markdown_adapter(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
manager = SkillsManager(tmp_path / "skills")
|
||||
zip_bytes = _build_codex_skill_zip_bytes(
|
||||
"# Humanizer-zh\n\nUse natural and human-like Chinese tone.\n"
|
||||
)
|
||||
captured_urls = []
|
||||
|
||||
def fake_download(url: str, output_zip: Path):
|
||||
captured_urls.append(url)
|
||||
output_zip.write_bytes(zip_bytes)
|
||||
|
||||
monkeypatch.setattr(manager, "_download_zip", fake_download)
|
||||
|
||||
ok, installed_key = manager.install_skill_from_source(
|
||||
"https://github.com/op7418/Humanizer-zh.git"
|
||||
)
|
||||
|
||||
assert ok
|
||||
assert installed_key == "humanizer_zh"
|
||||
assert captured_urls == [
|
||||
"https://codeload.github.com/op7418/Humanizer-zh/zip/refs/heads/main"
|
||||
]
|
||||
assert (tmp_path / "skills" / "humanizer_zh" / "SKILL.md").exists()
|
||||
Reference in New Issue
Block a user