From 7d7a4b8f545cdff7edc3a49cc3b00d2218f6fbcd Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Tue, 3 Mar 2026 14:14:16 +0800 Subject: [PATCH] Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing. --- src/ai/client.py | 61 +++++++++++++- src/ai/models/openai_model.py | 119 ++++++++++++++++++++++++++- src/handlers/message_handler_ai.py | 8 ++ tests/test_ai_client_forced_tool.py | 39 +++++++++ tests/test_openai_model_compat.py | 120 ++++++++++++++++++++++++++++ 5 files changed, 343 insertions(+), 4 deletions(-) create mode 100644 tests/test_ai_client_forced_tool.py diff --git a/src/ai/client.py b/src/ai/client.py index 861604d..7529c9d 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -191,12 +191,18 @@ class AIClient: if use_tools and self.tools.list(): tools = self.tools.to_openai_format() tool_names = [tool.name for tool in self.tools.list()] + forced_tool_name = self._extract_forced_tool_name(user_message, tool_names) + if forced_tool_name: + kwargs = dict(kwargs) + kwargs["forced_tool_name"] = forced_tool_name + logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") logger.info( "LLM请求: " f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " - f"tool_names={self._preview_log_payload(tool_names)}" + f"tool_names={self._preview_log_payload(tool_names)}, " + f"forced_tool={forced_tool_name or '-'}" ) logger.info( "LLM输入: " @@ -251,7 +257,7 @@ class AIClient: return response.content except Exception as e: - logger.error(f"对话失败: {e}") + logger.error(f"对话失败: {type(e).__name__}: {e!r}") raise async def _chat_stream( @@ -342,7 +348,10 @@ class AIClient: )) # 再次调用模型获取最终响应 - final_response = await self.model.chat(messages, tools, **kwargs) + final_kwargs = dict(kwargs) + # Force only the first model turn, avoid recursive force after tool result. + final_kwargs.pop("forced_tool_name", None) + final_response = await self.model.chat(messages, tools, **final_kwargs) logger.info( "LLM最终输出: " f"content={self._preview_log_payload(final_response.content)}" @@ -400,6 +409,52 @@ class AIClient: if len(text) > max_len: return text[:max_len] + "..." return text + + @staticmethod + def _extract_forced_tool_name( + user_message: str, available_tool_names: List[str] + ) -> Optional[str]: + if not user_message or not available_tool_names: + return None + + triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"] + if not any(trigger in user_message for trigger in triggers): + return None + + pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)") + explicit_matches = [ + name for name in pattern.findall(user_message) if name in available_tool_names + ] + if len(explicit_matches) == 1: + return explicit_matches[0] + if len(explicit_matches) > 1: + return None + + contained = [name for name in available_tool_names if name in user_message] + if len(contained) == 1: + return contained[0] + + # 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。 + prefixes = sorted( + {name.split(".", 1)[0] for name in available_tool_names}, + key=len, + reverse=True, + ) + matched_prefixes = [ + prefix + for prefix in prefixes + if re.search(rf"\b{re.escape(prefix)}\b", user_message) + ] + if len(matched_prefixes) == 1: + prefix_tools = [ + name + for name in available_tool_names + if name.startswith(f"{matched_prefixes[0]}.") + ] + if len(prefix_tools) == 1: + return prefix_tools[0] + + return None def set_personality(self, personality_name: str) -> bool: """设置人格。""" diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index b005fca..f872326 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -1,6 +1,7 @@ """ OpenAI model implementation (including OpenAI-compatible providers). """ +import asyncio import inspect import json import re @@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel): ) return {} + def _build_forced_tool_params( + self, + params: Dict[str, Any], + forced_tool_name: Optional[str], + tools: Optional[List[dict]], + ) -> Dict[str, Any]: + """Build request params for forcing one specific tool call.""" + if not forced_tool_name: + return {} + + available_tool_names = self._extract_tool_names(tools) + if available_tool_names and forced_tool_name not in available_tool_names: + self.logger.warning( + "forced_tool_name is not in current tool list, ignored: " + f"{forced_tool_name}" + ) + return {} + + if "tools" in params: + return { + "tool_choice": { + "type": "function", + "function": {"name": forced_tool_name}, + } + } + + if "functions" in params: + return {"function_call": {"name": forced_tool_name}} + + self.logger.warning( + "forced_tool_name provided but tool params are unavailable, ignored: " + f"{forced_tool_name}" + ) + return {} + + @staticmethod + def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]: + if not tools: + return [] + + names: List[str] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + function_data = tool.get("function") + if not isinstance(function_data, dict): + continue + name = function_data.get("name") + if isinstance(name, str) and name: + names.append(name) + return names + @staticmethod def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]: if not isinstance(tool, dict): @@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel): self._supports_tools = False retry_params = dict(params) retry_params.pop("tools", None) + forced_tool_name = self._extract_forced_tool_name_from_choice( + retry_params.pop("tool_choice", None) + ) retry_params.update(self._build_tool_params(tools)) + if forced_tool_name and "functions" in retry_params: + retry_params["function_call"] = {"name": forced_tool_name} return await self.client.chat.completions.create(**retry_params) if "unexpected keyword argument 'functions'" in message and "functions" in params: @@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel): return await self.client.chat.completions.create(**retry_params) raise + except Exception as exc: + if self._is_timeout_error(exc): + return await self._retry_on_timeout(params) + raise + + @staticmethod + def _is_timeout_error(error: Exception) -> bool: + if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): + return True + + error_name = type(error).__name__.lower() + if "timeout" in error_name: + return True + + message = str(error).lower() + return "timed out" in message or "timeout" in message + + async def _retry_on_timeout(self, params: Dict[str, Any]): + base_timeout = float(self.config.timeout or 60) + retry_timeout = min(max(base_timeout * 2, 120.0), 300.0) + retry_params = dict(params) + retry_params["timeout"] = retry_timeout + self.logger.warning( + "chat request timed out, retry once with longer timeout: " + f"{base_timeout:.0f}s -> {retry_timeout:.0f}s" + ) + try: + return await self.client.chat.completions.create(**retry_params) + except Exception as retry_exc: + if self._is_timeout_error(retry_exc): + self.logger.error( + "chat request still timed out after retry: " + f"timeout={retry_timeout:.0f}s" + ) + raise + + @staticmethod + def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]: + if not tool_choice: + return None + + if isinstance(tool_choice, dict): + function_data = tool_choice.get("function") + if isinstance(function_data, dict): + name = function_data.get("name") + return name if isinstance(name, str) and name else None + return None + + function_data = getattr(tool_choice, "function", None) + if function_data: + name = getattr(function_data, "name", None) + return name if isinstance(name, str) and name else None + + return None async def chat( self, @@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel): ) -> Message: """Non-stream chat.""" formatted_messages = [self._format_message(msg) for msg in messages] + forced_tool_name = kwargs.pop("forced_tool_name", None) params = { "model": self.config.model_name, @@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) + params.update(self._build_forced_tool_params(params, forced_tool_name, tools)) tool_mode = "none" tool_count = 0 @@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel): self.logger.info( "OpenAI chat request: " - f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}" + f"model={self.config.model_name}, tool_mode={tool_mode}, " + f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}" ) response = await self._create_completion_with_fallback(params, tools) @@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel): ) -> AsyncIterator[str]: """Streaming chat.""" formatted_messages = [self._format_message(msg) for msg in messages] + forced_tool_name = kwargs.pop("forced_tool_name", None) params = { "model": self.config.model_name, @@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) + params.update(self._build_forced_tool_params(params, forced_tool_name, tools)) stream = await self._create_completion_with_fallback(params, tools) diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 0b2a2a7..4d72835 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -8,6 +8,8 @@ from pathlib import Path import re from typing import Any, Dict, Optional +import httpx + from botpy.message import Message from src.ai import AIClient @@ -619,6 +621,12 @@ class MessageHandler: import traceback logger.error(traceback.format_exc()) + if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): + await self._reply_plain( + message, + "模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。", + ) + return await self._reply_plain(message, "消息处理失败,请稍后重试") async def _handle_skills_command(self, message: Message, command: str): diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py new file mode 100644 index 0000000..2130e2a --- /dev/null +++ b/tests/test_ai_client_forced_tool.py @@ -0,0 +1,39 @@ +"""Tests for AIClient forced tool name extraction.""" + +from src.ai.client import AIClient + + +def test_extract_forced_tool_name_full_name(): + tools = [ + "humanizer_zh.read_skill_doc", + "skills_creator.create_skill", + ] + message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced == "humanizer_zh.read_skill_doc" + + +def test_extract_forced_tool_name_unique_prefix(): + tools = [ + "humanizer_zh.read_skill_doc", + "skills_creator.create_skill", + ] + message = "please call tool humanizer_zh only" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced == "humanizer_zh.read_skill_doc" + + +def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): + tools = [ + "skills_creator.create_skill", + "skills_creator.reload_skill", + ] + message = "please call tool skills_creator" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced is None diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 4534a94..175a6e1 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -3,6 +3,7 @@ import asyncio from types import SimpleNamespace +import httpx import src.ai.models.openai_model as openai_model_module from src.ai.base import Message, ModelConfig, ModelProvider from src.ai.models.openai_model import OpenAIModel @@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI: self.embeddings = _LengthLimitedEmbeddings() +class _TimeoutOnceCompletions: + def __init__(self): + self.calls = [] + + async def create( + self, + *, + model, + messages, + temperature=None, + max_tokens=None, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + tools=None, + stream=False, + timeout=None, + **kwargs, + ): + self.calls.append( + { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "tools": tools, + "stream": stream, + "timeout": timeout, + **kwargs, + } + ) + + if len(self.calls) == 1: + raise httpx.ReadTimeout("timed out") + + message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + +class _TimeoutOnceAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _TimeoutOnceCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _FakeEmbeddings() + + def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) @@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_forces_tool_choice_when_supported(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=tools, + forced_tool_name="demo_tool", + ) + ) + + sent = model.client.completions.last_params + assert sent["tool_choice"]["type"] == "function" + assert sent["tool_choice"]["function"]["name"] == "demo_tool" + + def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) @@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=tools, + forced_tool_name="demo_tool", + ) + ) + + sent = model.client.completions.last_params + assert sent["function_call"] == {"name": "demo_tool"} + + def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) @@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch): + monkeypatch.setattr( + openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI + ) + + model = OpenAIModel(_model_config()) + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=_tool_defs(), + forced_tool_name="demo_tool", + ) + ) + + calls = model.client.completions.calls + assert len(calls) == 2 + assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool" + assert calls[1]["function_call"] == {"name": "demo_tool"} + + +def test_openai_model_retries_once_on_read_timeout(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI) + + model = OpenAIModel(_model_config()) + result = asyncio.run( + model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) + ) + + calls = model.client.completions.calls + assert len(calls) == 2 + assert calls[0]["timeout"] is None + assert calls[1]["timeout"] == 120.0 + assert result.content == "ok-after-timeout" + + def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI