Files
QQbot/src/ai/mcp/base.py

229 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MCP (Model Context Protocol) 支持
"""
import asyncio
import json
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, asdict
from pathlib import Path
from src.utils.logger import setup_logger
logger = setup_logger('MCPSystem')
@dataclass
class MCPResource:
"""MCP资源"""
uri: str
name: str
description: str
mime_type: str
@dataclass
class MCPTool:
"""MCP工具"""
name: str
description: str
input_schema: Dict[str, Any]
@dataclass
class MCPPrompt:
"""MCP提示词"""
name: str
description: str
arguments: List[Dict[str, Any]]
class MCPServer:
"""MCP服务器基类"""
def __init__(self, name: str, version: str):
self.name = name
self.version = version
self.resources: Dict[str, MCPResource] = {}
self.tools: Dict[str, Callable] = {}
self.tool_specs: Dict[str, MCPTool] = {}
self.prompts: Dict[str, MCPPrompt] = {}
async def initialize(self):
"""初始化服务器"""
pass
async def shutdown(self):
"""关闭服务器"""
pass
def register_resource(self, resource: MCPResource):
"""注册资源"""
self.resources[resource.uri] = resource
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
"""注册工具"""
tool = MCPTool(name=name, description=description, input_schema=input_schema)
self.tool_specs[name] = tool
self.tools[name] = handler
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
def register_prompt(self, prompt: MCPPrompt):
"""注册提示词"""
self.prompts[prompt.name] = prompt
async def list_resources(self) -> List[MCPResource]:
"""列出资源"""
return list(self.resources.values())
async def read_resource(self, uri: str) -> Optional[str]:
"""读取资源"""
raise NotImplementedError
async def list_tools(self) -> List[MCPTool]:
"""列出工具"""
return list(self.tool_specs.values())
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
if name not in self.tools:
raise ValueError(f"工具不存在: {name}")
handler = self.tools[name]
logger.info(
f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}"
)
try:
result = await handler(**arguments)
except Exception as exc:
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
raise
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
return result
async def list_prompts(self) -> List[MCPPrompt]:
"""列出提示词"""
return list(self.prompts.values())
async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]:
"""获取提示词"""
raise NotImplementedError
class MCPClient:
"""MCP客户端"""
def __init__(self):
self.servers: Dict[str, MCPServer] = {}
async def connect_server(self, server: MCPServer):
"""连接服务器"""
await server.initialize()
self.servers[server.name] = server
logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}")
async def disconnect_server(self, server_name: str):
"""断开服务器"""
if server_name in self.servers:
await self.servers[server_name].shutdown()
del self.servers[server_name]
logger.info(f"✅ 断开MCP服务器: {server_name}")
def get_server(self, name: str) -> Optional[MCPServer]:
"""获取服务器"""
return self.servers.get(name)
def list_servers(self) -> List[str]:
"""列出所有服务器"""
return list(self.servers.keys())
async def list_all_resources(self) -> Dict[str, List[MCPResource]]:
"""列出所有资源"""
result = {}
for name, server in self.servers.items():
result[name] = await server.list_resources()
return result
async def list_all_tools(self) -> Dict[str, List[MCPTool]]:
"""列出所有工具"""
result = {}
for name, server in self.servers.items():
result[name] = await server.list_tools()
return result
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
server = self.get_server(server_name)
if not server:
raise ValueError(f"服务器不存在: {server_name}")
return await server.call_tool(tool_name, arguments)
class MCPManager:
"""MCP管理器"""
def __init__(self, config_path: Path):
self.config_path = config_path
self.client = MCPClient()
self.server_configs: Dict[str, Dict] = {}
self._load_config()
def _load_config(self):
"""加载配置"""
if self.config_path.exists():
with open(self.config_path, 'r', encoding='utf-8') as f:
self.server_configs = json.load(f)
def _save_config(self):
"""保存配置"""
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, 'w', encoding='utf-8') as f:
json.dump(self.server_configs, f, ensure_ascii=False, indent=2)
async def register_server(self, server: MCPServer, config: Optional[Dict] = None):
"""注册服务器"""
await self.client.connect_server(server)
if config:
self.server_configs[server.name] = config
self._save_config()
async def unregister_server(self, server_name: str):
"""注销服务器"""
await self.client.disconnect_server(server_name)
if server_name in self.server_configs:
del self.server_configs[server_name]
self._save_config()
def get_client(self) -> MCPClient:
"""获取客户端"""
return self.client
async def get_all_tools_for_ai(self) -> List[Dict]:
"""获取所有工具AI格式"""
all_tools = []
tools_by_server = await self.client.list_all_tools()
for server_name, tools in tools_by_server.items():
for tool in tools:
all_tools.append({
"type": "function",
"function": {
"name": f"{server_name}.{tool.name}",
"description": tool.description,
"parameters": tool.input_schema
}
})
return all_tools
async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any:
"""执行工具"""
parts = full_tool_name.split('.', 1)
if len(parts) != 2:
raise ValueError(f"工具名格式错误: {full_tool_name}")
server_name, tool_name = parts
logger.info(f"MCP执行请求: {full_tool_name}")
return await self.client.call_tool(server_name, tool_name, arguments)