Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.

This commit is contained in:
Mimikko-zeus
2026-03-03 14:14:16 +08:00
parent 00501eb44d
commit 7d7a4b8f54
5 changed files with 343 additions and 4 deletions

View File

@@ -191,12 +191,18 @@ class AIClient:
if use_tools and self.tools.list():
tools = self.tools.to_openai_format()
tool_names = [tool.name for tool in self.tools.list()]
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
if forced_tool_name:
kwargs = dict(kwargs)
kwargs["forced_tool_name"] = forced_tool_name
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
logger.info(
"LLM请求: "
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
f"tool_names={self._preview_log_payload(tool_names)}"
f"tool_names={self._preview_log_payload(tool_names)}, "
f"forced_tool={forced_tool_name or '-'}"
)
logger.info(
"LLM输入: "
@@ -251,7 +257,7 @@ class AIClient:
return response.content
except Exception as e:
logger.error(f"对话失败: {e}")
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
raise
async def _chat_stream(
@@ -342,7 +348,10 @@ class AIClient:
))
# 再次调用模型获取最终响应
final_response = await self.model.chat(messages, tools, **kwargs)
final_kwargs = dict(kwargs)
# Force only the first model turn, avoid recursive force after tool result.
final_kwargs.pop("forced_tool_name", None)
final_response = await self.model.chat(messages, tools, **final_kwargs)
logger.info(
"LLM最终输出: "
f"content={self._preview_log_payload(final_response.content)}"
@@ -400,6 +409,52 @@ class AIClient:
if len(text) > max_len:
return text[:max_len] + "..."
return text
@staticmethod
def _extract_forced_tool_name(
user_message: str, available_tool_names: List[str]
) -> Optional[str]:
if not user_message or not available_tool_names:
return None
triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"]
if not any(trigger in user_message for trigger in triggers):
return None
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
explicit_matches = [
name for name in pattern.findall(user_message) if name in available_tool_names
]
if len(explicit_matches) == 1:
return explicit_matches[0]
if len(explicit_matches) > 1:
return None
contained = [name for name in available_tool_names if name in user_message]
if len(contained) == 1:
return contained[0]
# 允许只写 skill/tool 前缀(如 humanizer_zh前提是前缀下只有一个工具。
prefixes = sorted(
{name.split(".", 1)[0] for name in available_tool_names},
key=len,
reverse=True,
)
matched_prefixes = [
prefix
for prefix in prefixes
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
]
if len(matched_prefixes) == 1:
prefix_tools = [
name
for name in available_tool_names
if name.startswith(f"{matched_prefixes[0]}.")
]
if len(prefix_tools) == 1:
return prefix_tools[0]
return None
def set_personality(self, personality_name: str) -> bool:
"""设置人格。"""

View File

@@ -1,6 +1,7 @@
"""
OpenAI model implementation (including OpenAI-compatible providers).
"""
import asyncio
import inspect
import json
import re
@@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel):
)
return {}
def _build_forced_tool_params(
self,
params: Dict[str, Any],
forced_tool_name: Optional[str],
tools: Optional[List[dict]],
) -> Dict[str, Any]:
"""Build request params for forcing one specific tool call."""
if not forced_tool_name:
return {}
available_tool_names = self._extract_tool_names(tools)
if available_tool_names and forced_tool_name not in available_tool_names:
self.logger.warning(
"forced_tool_name is not in current tool list, ignored: "
f"{forced_tool_name}"
)
return {}
if "tools" in params:
return {
"tool_choice": {
"type": "function",
"function": {"name": forced_tool_name},
}
}
if "functions" in params:
return {"function_call": {"name": forced_tool_name}}
self.logger.warning(
"forced_tool_name provided but tool params are unavailable, ignored: "
f"{forced_tool_name}"
)
return {}
@staticmethod
def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]:
if not tools:
return []
names: List[str] = []
for tool in tools:
if not isinstance(tool, dict):
continue
function_data = tool.get("function")
if not isinstance(function_data, dict):
continue
name = function_data.get("name")
if isinstance(name, str) and name:
names.append(name)
return names
@staticmethod
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not isinstance(tool, dict):
@@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel):
self._supports_tools = False
retry_params = dict(params)
retry_params.pop("tools", None)
forced_tool_name = self._extract_forced_tool_name_from_choice(
retry_params.pop("tool_choice", None)
)
retry_params.update(self._build_tool_params(tools))
if forced_tool_name and "functions" in retry_params:
retry_params["function_call"] = {"name": forced_tool_name}
return await self.client.chat.completions.create(**retry_params)
if "unexpected keyword argument 'functions'" in message and "functions" in params:
@@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel):
return await self.client.chat.completions.create(**retry_params)
raise
except Exception as exc:
if self._is_timeout_error(exc):
return await self._retry_on_timeout(params)
raise
@staticmethod
def _is_timeout_error(error: Exception) -> bool:
if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
return True
error_name = type(error).__name__.lower()
if "timeout" in error_name:
return True
message = str(error).lower()
return "timed out" in message or "timeout" in message
async def _retry_on_timeout(self, params: Dict[str, Any]):
base_timeout = float(self.config.timeout or 60)
retry_timeout = min(max(base_timeout * 2, 120.0), 300.0)
retry_params = dict(params)
retry_params["timeout"] = retry_timeout
self.logger.warning(
"chat request timed out, retry once with longer timeout: "
f"{base_timeout:.0f}s -> {retry_timeout:.0f}s"
)
try:
return await self.client.chat.completions.create(**retry_params)
except Exception as retry_exc:
if self._is_timeout_error(retry_exc):
self.logger.error(
"chat request still timed out after retry: "
f"timeout={retry_timeout:.0f}s"
)
raise
@staticmethod
def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]:
if not tool_choice:
return None
if isinstance(tool_choice, dict):
function_data = tool_choice.get("function")
if isinstance(function_data, dict):
name = function_data.get("name")
return name if isinstance(name, str) and name else None
return None
function_data = getattr(tool_choice, "function", None)
if function_data:
name = getattr(function_data, "name", None)
return name if isinstance(name, str) and name else None
return None
async def chat(
self,
@@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel):
) -> Message:
"""Non-stream chat."""
formatted_messages = [self._format_message(msg) for msg in messages]
forced_tool_name = kwargs.pop("forced_tool_name", None)
params = {
"model": self.config.model_name,
@@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel):
params.update(self._build_tool_params(tools))
params.update(kwargs)
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
tool_mode = "none"
tool_count = 0
@@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel):
self.logger.info(
"OpenAI chat request: "
f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}"
f"model={self.config.model_name}, tool_mode={tool_mode}, "
f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}"
)
response = await self._create_completion_with_fallback(params, tools)
@@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel):
) -> AsyncIterator[str]:
"""Streaming chat."""
formatted_messages = [self._format_message(msg) for msg in messages]
forced_tool_name = kwargs.pop("forced_tool_name", None)
params = {
"model": self.config.model_name,
@@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel):
params.update(self._build_tool_params(tools))
params.update(kwargs)
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
stream = await self._create_completion_with_fallback(params, tools)

View File

@@ -8,6 +8,8 @@ from pathlib import Path
import re
from typing import Any, Dict, Optional
import httpx
from botpy.message import Message
from src.ai import AIClient
@@ -619,6 +621,12 @@ class MessageHandler:
import traceback
logger.error(traceback.format_exc())
if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
await self._reply_plain(
message,
"模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。",
)
return
await self._reply_plain(message, "消息处理失败,请稍后重试")
async def _handle_skills_command(self, message: Message, command: str):