229 lines
7.2 KiB
Python
229 lines
7.2 KiB
Python
"""
|
||
MCP (Model Context Protocol) 支持
|
||
"""
|
||
import asyncio
|
||
import json
|
||
from typing import Dict, List, Optional, Any, Callable
|
||
from dataclasses import dataclass, asdict
|
||
from pathlib import Path
|
||
from src.utils.logger import setup_logger
|
||
|
||
logger = setup_logger('MCPSystem')
|
||
|
||
|
||
@dataclass
|
||
class MCPResource:
|
||
"""MCP资源"""
|
||
uri: str
|
||
name: str
|
||
description: str
|
||
mime_type: str
|
||
|
||
|
||
@dataclass
|
||
class MCPTool:
|
||
"""MCP工具"""
|
||
name: str
|
||
description: str
|
||
input_schema: Dict[str, Any]
|
||
|
||
|
||
@dataclass
|
||
class MCPPrompt:
|
||
"""MCP提示词"""
|
||
name: str
|
||
description: str
|
||
arguments: List[Dict[str, Any]]
|
||
|
||
|
||
class MCPServer:
|
||
"""MCP服务器基类"""
|
||
|
||
def __init__(self, name: str, version: str):
|
||
self.name = name
|
||
self.version = version
|
||
self.resources: Dict[str, MCPResource] = {}
|
||
self.tools: Dict[str, Callable] = {}
|
||
self.tool_specs: Dict[str, MCPTool] = {}
|
||
self.prompts: Dict[str, MCPPrompt] = {}
|
||
|
||
async def initialize(self):
|
||
"""初始化服务器"""
|
||
pass
|
||
|
||
async def shutdown(self):
|
||
"""关闭服务器"""
|
||
pass
|
||
|
||
def register_resource(self, resource: MCPResource):
|
||
"""注册资源"""
|
||
self.resources[resource.uri] = resource
|
||
|
||
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
|
||
"""注册工具"""
|
||
tool = MCPTool(name=name, description=description, input_schema=input_schema)
|
||
self.tool_specs[name] = tool
|
||
self.tools[name] = handler
|
||
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
|
||
|
||
def register_prompt(self, prompt: MCPPrompt):
|
||
"""注册提示词"""
|
||
self.prompts[prompt.name] = prompt
|
||
|
||
async def list_resources(self) -> List[MCPResource]:
|
||
"""列出资源"""
|
||
return list(self.resources.values())
|
||
|
||
async def read_resource(self, uri: str) -> Optional[str]:
|
||
"""读取资源"""
|
||
raise NotImplementedError
|
||
|
||
async def list_tools(self) -> List[MCPTool]:
|
||
"""列出工具"""
|
||
return list(self.tool_specs.values())
|
||
|
||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||
"""调用工具"""
|
||
if name not in self.tools:
|
||
raise ValueError(f"工具不存在: {name}")
|
||
|
||
handler = self.tools[name]
|
||
logger.info(
|
||
f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}"
|
||
)
|
||
try:
|
||
result = await handler(**arguments)
|
||
except Exception as exc:
|
||
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
|
||
raise
|
||
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
|
||
return result
|
||
|
||
async def list_prompts(self) -> List[MCPPrompt]:
|
||
"""列出提示词"""
|
||
return list(self.prompts.values())
|
||
|
||
async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Optional[str]:
|
||
"""获取提示词"""
|
||
raise NotImplementedError
|
||
|
||
|
||
class MCPClient:
|
||
"""MCP客户端"""
|
||
|
||
def __init__(self):
|
||
self.servers: Dict[str, MCPServer] = {}
|
||
|
||
async def connect_server(self, server: MCPServer):
|
||
"""连接服务器"""
|
||
await server.initialize()
|
||
self.servers[server.name] = server
|
||
logger.info(f"✅ 连接MCP服务器: {server.name} v{server.version}")
|
||
|
||
async def disconnect_server(self, server_name: str):
|
||
"""断开服务器"""
|
||
if server_name in self.servers:
|
||
await self.servers[server_name].shutdown()
|
||
del self.servers[server_name]
|
||
logger.info(f"✅ 断开MCP服务器: {server_name}")
|
||
|
||
def get_server(self, name: str) -> Optional[MCPServer]:
|
||
"""获取服务器"""
|
||
return self.servers.get(name)
|
||
|
||
def list_servers(self) -> List[str]:
|
||
"""列出所有服务器"""
|
||
return list(self.servers.keys())
|
||
|
||
async def list_all_resources(self) -> Dict[str, List[MCPResource]]:
|
||
"""列出所有资源"""
|
||
result = {}
|
||
for name, server in self.servers.items():
|
||
result[name] = await server.list_resources()
|
||
return result
|
||
|
||
async def list_all_tools(self) -> Dict[str, List[MCPTool]]:
|
||
"""列出所有工具"""
|
||
result = {}
|
||
for name, server in self.servers.items():
|
||
result[name] = await server.list_tools()
|
||
return result
|
||
|
||
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||
"""调用工具"""
|
||
server = self.get_server(server_name)
|
||
if not server:
|
||
raise ValueError(f"服务器不存在: {server_name}")
|
||
|
||
return await server.call_tool(tool_name, arguments)
|
||
|
||
|
||
class MCPManager:
|
||
"""MCP管理器"""
|
||
|
||
def __init__(self, config_path: Path):
|
||
self.config_path = config_path
|
||
self.client = MCPClient()
|
||
self.server_configs: Dict[str, Dict] = {}
|
||
self._load_config()
|
||
|
||
def _load_config(self):
|
||
"""加载配置"""
|
||
if self.config_path.exists():
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
self.server_configs = json.load(f)
|
||
|
||
def _save_config(self):
|
||
"""保存配置"""
|
||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.server_configs, f, ensure_ascii=False, indent=2)
|
||
|
||
async def register_server(self, server: MCPServer, config: Optional[Dict] = None):
|
||
"""注册服务器"""
|
||
await self.client.connect_server(server)
|
||
|
||
if config:
|
||
self.server_configs[server.name] = config
|
||
self._save_config()
|
||
|
||
async def unregister_server(self, server_name: str):
|
||
"""注销服务器"""
|
||
await self.client.disconnect_server(server_name)
|
||
|
||
if server_name in self.server_configs:
|
||
del self.server_configs[server_name]
|
||
self._save_config()
|
||
|
||
def get_client(self) -> MCPClient:
|
||
"""获取客户端"""
|
||
return self.client
|
||
|
||
async def get_all_tools_for_ai(self) -> List[Dict]:
|
||
"""获取所有工具(AI格式)"""
|
||
all_tools = []
|
||
tools_by_server = await self.client.list_all_tools()
|
||
|
||
for server_name, tools in tools_by_server.items():
|
||
for tool in tools:
|
||
all_tools.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": f"{server_name}.{tool.name}",
|
||
"description": tool.description,
|
||
"parameters": tool.input_schema
|
||
}
|
||
})
|
||
|
||
return all_tools
|
||
|
||
async def execute_tool(self, full_tool_name: str, arguments: Dict) -> Any:
|
||
"""执行工具"""
|
||
parts = full_tool_name.split('.', 1)
|
||
if len(parts) != 2:
|
||
raise ValueError(f"工具名格式错误: {full_tool_name}")
|
||
|
||
server_name, tool_name = parts
|
||
logger.info(f"MCP执行请求: {full_tool_name}")
|
||
return await self.client.call_tool(server_name, tool_name, arguments)
|