Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message handling and model interactions, ensuring better user feedback and robustness.

This commit is contained in:
Mimikko-zeus
2026-03-03 14:25:21 +08:00
parent 7d7a4b8f54
commit 4a2666b1f2
5 changed files with 114 additions and 20 deletions

View File

@@ -195,6 +195,19 @@ class AIClient:
if forced_tool_name: if forced_tool_name:
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs["forced_tool_name"] = forced_tool_name kwargs["forced_tool_name"] = forced_tool_name
if tools:
before_count = len(tools)
tools = [
tool
for tool in tools
if ((tool.get("function") or {}).get("name") == forced_tool_name)
]
if len(tools) == 1:
tool_names = [forced_tool_name]
logger.info(
"显式工具调用已收敛工具列表: "
f"{before_count} -> {len(tools)}"
)
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}") logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
logger.info( logger.info(
@@ -237,6 +250,13 @@ class AIClient:
response = await self._handle_tool_calls( response = await self._handle_tool_calls(
messages, response, tools, **kwargs messages, response, tools, **kwargs
) )
elif forced_tool_name:
forced_response = await self._run_forced_tool_fallback(
forced_tool_name=forced_tool_name,
user_message=user_message,
)
if forced_response is not None:
response = forced_response
# 写入记忆 # 写入记忆
if use_memory: if use_memory:
@@ -260,6 +280,50 @@ class AIClient:
logger.error(f"对话失败: {type(e).__name__}: {e!r}") logger.error(f"对话失败: {type(e).__name__}: {e!r}")
raise raise
async def _run_forced_tool_fallback(
self, forced_tool_name: str, user_message: str
) -> Optional[Message]:
"""Execute forced tool locally when model did not emit tool_calls."""
tool_def = self.tools.get(forced_tool_name)
tool_source = self._tool_sources.get(forced_tool_name, "custom")
if not tool_def:
logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}")
return None
logger.warning(
"模型未返回 tool_calls启用本地强制工具执行: "
f"source={tool_source}, name={forced_tool_name}"
)
try:
result = tool_def.function()
if inspect.isawaitable(result):
result = await result
except TypeError as exc:
logger.warning(
"本地强制工具执行失败(参数不匹配): "
f"name={forced_tool_name}, error={exc}"
)
return None
except Exception as exc:
logger.warning(
"本地强制工具执行失败: "
f"name={forced_tool_name}, error={exc}"
)
return None
result_text = str(result)
prefix_limit = self._extract_prefix_limit(user_message)
if prefix_limit:
result_text = result_text[:prefix_limit]
logger.info(
"本地强制工具执行成功: "
f"source={tool_source}, name={forced_tool_name}, "
f"result={self._preview_log_payload(result_text)}"
)
return Message(role="assistant", content=result_text)
async def _chat_stream( async def _chat_stream(
self, self,
messages: List[Message], messages: List[Message],
@@ -410,6 +474,25 @@ class AIClient:
return text[:max_len] + "..." return text[:max_len] + "..."
return text return text
@staticmethod
def _extract_prefix_limit(user_message: str) -> Optional[int]:
"""Extract requested output prefix length like '前100字'."""
if not user_message:
return None
match = re.search(r"\s*(\d{1,4})\s*字", user_message)
if not match:
return None
try:
limit = int(match.group(1))
except (TypeError, ValueError):
return None
if limit <= 0:
return None
return min(limit, 5000)
@staticmethod @staticmethod
def _extract_forced_tool_name( def _extract_forced_tool_name(
user_message: str, available_tool_names: List[str] user_message: str, available_tool_names: List[str]

View File

@@ -51,14 +51,21 @@ class OpenAIModel(BaseAIModel):
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}") self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build SDK-compatible tool/function request parameters.""" """Build request parameters for tool-calling.
Prefer modern `tools` API first, then fallback to legacy `functions`
only when runtime rejects `tools`.
"""
if not tools: if not tools:
return {} return {}
if self._supports_tools:
return {"tools": tools} return {"tools": tools}
if self._supports_functions: def _build_legacy_function_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build legacy function-calling params from tool schema."""
if not tools:
return {}
functions = [] functions = []
for tool in tools: for tool in tools:
schema = self._extract_function_schema(tool) schema = self._extract_function_schema(tool)
@@ -68,9 +75,6 @@ class OpenAIModel(BaseAIModel):
if functions: if functions:
return {"functions": functions, "function_call": "auto"} return {"functions": functions, "function_call": "auto"}
self.logger.warning(
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
)
return {} return {}
def _build_forced_tool_params( def _build_forced_tool_params(
@@ -164,7 +168,7 @@ class OpenAIModel(BaseAIModel):
forced_tool_name = self._extract_forced_tool_name_from_choice( forced_tool_name = self._extract_forced_tool_name_from_choice(
retry_params.pop("tool_choice", None) retry_params.pop("tool_choice", None)
) )
retry_params.update(self._build_tool_params(tools)) retry_params.update(self._build_legacy_function_params(tools))
if forced_tool_name and "functions" in retry_params: if forced_tool_name and "functions" in retry_params:
retry_params["function_call"] = {"name": forced_tool_name} retry_params["function_call"] = {"name": forced_tool_name}
return await self.client.chat.completions.create(**retry_params) return await self.client.chat.completions.create(**retry_params)

View File

@@ -102,7 +102,7 @@ class MessageHandler:
f"{command_name} add <model_name> (保持 provider/api_base/api_key 不变)\n" f"{command_name} add <model_name> (保持 provider/api_base/api_key 不变)\n"
f"{command_name} add <key> <provider> <model_name> [api_base]\n" f"{command_name} add <key> <provider> <model_name> [api_base]\n"
f"{command_name} add <key> <json>\n" f"{command_name} add <key> <json>\n"
" json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n" " json 字段provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n"
f"{command_name} switch <key|index>\n" f"{command_name} switch <key|index>\n"
f"{command_name} remove <key|index>" f"{command_name} remove <key|index>"
) )

View File

@@ -37,3 +37,9 @@ def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
forced = AIClient._extract_forced_tool_name(message, tools) forced = AIClient._extract_forced_tool_name(message, tools)
assert forced is None assert forced is None
def test_extract_prefix_limit_from_user_message():
assert AIClient._extract_prefix_limit("直接返回前100字") == 100
assert AIClient._extract_prefix_limit("前 256 字") == 256
assert AIClient._extract_prefix_limit("返回全文") is None

View File

@@ -313,9 +313,9 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
sent = model.client.completions.last_params sent = model.client.completions.last_params
assert model._supports_tools is False assert model._supports_tools is False
assert model._supports_functions is True assert model._supports_functions is True
assert sent["function_call"] == "auto" assert sent["tools"] == tools
assert isinstance(sent["functions"], list) and sent["functions"] assert sent["function_call"] is None
assert sent["functions"][0]["name"] == "demo_tool" assert sent["functions"] is None
assert result.tool_calls is not None assert result.tool_calls is not None
assert result.tool_calls[0]["function"]["name"] == "demo_tool" assert result.tool_calls[0]["function"]["name"] == "demo_tool"
@@ -334,7 +334,8 @@ def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
) )
sent = model.client.completions.last_params sent = model.client.completions.last_params
assert sent["function_call"] == {"name": "demo_tool"} assert sent["tool_choice"]["type"] == "function"
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):