From fd2a09f68177dc1d5b51f255046f94a272493f18 Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Tue, 3 Mar 2026 13:10:09 +0800 Subject: [PATCH] Enhance AIClient and MCPServer to support tool registration with source tracking. Added logging for tool calls and improved error handling. Introduced methods for embedding token limit extraction and budget application in OpenAIModel. Added tests for MCP tool registration and execution. --- src/ai/client.py | 49 +++++++++++++++- src/ai/mcp/base.py | 19 ++++-- src/ai/models/openai_model.py | 91 +++++++++++++++++++++++++++-- src/handlers/message_handler_ai.py | 39 +++++++++++++ tests/test_mcp_tool_registration.py | 44 ++++++++++++++ tests/test_openai_model_compat.py | 44 ++++++++++++++ 6 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 tests/test_mcp_tool_registration.py diff --git a/src/ai/client.py b/src/ai/client.py index edca079..cc075de 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -43,6 +43,7 @@ class AIClient: # 初始化工具注册表 self.tools = ToolRegistry() + self._tool_sources: Dict[str, str] = {} # 初始化人格系统 self.personality = PersonalitySystem( @@ -242,8 +243,11 @@ class AIClient: ) -> Message: """处理工具调用。""" messages.append(response) + total_calls = len(response.tool_calls or []) + if total_calls: + logger.info(f"检测到工具调用请求: {total_calls} 个") - # 鎵ц宸ュ叿璋冪敤 + # 执行工具调用 for tool_call in response.tool_calls or []: try: tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) @@ -263,6 +267,7 @@ class AIClient: continue tool_def = self.tools.get(tool_name) + tool_source = self._tool_sources.get(tool_name, "custom") if not tool_def: error_msg = f"未找到工具: {tool_name}" logger.warning(error_msg) @@ -275,9 +280,19 @@ class AIClient: continue try: + logger.info( + "工具调用开始: " + f"source={tool_source}, name={tool_name}, " + f"args={self._preview_log_payload(tool_args)}" + ) result = tool_def.function(**tool_args) if inspect.isawaitable(result): result = await result + logger.info( + "工具调用成功: " + f"source={tool_source}, name={tool_name}, " + f"result={self._preview_log_payload(result)}" + ) messages.append(Message( role="tool", name=tool_name, @@ -285,6 +300,10 @@ class AIClient: tool_call_id=tool_call_id )) except Exception as e: + logger.warning( + "工具调用失败: " + f"source={tool_source}, name={tool_name}, error={e}" + ) messages.append(Message( role="tool", name=tool_name, @@ -334,6 +353,18 @@ class AIClient: return parsed raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") + + @staticmethod + def _preview_log_payload(payload: Any, max_len: int = 240) -> str: + """日志中展示参数/结果时使用的简短预览。""" + try: + text = json.dumps(payload, ensure_ascii=False, default=str) + except Exception: + text = str(payload) + + if len(text) > max_len: + return text[:max_len] + "..." + return text def set_personality(self, personality_name: str) -> bool: """设置人格。""" @@ -382,7 +413,14 @@ class AIClient: """获取任务状态。""" return self.task_manager.get_task_status(task_id) - def register_tool(self, name: str, description: str, parameters: Dict, function: callable): + def register_tool( + self, + name: str, + description: str, + parameters: Dict, + function: callable, + source: str = "custom", + ): """注册工具。""" from .base import ToolDefinition tool = ToolDefinition( @@ -392,18 +430,23 @@ class AIClient: function=function ) self.tools.register(tool) - logger.info(f"已注册工具: {name}") + self._tool_sources[name] = source + logger.info(f"已注册工具: {name} (source={source})") def unregister_tool(self, name: str) -> bool: """卸载工具。""" removed = self.tools.unregister(name) if removed: + self._tool_sources.pop(name, None) logger.info(f"已卸载工具: {name}") return removed def unregister_tools_by_prefix(self, prefix: str) -> int: """按前缀批量卸载工具。""" removed_count = self.tools.unregister_by_prefix(prefix) + for tool_name in list(self._tool_sources.keys()): + if tool_name.startswith(prefix): + self._tool_sources.pop(tool_name, None) if removed_count: logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") return removed_count diff --git a/src/ai/mcp/base.py b/src/ai/mcp/base.py index 4f0b4be..770ae8c 100644 --- a/src/ai/mcp/base.py +++ b/src/ai/mcp/base.py @@ -44,6 +44,7 @@ class MCPServer: self.version = version self.resources: Dict[str, MCPResource] = {} self.tools: Dict[str, Callable] = {} + self.tool_specs: Dict[str, MCPTool] = {} self.prompts: Dict[str, MCPPrompt] = {} async def initialize(self): @@ -61,6 +62,7 @@ class MCPServer: def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): """注册工具""" tool = MCPTool(name=name, description=description, input_schema=input_schema) + self.tool_specs[name] = tool self.tools[name] = handler logger.info(f"✅ MCP工具注册: {self.name}.{name}") @@ -78,10 +80,7 @@ class MCPServer: async def list_tools(self) -> List[MCPTool]: """列出工具""" - return [ - MCPTool(name=name, description="", input_schema={}) - for name in self.tools.keys() - ] + return list(self.tool_specs.values()) async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: """调用工具""" @@ -89,7 +88,16 @@ class MCPServer: raise ValueError(f"工具不存在: {name}") handler = self.tools[name] - return await handler(**arguments) + logger.info( + f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}" + ) + try: + result = await handler(**arguments) + except Exception as exc: + logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}") + raise + logger.info(f"MCP工具调用成功: server={self.name}, tool={name}") + return result async def list_prompts(self) -> List[MCPPrompt]: """列出提示词""" @@ -216,4 +224,5 @@ class MCPManager: raise ValueError(f"工具名格式错误: {full_tool_name}") server_name, tool_name = parts + logger.info(f"MCP执行请求: {full_tool_name}") return await self.client.call_tool(server_name, tool_name, arguments) diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index dd9fac4..7645d78 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -3,6 +3,7 @@ OpenAI model implementation (including OpenAI-compatible providers). """ import inspect import json +import re from typing import Any, AsyncIterator, Dict, List, Optional import httpx @@ -20,6 +21,7 @@ class OpenAIModel(BaseAIModel): def __init__(self, config: ModelConfig): super().__init__(config) self.logger = logger + self._embedding_token_limit: Optional[int] = None http_client = httpx.AsyncClient( timeout=config.timeout, @@ -317,6 +319,70 @@ class OpenAIModel(BaseAIModel): or "maximum context length" in message ) + @staticmethod + def _extract_embedding_token_limit(error: Exception) -> Optional[int]: + message = str(error).lower() + patterns = [ + r"less than\s+(\d+)\s+tokens", + r"maximum context length.*?(\d+)\s+tokens", + r"max(?:imum)?(?: input)?(?: length)?\D+(\d+)\s+tokens", + ] + for pattern in patterns: + match = re.search(pattern, message) + if not match: + continue + + try: + limit = int(match.group(1)) + except (TypeError, ValueError): + continue + + if limit > 0: + return limit + + return None + + @staticmethod + def _estimate_embedding_tokens(text: str) -> int: + if not text: + return 0 + + ascii_count = 0 + cjk_count = 0 + other_count = 0 + + for ch in text: + code = ord(ch) + if code < 128: + ascii_count += 1 + elif 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF: + cjk_count += 1 + else: + other_count += 1 + + estimated = cjk_count + int(other_count * 0.7) + int(ascii_count / 4) + return max(1, estimated) + + def _apply_embedding_budget(self, text: str) -> str: + if not text or not self._embedding_token_limit: + return text + + safe_limit = max(32, int(self._embedding_token_limit * 0.9)) + estimated_tokens = self._estimate_embedding_tokens(text) + if estimated_tokens <= safe_limit: + return text + + ratio = safe_limit / max(estimated_tokens, 1) + target_length = max(64, int(len(text) * ratio)) + if target_length >= len(text): + target_length = max(64, len(text) - 1) + if target_length <= 0 or target_length >= len(text): + return text + + head = target_length // 2 + tail = target_length - head + return f"{text[:head]} {text[-tail:]}" + @staticmethod def _shrink_text_for_embedding(text: str) -> str: compact = " ".join((text or "").split()) @@ -338,6 +404,7 @@ class OpenAIModel(BaseAIModel): raw_text = str(text or "") candidate_text = raw_text.strip() or raw_text or " " + candidate_text = self._apply_embedding_budget(candidate_text) retry_count = 0 while True: @@ -350,16 +417,32 @@ class OpenAIModel(BaseAIModel): return response.data[0].embedding except Exception as e: if self._is_embedding_too_long_error(e): - next_text = self._shrink_text_for_embedding(candidate_text) + token_limit = self._extract_embedding_token_limit(e) + if token_limit: + if not self._embedding_token_limit: + self._embedding_token_limit = token_limit + else: + self._embedding_token_limit = min( + self._embedding_token_limit, token_limit + ) + + next_text = self._apply_embedding_budget( + self._shrink_text_for_embedding(candidate_text) + ) if ( next_text and len(next_text) < len(candidate_text) and retry_count < 5 ): retry_count += 1 - self.logger.warning( - "embedding input too long, retry with truncated text: " - f"{len(candidate_text)} -> {len(next_text)}" + limit_desc = ( + f", token_limit={self._embedding_token_limit}" + if self._embedding_token_limit + else "" + ) + self.logger.info( + "embedding input exceeded provider limit, retry with truncated text: " + f"{len(candidate_text)} -> {len(next_text)}{limit_desc}" ) candidate_text = next_text continue diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 912ea54..dd8559c 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -324,6 +324,43 @@ class MessageHandler: description=f"技能工具: {full_tool_name}", parameters={"type": "object", "properties": {}}, function=tool_func, + source="skills", + ) + count += 1 + + return count + + async def _register_mcp_tools(self) -> int: + if not self.mcp_manager or not self.ai_client: + return 0 + + tools = await self.mcp_manager.get_all_tools_for_ai() + count = 0 + + for item in tools: + function_info = item.get("function") if isinstance(item, dict) else None + if not isinstance(function_info, dict): + continue + + full_tool_name = function_info.get("name") + if not full_tool_name: + continue + + parameters = function_info.get("parameters") + if not isinstance(parameters, dict): + parameters = {"type": "object", "properties": {}} + + async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs): + if not self.mcp_manager: + raise RuntimeError("MCP manager not initialized") + return await self.mcp_manager.execute_tool(_full_tool_name, kwargs) + + self.ai_client.register_tool( + name=full_tool_name, + description=f"MCP工具: {full_tool_name}", + parameters=parameters, + function=_mcp_proxy, + source="mcp", ) count += 1 @@ -439,6 +476,8 @@ class MessageHandler: self.mcp_manager = MCPManager(Path("config/mcp.json")) fs_server = FileSystemMCPServer(root_path=Path("data")) await self.mcp_manager.register_server(fs_server) + mcp_tool_count = await self._register_mcp_tools() + logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools") except Exception as exc: logger.warning(f"MCP 初始化失败: {exc}") diff --git a/tests/test_mcp_tool_registration.py b/tests/test_mcp_tool_registration.py new file mode 100644 index 0000000..90bde10 --- /dev/null +++ b/tests/test_mcp_tool_registration.py @@ -0,0 +1,44 @@ +import asyncio +from pathlib import Path + +from src.ai.mcp.base import MCPManager, MCPServer + + +class _DummyMCPServer(MCPServer): + def __init__(self): + super().__init__(name="dummy", version="1.0.0") + + async def initialize(self): + self.register_tool( + name="echo", + description="Echo text", + input_schema={ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + handler=self.echo, + ) + + async def echo(self, text: str) -> str: + return text + + +def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path): + manager = MCPManager(tmp_path / "mcp.json") + asyncio.run(manager.register_server(_DummyMCPServer())) + + tools = asyncio.run(manager.get_all_tools_for_ai()) + assert len(tools) == 1 + function_info = tools[0]["function"] + assert function_info["name"] == "dummy.echo" + assert function_info["description"] == "Echo text" + assert function_info["parameters"]["required"] == ["text"] + + +def test_mcp_manager_execute_tool(tmp_path: Path): + manager = MCPManager(tmp_path / "mcp.json") + asyncio.run(manager.register_server(_DummyMCPServer())) + + result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"})) + assert result == "hello" diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 55ee06d..4534a94 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -43,6 +43,18 @@ class _FakeEmbeddings: return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) +class _LengthLimitedEmbeddings: + def __init__(self): + self.inputs = [] + + async def create(self, **kwargs): + text = kwargs.get("input", "") + self.inputs.append(text) + if len(text) > 512: + raise RuntimeError("input must be less than 512 tokens") + return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) + + class _ModernCompletions: def __init__(self): self.last_params = None @@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI: self.embeddings = _FakeEmbeddings() +class _LengthLimitedEmbedAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _ModernCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _LengthLimitedEmbeddings() + + def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) @@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): assert model._supports_tools is False assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" + + +def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): + monkeypatch.setattr( + openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI + ) + + model = OpenAIModel(_model_config()) + long_text = "你" * 600 + + first_embedding = asyncio.run(model.embed(long_text)) + assert first_embedding == [0.1, 0.2, 0.3] + assert model._embedding_token_limit == 512 + + inputs_after_first = list(model.client.embeddings.inputs) + assert len(inputs_after_first) == 2 + assert len(inputs_after_first[0]) == 600 + assert len(inputs_after_first[1]) < len(inputs_after_first[0]) + + second_embedding = asyncio.run(model.embed(long_text)) + assert second_embedding == [0.1, 0.2, 0.3] + + inputs_after_second = list(model.client.embeddings.inputs) + assert len(inputs_after_second) == 3 + assert len(inputs_after_second[-1]) < 512