From 4a2666b1f2f32446c4537ebebc573819e195f4ce Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Tue, 3 Mar 2026 14:25:21 +0800 Subject: [PATCH] Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message handling and model interactions, ensuring better user feedback and robustness. --- src/ai/client.py | 83 +++++++++++++++++++++++++++++ src/ai/models/openai_model.py | 34 ++++++------ src/handlers/message_handler_ai.py | 2 +- tests/test_ai_client_forced_tool.py | 6 +++ tests/test_openai_model_compat.py | 9 ++-- 5 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/ai/client.py b/src/ai/client.py index 7529c9d..5e4be00 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -195,6 +195,19 @@ class AIClient: if forced_tool_name: kwargs = dict(kwargs) kwargs["forced_tool_name"] = forced_tool_name + if tools: + before_count = len(tools) + tools = [ + tool + for tool in tools + if ((tool.get("function") or {}).get("name") == forced_tool_name) + ] + if len(tools) == 1: + tool_names = [forced_tool_name] + logger.info( + "显式工具调用已收敛工具列表: " + f"{before_count} -> {len(tools)}" + ) logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") logger.info( @@ -237,6 +250,13 @@ class AIClient: response = await self._handle_tool_calls( messages, response, tools, **kwargs ) + elif forced_tool_name: + forced_response = await self._run_forced_tool_fallback( + forced_tool_name=forced_tool_name, + user_message=user_message, + ) + if forced_response is not None: + response = forced_response # 写入记忆 if use_memory: @@ -259,6 +279,50 @@ class AIClient: except Exception as e: logger.error(f"对话失败: {type(e).__name__}: {e!r}") raise + + async def _run_forced_tool_fallback( + self, forced_tool_name: str, user_message: str + ) -> Optional[Message]: + """Execute forced tool locally when model did not emit tool_calls.""" + tool_def = self.tools.get(forced_tool_name) + tool_source = self._tool_sources.get(forced_tool_name, "custom") + if not tool_def: + logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}") + return None + + logger.warning( + "模型未返回 tool_calls,启用本地强制工具执行: " + f"source={tool_source}, name={forced_tool_name}" + ) + + try: + result = tool_def.function() + if inspect.isawaitable(result): + result = await result + except TypeError as exc: + logger.warning( + "本地强制工具执行失败(参数不匹配): " + f"name={forced_tool_name}, error={exc}" + ) + return None + except Exception as exc: + logger.warning( + "本地强制工具执行失败: " + f"name={forced_tool_name}, error={exc}" + ) + return None + + result_text = str(result) + prefix_limit = self._extract_prefix_limit(user_message) + if prefix_limit: + result_text = result_text[:prefix_limit] + + logger.info( + "本地强制工具执行成功: " + f"source={tool_source}, name={forced_tool_name}, " + f"result={self._preview_log_payload(result_text)}" + ) + return Message(role="assistant", content=result_text) async def _chat_stream( self, @@ -410,6 +474,25 @@ class AIClient: return text[:max_len] + "..." return text + @staticmethod + def _extract_prefix_limit(user_message: str) -> Optional[int]: + """Extract requested output prefix length like '前100字'.""" + if not user_message: + return None + + match = re.search(r"前\s*(\d{1,4})\s*字", user_message) + if not match: + return None + + try: + limit = int(match.group(1)) + except (TypeError, ValueError): + return None + + if limit <= 0: + return None + return min(limit, 5000) + @staticmethod def _extract_forced_tool_name( user_message: str, available_tool_names: List[str] diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index f872326..40c5707 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -51,26 +51,30 @@ class OpenAIModel(BaseAIModel): self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}") def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: - """Build SDK-compatible tool/function request parameters.""" + """Build request parameters for tool-calling. + + Prefer modern `tools` API first, then fallback to legacy `functions` + only when runtime rejects `tools`. + """ if not tools: return {} - if self._supports_tools: - return {"tools": tools} + return {"tools": tools} - if self._supports_functions: - functions = [] - for tool in tools: - schema = self._extract_function_schema(tool) - if schema: - functions.append(schema) + def _build_legacy_function_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: + """Build legacy function-calling params from tool schema.""" + if not tools: + return {} - if functions: - return {"functions": functions, "function_call": "auto"} + functions = [] + for tool in tools: + schema = self._extract_function_schema(tool) + if schema: + functions.append(schema) + + if functions: + return {"functions": functions, "function_call": "auto"} - self.logger.warning( - "Tool calling is not supported by current OpenAI SDK; tools were ignored." - ) return {} def _build_forced_tool_params( @@ -164,7 +168,7 @@ class OpenAIModel(BaseAIModel): forced_tool_name = self._extract_forced_tool_name_from_choice( retry_params.pop("tool_choice", None) ) - retry_params.update(self._build_tool_params(tools)) + retry_params.update(self._build_legacy_function_params(tools)) if forced_tool_name and "functions" in retry_params: retry_params["function_call"] = {"name": forced_tool_name} return await self.client.chat.completions.create(**retry_params) diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 4d72835..d48c29d 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -102,7 +102,7 @@ class MessageHandler: f"{command_name} add (保持 provider/api_base/api_key 不变)\n" f"{command_name} add [api_base]\n" f"{command_name} add \n" - " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n" + " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n" f"{command_name} switch \n" f"{command_name} remove " ) diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py index 2130e2a..07ad623 100644 --- a/tests/test_ai_client_forced_tool.py +++ b/tests/test_ai_client_forced_tool.py @@ -37,3 +37,9 @@ def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): forced = AIClient._extract_forced_tool_name(message, tools) assert forced is None + + +def test_extract_prefix_limit_from_user_message(): + assert AIClient._extract_prefix_limit("直接返回前100字") == 100 + assert AIClient._extract_prefix_limit("前 256 字") == 256 + assert AIClient._extract_prefix_limit("返回全文") is None diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 175a6e1..4f82662 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -313,9 +313,9 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): sent = model.client.completions.last_params assert model._supports_tools is False assert model._supports_functions is True - assert sent["function_call"] == "auto" - assert isinstance(sent["functions"], list) and sent["functions"] - assert sent["functions"][0]["name"] == "demo_tool" + assert sent["tools"] == tools + assert sent["function_call"] is None + assert sent["functions"] is None assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" @@ -334,7 +334,8 @@ def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch): ) sent = model.client.completions.last_params - assert sent["function_call"] == {"name": "demo_tool"} + assert sent["tool_choice"]["type"] == "function" + assert sent["tool_choice"]["function"]["name"] == "demo_tool" def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):