""" MCP (Model Context Protocol) 支持 """ import asyncio import json from typing import Dict, List, Optional, Any, Callable from dataclasses import dataclass, asdict from pathlib import Path from src.utils.logger import setup_logger logger = setup_logger('MCPSystem') @dataclass class MCPResource: """MCP资源""" uri: str name: str description: str mime_type: str @dataclass class MCPTool: """MCP工具""" name: str description: str input_schema: Dict[str, Any] @dataclass class MCPPrompt: """MCP提示词""" name: str description: str arguments: List[Dict[str, Any]] class MCPServer: """MCP服务器基类""" def __init__(self, name: str, version: str): self.name = name self.version = version self.resources: Dict[str, MCPResource] = {} self.tools: Dict[str, Callable] = {} self.tool_specs: Dict[str, MCPTool] = {} self.prompts: Dict[str, MCPPrompt] = {} async def initialize(self): """初始化服务器""" pass async def shutdown(self): """关闭服务器""" pass def register_resource(self, resource: MCPResource): """注册资源""" self.resources[resource.uri] = resource def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): """注册工具""" tool = MCPTool(name=name, description=description, input_schema=input_schema) self.tool_specs[name] = tool self.tools[name] = handler logger.info(f"✅ MCP工具注册: {self.name}.{name}") def register_prompt(self, prompt: MCPPrompt): """注册提示词""" self.prompts[prompt.name] = prompt async def list_resources(self) -> List[MCPResource]: """列出资源""" return list(self.resources.values()) async def read_resource(self, uri: str) -> Optional[str]: """读取资源""" raise NotImplementedError async def list_tools(self) -> List[MCPTool]: """列出工具""" return list(self.tool_specs.values()) async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: """调用工具""" if name not in self.tools: raise ValueError(f"工具不存在: {name}") handler = self.tools[name] logger.info( f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}" ) try: result = await handler(**arguments) except Exception as exc: logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}") raise logger.info(f"MCP工具调用成功: server={self.name}, tool={name}") return result async def list_prompts(self) -> List[MCPPrompt]: """列出提示词""" return list(self.prompts.values()) async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]: """获取提示词""" raise NotImplementedError class MCPClient: """MCP客户端""" def __init__(self): self.servers: Dict[str, MCPServer] = {} async def connect_server(self, server: MCPServer): """连接服务器""" await server.initialize() self.servers[server.name] = server logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}") async def disconnect_server(self, server_name: str): """断开服务器""" if server_name in self.servers: await self.servers[server_name].shutdown() del self.servers[server_name] logger.info(f"✅ 断开MCP服务器: {server_name}") def get_server(self, name: str) -> Optional[MCPServer]: """获取服务器""" return self.servers.get(name) def list_servers(self) -> List[str]: """列出所有服务器""" return list(self.servers.keys()) async def list_all_resources(self) -> Dict[str, List[MCPResource]]: """列出所有资源""" result = {} for name, server in self.servers.items(): result[name] = await server.list_resources() return result async def list_all_tools(self) -> Dict[str, List[MCPTool]]: """列出所有工具""" result = {} for name, server in self.servers.items(): result[name] = await server.list_tools() return result async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any: """调用工具""" server = self.get_server(server_name) if not server: raise ValueError(f"服务器不存在: {server_name}") return await server.call_tool(tool_name, arguments) class MCPManager: """MCP管理器""" def __init__(self, config_path: Path): self.config_path = config_path self.client = MCPClient() self.server_configs: Dict[str, Dict] = {} self._load_config() def _load_config(self): """加载配置""" if self.config_path.exists(): with open(self.config_path, 'r', encoding='utf-8') as f: self.server_configs = json.load(f) def _save_config(self): """保存配置""" self.config_path.parent.mkdir(parents=True, exist_ok=True) with open(self.config_path, 'w', encoding='utf-8') as f: json.dump(self.server_configs, f, ensure_ascii=False, indent=2) async def register_server(self, server: MCPServer, config: Optional[Dict] = None): """注册服务器""" await self.client.connect_server(server) if config: self.server_configs[server.name] = config self._save_config() async def unregister_server(self, server_name: str): """注销服务器""" await self.client.disconnect_server(server_name) if server_name in self.server_configs: del self.server_configs[server_name] self._save_config() def get_client(self) -> MCPClient: """获取客户端""" return self.client async def get_all_tools_for_ai(self) -> List[Dict]: """获取所有工具(AI格式)""" all_tools = [] tools_by_server = await self.client.list_all_tools() for server_name, tools in tools_by_server.items(): for tool in tools: all_tools.append({ "type": "function", "function": { "name": f"{server_name}.{tool.name}", "description": tool.description, "parameters": tool.input_schema } }) return all_tools async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any: """执行工具""" parts = full_tool_name.split('.', 1) if len(parts) != 2: raise ValueError(f"工具名格式错误: {full_tool_name}") server_name, tool_name = parts logger.info(f"MCP执行请求: {full_tool_name}") return await self.client.call_tool(server_name, tool_name, arguments)