Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.
This commit is contained in:
@@ -191,12 +191,18 @@ class AIClient:
|
||||
if use_tools and self.tools.list():
|
||||
tools = self.tools.to_openai_format()
|
||||
tool_names = [tool.name for tool in self.tools.list()]
|
||||
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
|
||||
if forced_tool_name:
|
||||
kwargs = dict(kwargs)
|
||||
kwargs["forced_tool_name"] = forced_tool_name
|
||||
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
|
||||
|
||||
logger.info(
|
||||
"LLM请求: "
|
||||
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
|
||||
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
|
||||
f"tool_names={self._preview_log_payload(tool_names)}"
|
||||
f"tool_names={self._preview_log_payload(tool_names)}, "
|
||||
f"forced_tool={forced_tool_name or '-'}"
|
||||
)
|
||||
logger.info(
|
||||
"LLM输入: "
|
||||
@@ -251,7 +257,7 @@ class AIClient:
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {e}")
|
||||
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
|
||||
raise
|
||||
|
||||
async def _chat_stream(
|
||||
@@ -342,7 +348,10 @@ class AIClient:
|
||||
))
|
||||
|
||||
# 再次调用模型获取最终响应
|
||||
final_response = await self.model.chat(messages, tools, **kwargs)
|
||||
final_kwargs = dict(kwargs)
|
||||
# Force only the first model turn, avoid recursive force after tool result.
|
||||
final_kwargs.pop("forced_tool_name", None)
|
||||
final_response = await self.model.chat(messages, tools, **final_kwargs)
|
||||
logger.info(
|
||||
"LLM最终输出: "
|
||||
f"content={self._preview_log_payload(final_response.content)}"
|
||||
@@ -400,6 +409,52 @@ class AIClient:
|
||||
if len(text) > max_len:
|
||||
return text[:max_len] + "..."
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _extract_forced_tool_name(
|
||||
user_message: str, available_tool_names: List[str]
|
||||
) -> Optional[str]:
|
||||
if not user_message or not available_tool_names:
|
||||
return None
|
||||
|
||||
triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"]
|
||||
if not any(trigger in user_message for trigger in triggers):
|
||||
return None
|
||||
|
||||
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
|
||||
explicit_matches = [
|
||||
name for name in pattern.findall(user_message) if name in available_tool_names
|
||||
]
|
||||
if len(explicit_matches) == 1:
|
||||
return explicit_matches[0]
|
||||
if len(explicit_matches) > 1:
|
||||
return None
|
||||
|
||||
contained = [name for name in available_tool_names if name in user_message]
|
||||
if len(contained) == 1:
|
||||
return contained[0]
|
||||
|
||||
# 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。
|
||||
prefixes = sorted(
|
||||
{name.split(".", 1)[0] for name in available_tool_names},
|
||||
key=len,
|
||||
reverse=True,
|
||||
)
|
||||
matched_prefixes = [
|
||||
prefix
|
||||
for prefix in prefixes
|
||||
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
|
||||
]
|
||||
if len(matched_prefixes) == 1:
|
||||
prefix_tools = [
|
||||
name
|
||||
for name in available_tool_names
|
||||
if name.startswith(f"{matched_prefixes[0]}.")
|
||||
]
|
||||
if len(prefix_tools) == 1:
|
||||
return prefix_tools[0]
|
||||
|
||||
return None
|
||||
|
||||
def set_personality(self, personality_name: str) -> bool:
|
||||
"""设置人格。"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
OpenAI model implementation (including OpenAI-compatible providers).
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
@@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel):
|
||||
)
|
||||
return {}
|
||||
|
||||
def _build_forced_tool_params(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
forced_tool_name: Optional[str],
|
||||
tools: Optional[List[dict]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build request params for forcing one specific tool call."""
|
||||
if not forced_tool_name:
|
||||
return {}
|
||||
|
||||
available_tool_names = self._extract_tool_names(tools)
|
||||
if available_tool_names and forced_tool_name not in available_tool_names:
|
||||
self.logger.warning(
|
||||
"forced_tool_name is not in current tool list, ignored: "
|
||||
f"{forced_tool_name}"
|
||||
)
|
||||
return {}
|
||||
|
||||
if "tools" in params:
|
||||
return {
|
||||
"tool_choice": {
|
||||
"type": "function",
|
||||
"function": {"name": forced_tool_name},
|
||||
}
|
||||
}
|
||||
|
||||
if "functions" in params:
|
||||
return {"function_call": {"name": forced_tool_name}}
|
||||
|
||||
self.logger.warning(
|
||||
"forced_tool_name provided but tool params are unavailable, ignored: "
|
||||
f"{forced_tool_name}"
|
||||
)
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]:
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
names: List[str] = []
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
function_data = tool.get("function")
|
||||
if not isinstance(function_data, dict):
|
||||
continue
|
||||
name = function_data.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
@staticmethod
|
||||
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
if not isinstance(tool, dict):
|
||||
@@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel):
|
||||
self._supports_tools = False
|
||||
retry_params = dict(params)
|
||||
retry_params.pop("tools", None)
|
||||
forced_tool_name = self._extract_forced_tool_name_from_choice(
|
||||
retry_params.pop("tool_choice", None)
|
||||
)
|
||||
retry_params.update(self._build_tool_params(tools))
|
||||
if forced_tool_name and "functions" in retry_params:
|
||||
retry_params["function_call"] = {"name": forced_tool_name}
|
||||
return await self.client.chat.completions.create(**retry_params)
|
||||
|
||||
if "unexpected keyword argument 'functions'" in message and "functions" in params:
|
||||
@@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel):
|
||||
return await self.client.chat.completions.create(**retry_params)
|
||||
|
||||
raise
|
||||
except Exception as exc:
|
||||
if self._is_timeout_error(exc):
|
||||
return await self._retry_on_timeout(params)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _is_timeout_error(error: Exception) -> bool:
|
||||
if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
|
||||
return True
|
||||
|
||||
error_name = type(error).__name__.lower()
|
||||
if "timeout" in error_name:
|
||||
return True
|
||||
|
||||
message = str(error).lower()
|
||||
return "timed out" in message or "timeout" in message
|
||||
|
||||
async def _retry_on_timeout(self, params: Dict[str, Any]):
|
||||
base_timeout = float(self.config.timeout or 60)
|
||||
retry_timeout = min(max(base_timeout * 2, 120.0), 300.0)
|
||||
retry_params = dict(params)
|
||||
retry_params["timeout"] = retry_timeout
|
||||
self.logger.warning(
|
||||
"chat request timed out, retry once with longer timeout: "
|
||||
f"{base_timeout:.0f}s -> {retry_timeout:.0f}s"
|
||||
)
|
||||
try:
|
||||
return await self.client.chat.completions.create(**retry_params)
|
||||
except Exception as retry_exc:
|
||||
if self._is_timeout_error(retry_exc):
|
||||
self.logger.error(
|
||||
"chat request still timed out after retry: "
|
||||
f"timeout={retry_timeout:.0f}s"
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]:
|
||||
if not tool_choice:
|
||||
return None
|
||||
|
||||
if isinstance(tool_choice, dict):
|
||||
function_data = tool_choice.get("function")
|
||||
if isinstance(function_data, dict):
|
||||
name = function_data.get("name")
|
||||
return name if isinstance(name, str) and name else None
|
||||
return None
|
||||
|
||||
function_data = getattr(tool_choice, "function", None)
|
||||
if function_data:
|
||||
name = getattr(function_data, "name", None)
|
||||
return name if isinstance(name, str) and name else None
|
||||
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel):
|
||||
) -> Message:
|
||||
"""Non-stream chat."""
|
||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||
forced_tool_name = kwargs.pop("forced_tool_name", None)
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
@@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel):
|
||||
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
|
||||
|
||||
tool_mode = "none"
|
||||
tool_count = 0
|
||||
@@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel):
|
||||
|
||||
self.logger.info(
|
||||
"OpenAI chat request: "
|
||||
f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}"
|
||||
f"model={self.config.model_name}, tool_mode={tool_mode}, "
|
||||
f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}"
|
||||
)
|
||||
|
||||
response = await self._create_completion_with_fallback(params, tools)
|
||||
@@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel):
|
||||
) -> AsyncIterator[str]:
|
||||
"""Streaming chat."""
|
||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||
forced_tool_name = kwargs.pop("forced_tool_name", None)
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
@@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel):
|
||||
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
|
||||
|
||||
stream = await self._create_completion_with_fallback(params, tools)
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from botpy.message import Message
|
||||
|
||||
from src.ai import AIClient
|
||||
@@ -619,6 +621,12 @@ class MessageHandler:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
|
||||
await self._reply_plain(
|
||||
message,
|
||||
"模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。",
|
||||
)
|
||||
return
|
||||
await self._reply_plain(message, "消息处理失败,请稍后重试")
|
||||
|
||||
async def _handle_skills_command(self, message: Message, command: str):
|
||||
|
||||
Reference in New Issue
Block a user