Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message handling and model interactions, ensuring better user feedback and robustness.

This commit is contained in:
Mimikko-zeus
2026-03-03 14:25:21 +08:00
parent 7d7a4b8f54
commit 4a2666b1f2
5 changed files with 114 additions and 20 deletions

View File

@@ -195,6 +195,19 @@ class AIClient:
if forced_tool_name:
kwargs = dict(kwargs)
kwargs["forced_tool_name"] = forced_tool_name
if tools:
before_count = len(tools)
tools = [
tool
for tool in tools
if ((tool.get("function") or {}).get("name") == forced_tool_name)
]
if len(tools) == 1:
tool_names = [forced_tool_name]
logger.info(
"显式工具调用已收敛工具列表: "
f"{before_count} -> {len(tools)}"
)
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
logger.info(
@@ -237,6 +250,13 @@ class AIClient:
response = await self._handle_tool_calls(
messages, response, tools, **kwargs
)
elif forced_tool_name:
forced_response = await self._run_forced_tool_fallback(
forced_tool_name=forced_tool_name,
user_message=user_message,
)
if forced_response is not None:
response = forced_response
# 写入记忆
if use_memory:
@@ -259,6 +279,50 @@ class AIClient:
except Exception as e:
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
raise
async def _run_forced_tool_fallback(
self, forced_tool_name: str, user_message: str
) -> Optional[Message]:
"""Execute forced tool locally when model did not emit tool_calls."""
tool_def = self.tools.get(forced_tool_name)
tool_source = self._tool_sources.get(forced_tool_name, "custom")
if not tool_def:
logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}")
return None
logger.warning(
"模型未返回 tool_calls启用本地强制工具执行: "
f"source={tool_source}, name={forced_tool_name}"
)
try:
result = tool_def.function()
if inspect.isawaitable(result):
result = await result
except TypeError as exc:
logger.warning(
"本地强制工具执行失败(参数不匹配): "
f"name={forced_tool_name}, error={exc}"
)
return None
except Exception as exc:
logger.warning(
"本地强制工具执行失败: "
f"name={forced_tool_name}, error={exc}"
)
return None
result_text = str(result)
prefix_limit = self._extract_prefix_limit(user_message)
if prefix_limit:
result_text = result_text[:prefix_limit]
logger.info(
"本地强制工具执行成功: "
f"source={tool_source}, name={forced_tool_name}, "
f"result={self._preview_log_payload(result_text)}"
)
return Message(role="assistant", content=result_text)
async def _chat_stream(
self,
@@ -410,6 +474,25 @@ class AIClient:
return text[:max_len] + "..."
return text
@staticmethod
def _extract_prefix_limit(user_message: str) -> Optional[int]:
"""Extract requested output prefix length like '前100字'."""
if not user_message:
return None
match = re.search(r"\s*(\d{1,4})\s*字", user_message)
if not match:
return None
try:
limit = int(match.group(1))
except (TypeError, ValueError):
return None
if limit <= 0:
return None
return min(limit, 5000)
@staticmethod
def _extract_forced_tool_name(
user_message: str, available_tool_names: List[str]

View File

@@ -51,26 +51,30 @@ class OpenAIModel(BaseAIModel):
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build SDK-compatible tool/function request parameters."""
"""Build request parameters for tool-calling.
Prefer modern `tools` API first, then fallback to legacy `functions`
only when runtime rejects `tools`.
"""
if not tools:
return {}
if self._supports_tools:
return {"tools": tools}
return {"tools": tools}
if self._supports_functions:
functions = []
for tool in tools:
schema = self._extract_function_schema(tool)
if schema:
functions.append(schema)
def _build_legacy_function_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build legacy function-calling params from tool schema."""
if not tools:
return {}
if functions:
return {"functions": functions, "function_call": "auto"}
functions = []
for tool in tools:
schema = self._extract_function_schema(tool)
if schema:
functions.append(schema)
if functions:
return {"functions": functions, "function_call": "auto"}
self.logger.warning(
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
)
return {}
def _build_forced_tool_params(
@@ -164,7 +168,7 @@ class OpenAIModel(BaseAIModel):
forced_tool_name = self._extract_forced_tool_name_from_choice(
retry_params.pop("tool_choice", None)
)
retry_params.update(self._build_tool_params(tools))
retry_params.update(self._build_legacy_function_params(tools))
if forced_tool_name and "functions" in retry_params:
retry_params["function_call"] = {"name": forced_tool_name}
return await self.client.chat.completions.create(**retry_params)

View File

@@ -102,7 +102,7 @@ class MessageHandler:
f"{command_name} add <model_name> (保持 provider/api_base/api_key 不变)\n"
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
f"{command_name} add <key> <json>\n"
" json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
" json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n"
f"{command_name} switch <key|index>\n"
f"{command_name} remove <key|index>"
)