Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message handling and model interactions, ensuring better user feedback and robustness.
This commit is contained in:
@@ -195,6 +195,19 @@ class AIClient:
|
|||||||
if forced_tool_name:
|
if forced_tool_name:
|
||||||
kwargs = dict(kwargs)
|
kwargs = dict(kwargs)
|
||||||
kwargs["forced_tool_name"] = forced_tool_name
|
kwargs["forced_tool_name"] = forced_tool_name
|
||||||
|
if tools:
|
||||||
|
before_count = len(tools)
|
||||||
|
tools = [
|
||||||
|
tool
|
||||||
|
for tool in tools
|
||||||
|
if ((tool.get("function") or {}).get("name") == forced_tool_name)
|
||||||
|
]
|
||||||
|
if len(tools) == 1:
|
||||||
|
tool_names = [forced_tool_name]
|
||||||
|
logger.info(
|
||||||
|
"显式工具调用已收敛工具列表: "
|
||||||
|
f"{before_count} -> {len(tools)}"
|
||||||
|
)
|
||||||
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
|
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -237,6 +250,13 @@ class AIClient:
|
|||||||
response = await self._handle_tool_calls(
|
response = await self._handle_tool_calls(
|
||||||
messages, response, tools, **kwargs
|
messages, response, tools, **kwargs
|
||||||
)
|
)
|
||||||
|
elif forced_tool_name:
|
||||||
|
forced_response = await self._run_forced_tool_fallback(
|
||||||
|
forced_tool_name=forced_tool_name,
|
||||||
|
user_message=user_message,
|
||||||
|
)
|
||||||
|
if forced_response is not None:
|
||||||
|
response = forced_response
|
||||||
|
|
||||||
# 写入记忆
|
# 写入记忆
|
||||||
if use_memory:
|
if use_memory:
|
||||||
@@ -260,6 +280,50 @@ class AIClient:
|
|||||||
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
|
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _run_forced_tool_fallback(
|
||||||
|
self, forced_tool_name: str, user_message: str
|
||||||
|
) -> Optional[Message]:
|
||||||
|
"""Execute forced tool locally when model did not emit tool_calls."""
|
||||||
|
tool_def = self.tools.get(forced_tool_name)
|
||||||
|
tool_source = self._tool_sources.get(forced_tool_name, "custom")
|
||||||
|
if not tool_def:
|
||||||
|
logger.warning(f"强制工具回退失败,未找到工具: {forced_tool_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"模型未返回 tool_calls,启用本地强制工具执行: "
|
||||||
|
f"source={tool_source}, name={forced_tool_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = tool_def.function()
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
except TypeError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"本地强制工具执行失败(参数不匹配): "
|
||||||
|
f"name={forced_tool_name}, error={exc}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"本地强制工具执行失败: "
|
||||||
|
f"name={forced_tool_name}, error={exc}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
result_text = str(result)
|
||||||
|
prefix_limit = self._extract_prefix_limit(user_message)
|
||||||
|
if prefix_limit:
|
||||||
|
result_text = result_text[:prefix_limit]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"本地强制工具执行成功: "
|
||||||
|
f"source={tool_source}, name={forced_tool_name}, "
|
||||||
|
f"result={self._preview_log_payload(result_text)}"
|
||||||
|
)
|
||||||
|
return Message(role="assistant", content=result_text)
|
||||||
|
|
||||||
async def _chat_stream(
|
async def _chat_stream(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
@@ -410,6 +474,25 @@ class AIClient:
|
|||||||
return text[:max_len] + "..."
|
return text[:max_len] + "..."
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_prefix_limit(user_message: str) -> Optional[int]:
|
||||||
|
"""Extract requested output prefix length like '前100字'."""
|
||||||
|
if not user_message:
|
||||||
|
return None
|
||||||
|
|
||||||
|
match = re.search(r"前\s*(\d{1,4})\s*字", user_message)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit = int(match.group(1))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if limit <= 0:
|
||||||
|
return None
|
||||||
|
return min(limit, 5000)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_forced_tool_name(
|
def _extract_forced_tool_name(
|
||||||
user_message: str, available_tool_names: List[str]
|
user_message: str, available_tool_names: List[str]
|
||||||
|
|||||||
@@ -51,14 +51,21 @@ class OpenAIModel(BaseAIModel):
|
|||||||
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
|
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
|
||||||
|
|
||||||
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
|
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
|
||||||
"""Build SDK-compatible tool/function request parameters."""
|
"""Build request parameters for tool-calling.
|
||||||
|
|
||||||
|
Prefer modern `tools` API first, then fallback to legacy `functions`
|
||||||
|
only when runtime rejects `tools`.
|
||||||
|
"""
|
||||||
if not tools:
|
if not tools:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if self._supports_tools:
|
|
||||||
return {"tools": tools}
|
return {"tools": tools}
|
||||||
|
|
||||||
if self._supports_functions:
|
def _build_legacy_function_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
|
||||||
|
"""Build legacy function-calling params from tool schema."""
|
||||||
|
if not tools:
|
||||||
|
return {}
|
||||||
|
|
||||||
functions = []
|
functions = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
schema = self._extract_function_schema(tool)
|
schema = self._extract_function_schema(tool)
|
||||||
@@ -68,9 +75,6 @@ class OpenAIModel(BaseAIModel):
|
|||||||
if functions:
|
if functions:
|
||||||
return {"functions": functions, "function_call": "auto"}
|
return {"functions": functions, "function_call": "auto"}
|
||||||
|
|
||||||
self.logger.warning(
|
|
||||||
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
|
|
||||||
)
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _build_forced_tool_params(
|
def _build_forced_tool_params(
|
||||||
@@ -164,7 +168,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
forced_tool_name = self._extract_forced_tool_name_from_choice(
|
forced_tool_name = self._extract_forced_tool_name_from_choice(
|
||||||
retry_params.pop("tool_choice", None)
|
retry_params.pop("tool_choice", None)
|
||||||
)
|
)
|
||||||
retry_params.update(self._build_tool_params(tools))
|
retry_params.update(self._build_legacy_function_params(tools))
|
||||||
if forced_tool_name and "functions" in retry_params:
|
if forced_tool_name and "functions" in retry_params:
|
||||||
retry_params["function_call"] = {"name": forced_tool_name}
|
retry_params["function_call"] = {"name": forced_tool_name}
|
||||||
return await self.client.chat.completions.create(**retry_params)
|
return await self.client.chat.completions.create(**retry_params)
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class MessageHandler:
|
|||||||
f"{command_name} add <model_name> (保持 provider/api_base/api_key 不变)\n"
|
f"{command_name} add <model_name> (保持 provider/api_base/api_key 不变)\n"
|
||||||
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
|
f"{command_name} add <key> <provider> <model_name> [api_base]\n"
|
||||||
f"{command_name} add <key> <json>\n"
|
f"{command_name} add <key> <json>\n"
|
||||||
" json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p\n"
|
" json 字段:provider, model_name, api_base, api_key, temperature, max_tokens, top_p, timeout\n"
|
||||||
f"{command_name} switch <key|index>\n"
|
f"{command_name} switch <key|index>\n"
|
||||||
f"{command_name} remove <key|index>"
|
f"{command_name} remove <key|index>"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,3 +37,9 @@ def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
|
|||||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||||
|
|
||||||
assert forced is None
|
assert forced is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_prefix_limit_from_user_message():
|
||||||
|
assert AIClient._extract_prefix_limit("直接返回前100字") == 100
|
||||||
|
assert AIClient._extract_prefix_limit("前 256 字") == 256
|
||||||
|
assert AIClient._extract_prefix_limit("返回全文") is None
|
||||||
|
|||||||
@@ -313,9 +313,9 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
|||||||
sent = model.client.completions.last_params
|
sent = model.client.completions.last_params
|
||||||
assert model._supports_tools is False
|
assert model._supports_tools is False
|
||||||
assert model._supports_functions is True
|
assert model._supports_functions is True
|
||||||
assert sent["function_call"] == "auto"
|
assert sent["tools"] == tools
|
||||||
assert isinstance(sent["functions"], list) and sent["functions"]
|
assert sent["function_call"] is None
|
||||||
assert sent["functions"][0]["name"] == "demo_tool"
|
assert sent["functions"] is None
|
||||||
assert result.tool_calls is not None
|
assert result.tool_calls is not None
|
||||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
@@ -334,7 +334,8 @@ def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
sent = model.client.completions.last_params
|
sent = model.client.completions.last_params
|
||||||
assert sent["function_call"] == {"name": "demo_tool"}
|
assert sent["tool_choice"]["type"] == "function"
|
||||||
|
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
||||||
|
|||||||
Reference in New Issue
Block a user