diff --git a/src/ai/client.py b/src/ai/client.py index 861604d..7529c9d 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -191,12 +191,18 @@ class AIClient: if use_tools and self.tools.list(): tools = self.tools.to_openai_format() tool_names = [tool.name for tool in self.tools.list()] + forced_tool_name = self._extract_forced_tool_name(user_message, tool_names) + if forced_tool_name: + kwargs = dict(kwargs) + kwargs["forced_tool_name"] = forced_tool_name + logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") logger.info( "LLM请求: " f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " - f"tool_names={self._preview_log_payload(tool_names)}" + f"tool_names={self._preview_log_payload(tool_names)}, " + f"forced_tool={forced_tool_name or '-'}" ) logger.info( "LLM输入: " @@ -251,7 +257,7 @@ class AIClient: return response.content except Exception as e: - logger.error(f"对话失败: {e}") + logger.error(f"对话失败: {type(e).__name__}: {e!r}") raise async def _chat_stream( @@ -342,7 +348,10 @@ class AIClient: )) # 再次调用模型获取最终响应 - final_response = await self.model.chat(messages, tools, **kwargs) + final_kwargs = dict(kwargs) + # Force only the first model turn, avoid recursive force after tool result. + final_kwargs.pop("forced_tool_name", None) + final_response = await self.model.chat(messages, tools, **final_kwargs) logger.info( "LLM最终输出: " f"content={self._preview_log_payload(final_response.content)}" @@ -400,6 +409,52 @@ class AIClient: if len(text) > max_len: return text[:max_len] + "..." return text + + @staticmethod + def _extract_forced_tool_name( + user_message: str, available_tool_names: List[str] + ) -> Optional[str]: + if not user_message or not available_tool_names: + return None + + triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"] + if not any(trigger in user_message for trigger in triggers): + return None + + pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)") + explicit_matches = [ + name for name in pattern.findall(user_message) if name in available_tool_names + ] + if len(explicit_matches) == 1: + return explicit_matches[0] + if len(explicit_matches) > 1: + return None + + contained = [name for name in available_tool_names if name in user_message] + if len(contained) == 1: + return contained[0] + + # 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。 + prefixes = sorted( + {name.split(".", 1)[0] for name in available_tool_names}, + key=len, + reverse=True, + ) + matched_prefixes = [ + prefix + for prefix in prefixes + if re.search(rf"\b{re.escape(prefix)}\b", user_message) + ] + if len(matched_prefixes) == 1: + prefix_tools = [ + name + for name in available_tool_names + if name.startswith(f"{matched_prefixes[0]}.") + ] + if len(prefix_tools) == 1: + return prefix_tools[0] + + return None def set_personality(self, personality_name: str) -> bool: """设置人格。""" diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index b005fca..f872326 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -1,6 +1,7 @@ """ OpenAI model implementation (including OpenAI-compatible providers). """ +import asyncio import inspect import json import re @@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel): ) return {} + def _build_forced_tool_params( + self, + params: Dict[str, Any], + forced_tool_name: Optional[str], + tools: Optional[List[dict]], + ) -> Dict[str, Any]: + """Build request params for forcing one specific tool call.""" + if not forced_tool_name: + return {} + + available_tool_names = self._extract_tool_names(tools) + if available_tool_names and forced_tool_name not in available_tool_names: + self.logger.warning( + "forced_tool_name is not in current tool list, ignored: " + f"{forced_tool_name}" + ) + return {} + + if "tools" in params: + return { + "tool_choice": { + "type": "function", + "function": {"name": forced_tool_name}, + } + } + + if "functions" in params: + return {"function_call": {"name": forced_tool_name}} + + self.logger.warning( + "forced_tool_name provided but tool params are unavailable, ignored: " + f"{forced_tool_name}" + ) + return {} + + @staticmethod + def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]: + if not tools: + return [] + + names: List[str] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + function_data = tool.get("function") + if not isinstance(function_data, dict): + continue + name = function_data.get("name") + if isinstance(name, str) and name: + names.append(name) + return names + @staticmethod def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]: if not isinstance(tool, dict): @@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel): self._supports_tools = False retry_params = dict(params) retry_params.pop("tools", None) + forced_tool_name = self._extract_forced_tool_name_from_choice( + retry_params.pop("tool_choice", None) + ) retry_params.update(self._build_tool_params(tools)) + if forced_tool_name and "functions" in retry_params: + retry_params["function_call"] = {"name": forced_tool_name} return await self.client.chat.completions.create(**retry_params) if "unexpected keyword argument 'functions'" in message and "functions" in params: @@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel): return await self.client.chat.completions.create(**retry_params) raise + except Exception as exc: + if self._is_timeout_error(exc): + return await self._retry_on_timeout(params) + raise + + @staticmethod + def _is_timeout_error(error: Exception) -> bool: + if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): + return True + + error_name = type(error).__name__.lower() + if "timeout" in error_name: + return True + + message = str(error).lower() + return "timed out" in message or "timeout" in message + + async def _retry_on_timeout(self, params: Dict[str, Any]): + base_timeout = float(self.config.timeout or 60) + retry_timeout = min(max(base_timeout * 2, 120.0), 300.0) + retry_params = dict(params) + retry_params["timeout"] = retry_timeout + self.logger.warning( + "chat request timed out, retry once with longer timeout: " + f"{base_timeout:.0f}s -> {retry_timeout:.0f}s" + ) + try: + return await self.client.chat.completions.create(**retry_params) + except Exception as retry_exc: + if self._is_timeout_error(retry_exc): + self.logger.error( + "chat request still timed out after retry: " + f"timeout={retry_timeout:.0f}s" + ) + raise + + @staticmethod + def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]: + if not tool_choice: + return None + + if isinstance(tool_choice, dict): + function_data = tool_choice.get("function") + if isinstance(function_data, dict): + name = function_data.get("name") + return name if isinstance(name, str) and name else None + return None + + function_data = getattr(tool_choice, "function", None) + if function_data: + name = getattr(function_data, "name", None) + return name if isinstance(name, str) and name else None + + return None async def chat( self, @@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel): ) -> Message: """Non-stream chat.""" formatted_messages = [self._format_message(msg) for msg in messages] + forced_tool_name = kwargs.pop("forced_tool_name", None) params = { "model": self.config.model_name, @@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) + params.update(self._build_forced_tool_params(params, forced_tool_name, tools)) tool_mode = "none" tool_count = 0 @@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel): self.logger.info( "OpenAI chat request: " - f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}" + f"model={self.config.model_name}, tool_mode={tool_mode}, " + f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}" ) response = await self._create_completion_with_fallback(params, tools) @@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel): ) -> AsyncIterator[str]: """Streaming chat.""" formatted_messages = [self._format_message(msg) for msg in messages] + forced_tool_name = kwargs.pop("forced_tool_name", None) params = { "model": self.config.model_name, @@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) + params.update(self._build_forced_tool_params(params, forced_tool_name, tools)) stream = await self._create_completion_with_fallback(params, tools) diff --git a/src/handlers/message_handler_ai.py b/src/handlers/message_handler_ai.py index 0b2a2a7..4d72835 100644 --- a/src/handlers/message_handler_ai.py +++ b/src/handlers/message_handler_ai.py @@ -8,6 +8,8 @@ from pathlib import Path import re from typing import Any, Dict, Optional +import httpx + from botpy.message import Message from src.ai import AIClient @@ -619,6 +621,12 @@ class MessageHandler: import traceback logger.error(traceback.format_exc()) + if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)): + await self._reply_plain( + message, + "模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。", + ) + return await self._reply_plain(message, "消息处理失败,请稍后重试") async def _handle_skills_command(self, message: Message, command: str): diff --git a/tests/test_ai_client_forced_tool.py b/tests/test_ai_client_forced_tool.py new file mode 100644 index 0000000..2130e2a --- /dev/null +++ b/tests/test_ai_client_forced_tool.py @@ -0,0 +1,39 @@ +"""Tests for AIClient forced tool name extraction.""" + +from src.ai.client import AIClient + + +def test_extract_forced_tool_name_full_name(): + tools = [ + "humanizer_zh.read_skill_doc", + "skills_creator.create_skill", + ] + message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced == "humanizer_zh.read_skill_doc" + + +def test_extract_forced_tool_name_unique_prefix(): + tools = [ + "humanizer_zh.read_skill_doc", + "skills_creator.create_skill", + ] + message = "please call tool humanizer_zh only" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced == "humanizer_zh.read_skill_doc" + + +def test_extract_forced_tool_name_ambiguous_prefix_returns_none(): + tools = [ + "skills_creator.create_skill", + "skills_creator.reload_skill", + ] + message = "please call tool skills_creator" + + forced = AIClient._extract_forced_tool_name(message, tools) + + assert forced is None diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py index 4534a94..175a6e1 100644 --- a/tests/test_openai_model_compat.py +++ b/tests/test_openai_model_compat.py @@ -3,6 +3,7 @@ import asyncio from types import SimpleNamespace +import httpx import src.ai.models.openai_model as openai_model_module from src.ai.base import Message, ModelConfig, ModelProvider from src.ai.models.openai_model import OpenAIModel @@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI: self.embeddings = _LengthLimitedEmbeddings() +class _TimeoutOnceCompletions: + def __init__(self): + self.calls = [] + + async def create( + self, + *, + model, + messages, + temperature=None, + max_tokens=None, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + tools=None, + stream=False, + timeout=None, + **kwargs, + ): + self.calls.append( + { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "tools": tools, + "stream": stream, + "timeout": timeout, + **kwargs, + } + ) + + if len(self.calls) == 1: + raise httpx.ReadTimeout("timed out") + + message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + +class _TimeoutOnceAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _TimeoutOnceCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _FakeEmbeddings() + + def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) @@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_forces_tool_choice_when_supported(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=tools, + forced_tool_name="demo_tool", + ) + ) + + sent = model.client.completions.last_params + assert sent["tool_choice"]["type"] == "function" + assert sent["tool_choice"]["function"]["name"] == "demo_tool" + + def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) @@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=tools, + forced_tool_name="demo_tool", + ) + ) + + sent = model.client.completions.last_params + assert sent["function_call"] == {"name": "demo_tool"} + + def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) @@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): assert result.tool_calls[0]["function"]["name"] == "demo_tool" +def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch): + monkeypatch.setattr( + openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI + ) + + model = OpenAIModel(_model_config()) + asyncio.run( + model.chat( + messages=[Message(role="user", content="hi")], + tools=_tool_defs(), + forced_tool_name="demo_tool", + ) + ) + + calls = model.client.completions.calls + assert len(calls) == 2 + assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool" + assert calls[1]["function_call"] == {"name": "demo_tool"} + + +def test_openai_model_retries_once_on_read_timeout(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI) + + model = OpenAIModel(_model_config()) + result = asyncio.run( + model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) + ) + + calls = model.client.completions.calls + assert len(calls) == 2 + assert calls[0]["timeout"] is None + assert calls[1]["timeout"] == 120.0 + assert result.content == "ok-after-timeout" + + def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI