diff --git a/src/ai/client.py b/src/ai/client.py index 7529c9d..5e4be00 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -195,6 +195,19 @@ class AIClient: if forced_tool_name: kwargs = dict(kwargs) kwargs["forced_tool_name"] = forced_tool_name + if tools: + before_count = len(tools) + tools = [ + tool + for tool in tools + if ((tool.get("function") or {}).get("name") == forced_tool_name) + ] + if len(tools) == 1: + tool_names = [forced_tool_name] + logger.info( + "显式工具调用已收敛工具列表: " + f"{before_count} -> {len(tools)}" + ) logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") logger.info( @@ -237,6 +250,13 @@ class AIClient: response = await self._handle_tool_calls( messages, response, tools, **kwargs ) + elif forced_tool_name: + forced_response = await self._run_forced_tool_fallback( + forced_tool_name=forced_tool_name, + user_message=user_message, + ) + if forced_response is not None: + response = forced_response # 写入记忆 if use_memory: @@ -259,6 +279,50 @@ class AIClient: except Exception as e: logger.error(f"对话失败: {type(e).__name__}: {e!r}") raise + + async def _run_forced_tool_fallback( + self, forced_tool_name: str, user_message: str + ) -> Optional[Message]: + """Execute forced tool locally when model did not emit tool_calls.""" + tool_def = self.tools.get(forced_tool_name) + tool_source = self._tool_sources.get(forced_tool_name, "custom") + if not tool_def: + logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}") + return None + + logger.warning( + "模型未返回 tool_calls,启用本地强制工具执行: " + f"source={tool_source}, name={forced_tool_name}" + ) + + try: + result = tool_def.function() + if inspect.isawaitable(result): + result = await result + except TypeError as exc: + logger.warning( + "本地强制工具执行失败(参数不匹配): " + f"name={forced_tool_name}, error={exc}" + ) + return None + except Exception as exc: + logger.warning( + "本地强制工具执行失败: " + f"name={forced_tool_name}, error={exc}" + ) + return None + + result_text = str(result) + prefix_limit = self._extract_prefix_limit(user_message) + if prefix_limit: + result_text = result_text[:prefix_limit] + + logger.info( + "本地强制工具执行成功: " + f"source={tool_source}, name={forced_tool_name}, " + f"result={self._preview_log_payload(result_text)}" + ) + return Message(role="assistant", content=result_text) async def _chat_stream( self, @@ -410,6 +474,25 @@ class AIClient: return text[:max_len] + "..." return text + @staticmethod + def _extract_prefix_limit(user_message: str) -> Optional[int]: + """Extract requested output prefix length like '前100字'.""" + if not user_message: + return None + + match = re.search(r"前\s*(\d{1,4})\s*字", user_message) + if not match: + return None + + try: + limit = int(match.group(1)) + except (TypeError, ValueError): + return None + + if limit <= 0: + return None + return min(limit, 5000) + @staticmethod def _extract_forced_tool_name( user_message: str, available_tool_names: List[str] diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index f872326..40c5707 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -51,26 +51,30 @@ class OpenAIModel(BaseAIModel): self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}") def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: - """Build SDK-compatible tool/function request parameters.""" + """Build request parameters for tool-calling. + + Prefer modern `tools` API first, then fallback to legacy `functions` + only when runtime rejects `tools`. + """ if not tools: return {} - if self._supports_tools: - return {"tools": tools} + return {"tools": tools} - if self._supports_functions: - functions = [] - for tool in tools: - schema = self._extract_function_schema(tool) - if schema: - functions.append(schema) + def _build_legacy_function_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: + """Build legacy function-calling params from tool schema.""" + if not tools: + return {} - if functions: - return {"functions": functions, "function_call": "auto"} + functions = [] + for tool in tools: + schema = self._extract_function_schema(tool) + if schema: + functions.append(schema) + + if functions: + return {"functions": functions, "function_call": "auto"} - self.logger.warning( - "Tool calling is not supported by current OpenAI SDK; tools were ignored." - ) return {} def _build_forced_tool_params( @@ -164,7 +168,7 @@ class OpenAIModel(BaseAIModel): forced_tool_name = self._extract_forced_tool_name_from_choice( retry_params.pop("tool_choice", None) ) - retry_params.update(self._build_tool_params(tools)) + retry_params.update(self._build_legacy_function_params(tools)) if forced_tool_name and "functions" in retry_params: retry_params["function_call"] = {"name": forced_tool_name} return await self.client.chat.completions.create(**retry_params) diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 4d72835..d48c29d 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -102,7 +102,7 @@ class MessageHandler: f"{command_name} add (保持 provider/api_base/api_key 不变)\n" f"{command_name} add [api_base]\n" f"{command_name} add \n" - " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n" + " json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n" f"{command_name} switch \n" f"{command_name} remove " ) diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py index 2130e2a..07ad623 100644 --- a/tests/test_ai_client_forced_tool.py +++ b/tests/test_ai_client_forced_tool.py @@ -37,3 +37,9 @@ def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): forced = AIClient._extract_forced_tool_name(message, tools) assert forced is None + + +def test_extract_prefix_limit_from_user_message(): + assert AIClient._extract_prefix_limit("直接返回前100字") == 100 + assert AIClient._extract_prefix_limit("前 256 字") == 256 + assert AIClient._extract_prefix_limit("返回全文") is None diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 175a6e1..4f82662 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -313,9 +313,9 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): sent = model.client.completions.last_params assert model._supports_tools is False assert model._supports_functions is True - assert sent["function_call"] == "auto" - assert isinstance(sent["functions"], list) and sent["functions"] - assert sent["functions"][0]["name"] == "demo_tool" + assert sent["tools"] == tools + assert sent["function_call"] is None + assert sent["functions"] is None assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" @@ -334,7 +334,8 @@ def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch): ) sent = model.client.completions.last_params - assert sent["function_call"] == {"name": "demo_tool"} + assert sent["tool_choice"]["type"] == "function" + assert sent["tool_choice"]["function"]["name"] == "demo_tool" def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):