diff --git a/src/ai/client.py b/src/ai/client.py index edca079..cc075de 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -43,6 +43,7 @@ class AIClient: # 初始化工具注册表 self.tools = ToolRegistry() + self._tool_sources: Dict[str, str] = {} # 初始化人格系统 self.personality = PersonalitySystem( @@ -242,8 +243,11 @@ class AIClient: ) -> Message: """处理工具调用。""" messages.append(response) + total_calls = len(response.tool_calls or []) + if total_calls: + logger.info(f"检测到工具调用请求: {total_calls} 个") - # 鎵ц宸ュ叿璋冪敤 + # 执行工具调用 for tool_call in response.tool_calls or []: try: tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) @@ -263,6 +267,7 @@ class AIClient: continue tool_def = self.tools.get(tool_name) + tool_source = self._tool_sources.get(tool_name, "custom") if not tool_def: error_msg = f"未找到工具: {tool_name}" logger.warning(error_msg) @@ -275,9 +280,19 @@ class AIClient: continue try: + logger.info( + "工具调用开始: " + f"source={tool_source}, name={tool_name}, " + f"args={self._preview_log_payload(tool_args)}" + ) result = tool_def.function(**tool_args) if inspect.isawaitable(result): result = await result + logger.info( + "工具调用成功: " + f"source={tool_source}, name={tool_name}, " + f"result={self._preview_log_payload(result)}" + ) messages.append(Message( role="tool", name=tool_name, @@ -285,6 +300,10 @@ class AIClient: tool_call_id=tool_call_id )) except Exception as e: + logger.warning( + "工具调用失败: " + f"source={tool_source}, name={tool_name}, error={e}" + ) messages.append(Message( role="tool", name=tool_name, @@ -334,6 +353,18 @@ class AIClient: return parsed raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") + + @staticmethod + def _preview_log_payload(payload: Any, max_len: int = 240) -> str: + """日志中展示参数/结果时使用的简短预览。""" + try: + text = json.dumps(payload, ensure_ascii=False, default=str) + except Exception: + text = str(payload) + + if len(text) > max_len: + return text[:max_len] + "..." + return text def set_personality(self, personality_name: str) -> bool: """设置人格。""" @@ -382,7 +413,14 @@ class AIClient: """获取任务状态。""" return self.task_manager.get_task_status(task_id) - def register_tool(self, name: str, description: str, parameters: Dict, function: callable): + def register_tool( + self, + name: str, + description: str, + parameters: Dict, + function: callable, + source: str = "custom", + ): """注册工具。""" from .base import ToolDefinition tool = ToolDefinition( @@ -392,18 +430,23 @@ class AIClient: function=function ) self.tools.register(tool) - logger.info(f"已注册工具: {name}") + self._tool_sources[name] = source + logger.info(f"已注册工具: {name} (source={source})") def unregister_tool(self, name: str) -> bool: """卸载工具。""" removed = self.tools.unregister(name) if removed: + self._tool_sources.pop(name, None) logger.info(f"已卸载工具: {name}") return removed def unregister_tools_by_prefix(self, prefix: str) -> int: """按前缀批量卸载工具。""" removed_count = self.tools.unregister_by_prefix(prefix) + for tool_name in list(self._tool_sources.keys()): + if tool_name.startswith(prefix): + self._tool_sources.pop(tool_name, None) if removed_count: logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") return removed_count diff --git a/src/ai/mcp/base.py b/src/ai/mcp/base.py index 4f0b4be..770ae8c 100644 --- a/src/ai/mcp/base.py +++ b/src/ai/mcp/base.py @@ -44,6 +44,7 @@ class MCPServer: self.version = version self.resources: Dict[str, MCPResource] = {} self.tools: Dict[str, Callable] = {} + self.tool_specs: Dict[str, MCPTool] = {} self.prompts: Dict[str, MCPPrompt] = {} async def initialize(self): @@ -61,6 +62,7 @@ class MCPServer: def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): """注册工具""" tool = MCPTool(name=name, description=description, input_schema=input_schema) + self.tool_specs[name] = tool self.tools[name] = handler logger.info(f"✅ MCP工具注册: {self.name}.{name}") @@ -78,10 +80,7 @@ class MCPServer: async def list_tools(self) -> List[MCPTool]: """列出工具""" - return [ - MCPTool(name=name, description="", input_schema={}) - for name in self.tools.keys() - ] + return list(self.tool_specs.values()) async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: """调用工具""" @@ -89,7 +88,16 @@ class MCPServer: raise ValueError(f"工具不存在: {name}") handler = self.tools[name] - return await handler(**arguments) + logger.info( + f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}" + ) + try: + result = await handler(**arguments) + except Exception as exc: + logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}") + raise + logger.info(f"MCP工具调用成功: server={self.name}, tool={name}") + return result async def list_prompts(self) -> List[MCPPrompt]: """列出提示词""" @@ -216,4 +224,5 @@ class MCPManager: raise ValueError(f"工具名格式错误: {full_tool_name}") server_name, tool_name = parts + logger.info(f"MCP执行请求: {full_tool_name}") return await self.client.call_tool(server_name, tool_name, arguments) diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index dd9fac4..7645d78 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -3,6 +3,7 @@ OpenAI model implementation (including OpenAI-compatible providers). """ import inspect import json +import re from typing import Any, AsyncIterator, Dict, List, Optional import httpx @@ -20,6 +21,7 @@ class OpenAIModel(BaseAIModel): def __init__(self, config: ModelConfig): super().__init__(config) self.logger = logger + self._embedding_token_limit: Optional[int] = None http_client = httpx.AsyncClient( timeout=config.timeout, @@ -317,6 +319,70 @@ class OpenAIModel(BaseAIModel): or "maximum context length" in message ) + @staticmethod + def _extract_embedding_token_limit(error: Exception) -> Optional[int]: + message = str(error).lower() + patterns = [ + r"less than\s+(\d+)\s+tokens", + r"maximum context length.*?(\d+)\s+tokens", + r"max(?:imum)?(?: input)?(?: length)?\D+(\d+)\s+tokens", + ] + for pattern in patterns: + match = re.search(pattern, message) + if not match: + continue + + try: + limit = int(match.group(1)) + except (TypeError, ValueError): + continue + + if limit > 0: + return limit + + return None + + @staticmethod + def _estimate_embedding_tokens(text: str) -> int: + if not text: + return 0 + + ascii_count = 0 + cjk_count = 0 + other_count = 0 + + for ch in text: + code = ord(ch) + if code < 128: + ascii_count += 1 + elif 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF: + cjk_count += 1 + else: + other_count += 1 + + estimated = cjk_count + int(other_count * 0.7) + int(ascii_count / 4) + return max(1, estimated) + + def _apply_embedding_budget(self, text: str) -> str: + if not text or not self._embedding_token_limit: + return text + + safe_limit = max(32, int(self._embedding_token_limit * 0.9)) + estimated_tokens = self._estimate_embedding_tokens(text) + if estimated_tokens <= safe_limit: + return text + + ratio = safe_limit / max(estimated_tokens, 1) + target_length = max(64, int(len(text) * ratio)) + if target_length >= len(text): + target_length = max(64, len(text) - 1) + if target_length <= 0 or target_length >= len(text): + return text + + head = target_length // 2 + tail = target_length - head + return f"{text[:head]} {text[-tail:]}" + @staticmethod def _shrink_text_for_embedding(text: str) -> str: compact = " ".join((text or "").split()) @@ -338,6 +404,7 @@ class OpenAIModel(BaseAIModel): raw_text = str(text or "") candidate_text = raw_text.strip() or raw_text or " " + candidate_text = self._apply_embedding_budget(candidate_text) retry_count = 0 while True: @@ -350,16 +417,32 @@ class OpenAIModel(BaseAIModel): return response.data[0].embedding except Exception as e: if self._is_embedding_too_long_error(e): - next_text = self._shrink_text_for_embedding(candidate_text) + token_limit = self._extract_embedding_token_limit(e) + if token_limit: + if not self._embedding_token_limit: + self._embedding_token_limit = token_limit + else: + self._embedding_token_limit = min( + self._embedding_token_limit, token_limit + ) + + next_text = self._apply_embedding_budget( + self._shrink_text_for_embedding(candidate_text) + ) if ( next_text and len(next_text) < len(candidate_text) and retry_count < 5 ): retry_count += 1 - self.logger.warning( - "embedding input too long, retry with truncated text: " - f"{len(candidate_text)} -> {len(next_text)}" + limit_desc = ( + f", token_limit={self._embedding_token_limit}" + if self._embedding_token_limit + else "" + ) + self.logger.info( + "embedding input exceeded provider limit, retry with truncated text: " + f"{len(candidate_text)} -> {len(next_text)}{limit_desc}" ) candidate_text = next_text continue diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 912ea54..dd8559c 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -324,6 +324,43 @@ class MessageHandler: description=f"技能工具: {full_tool_name}", parameters={"type": "object", "properties": {}}, function=tool_func, + source="skills", + ) + count += 1 + + return count + + async def _register_mcp_tools(self) -> int: + if not self.mcp_manager or not self.ai_client: + return 0 + + tools = await self.mcp_manager.get_all_tools_for_ai() + count = 0 + + for item in tools: + function_info = item.get("function") if isinstance(item, dict) else None + if not isinstance(function_info, dict): + continue + + full_tool_name = function_info.get("name") + if not full_tool_name: + continue + + parameters = function_info.get("parameters") + if not isinstance(parameters, dict): + parameters = {"type": "object", "properties": {}} + + async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs): + if not self.mcp_manager: + raise RuntimeError("MCP manager not initialized") + return await self.mcp_manager.execute_tool(_full_tool_name, kwargs) + + self.ai_client.register_tool( + name=full_tool_name, + description=f"MCP工具: {full_tool_name}", + parameters=parameters, + function=_mcp_proxy, + source="mcp", ) count += 1 @@ -439,6 +476,8 @@ class MessageHandler: self.mcp_manager = MCPManager(Path("config/mcp.json")) fs_server = FileSystemMCPServer(root_path=Path("data")) await self.mcp_manager.register_server(fs_server) + mcp_tool_count = await self._register_mcp_tools() + logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools") except Exception as exc: logger.warning(f"MCP 初始化失败: {exc}") diff --git a/tests/test_mcp_tool_registration.py b/tests/test_mcp_tool_registration.py new file mode 100644 index 0000000..90bde10 --- /dev/null +++ b/tests/test_mcp_tool_registration.py @@ -0,0 +1,44 @@ +import asyncio +from pathlib import Path + +from src.ai.mcp.base import MCPManager, MCPServer + + +class _DummyMCPServer(MCPServer): + def __init__(self): + super().__init__(name="dummy", version="1.0.0") + + async def initialize(self): + self.register_tool( + name="echo", + description="Echo text", + input_schema={ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + handler=self.echo, + ) + + async def echo(self, text: str) -> str: + return text + + +def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path): + manager = MCPManager(tmp_path / "mcp.json") + asyncio.run(manager.register_server(_DummyMCPServer())) + + tools = asyncio.run(manager.get_all_tools_for_ai()) + assert len(tools) == 1 + function_info = tools[0]["function"] + assert function_info["name"] == "dummy.echo" + assert function_info["description"] == "Echo text" + assert function_info["parameters"]["required"] == ["text"] + + +def test_mcp_manager_execute_tool(tmp_path: Path): + manager = MCPManager(tmp_path / "mcp.json") + asyncio.run(manager.register_server(_DummyMCPServer())) + + result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"})) + assert result == "hello" diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 55ee06d..4534a94 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -43,6 +43,18 @@ class _FakeEmbeddings: return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) +class _LengthLimitedEmbeddings: + def __init__(self): + self.inputs = [] + + async def create(self, **kwargs): + text = kwargs.get("input", "") + self.inputs.append(text) + if len(text) > 512: + raise RuntimeError("input must be less than 512 tokens") + return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) + + class _ModernCompletions: def __init__(self): self.last_params = None @@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI: self.embeddings = _FakeEmbeddings() +class _LengthLimitedEmbedAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _ModernCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _LengthLimitedEmbeddings() + + def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) @@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): assert model._supports_tools is False assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" + + +def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): + monkeypatch.setattr( + openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI + ) + + model = OpenAIModel(_model_config()) + long_text = "你" * 600 + + first_embedding = asyncio.run(model.embed(long_text)) + assert first_embedding == [0.1, 0.2, 0.3] + assert model._embedding_token_limit == 512 + + inputs_after_first = list(model.client.embeddings.inputs) + assert len(inputs_after_first) == 2 + assert len(inputs_after_first[0]) == 600 + assert len(inputs_after_first[1]) < len(inputs_after_first[0]) + + second_embedding = asyncio.run(model.embed(long_text)) + assert second_embedding == [0.1, 0.2, 0.3] + + inputs_after_second = list(model.client.embeddings.inputs) + assert len(inputs_after_second) == 3 + assert len(inputs_after_second[-1]) < 512