Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.

This commit is contained in:
Mimikko-zeus
2026-03-03 14:14:16 +08:00
parent 00501eb44d
commit 7d7a4b8f54
5 changed files with 343 additions and 4 deletions

View File

@@ -191,12 +191,18 @@ class AIClient:
if use_tools and self.tools.list(): if use_tools and self.tools.list():
tools = self.tools.to_openai_format() tools = self.tools.to_openai_format()
tool_names = [tool.name for tool in self.tools.list()] tool_names = [tool.name for tool in self.tools.list()]
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
if forced_tool_name:
kwargs = dict(kwargs)
kwargs["forced_tool_name"] = forced_tool_name
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
logger.info( logger.info(
"LLM请求: " "LLM请求: "
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, " f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, " f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
f"tool_names={self._preview_log_payload(tool_names)}" f"tool_names={self._preview_log_payload(tool_names)}, "
f"forced_tool={forced_tool_name or '-'}"
) )
logger.info( logger.info(
"LLM输入: " "LLM输入: "
@@ -251,7 +257,7 @@ class AIClient:
return response.content return response.content
except Exception as e: except Exception as e:
logger.error(f"对话失败: {e}") logger.error(f"对话失败: {type(e).__name__}: {e!r}")
raise raise
async def _chat_stream( async def _chat_stream(
@@ -342,7 +348,10 @@ class AIClient:
)) ))
# 再次调用模型获取最终响应 # 再次调用模型获取最终响应
final_response = await self.model.chat(messages, tools, **kwargs) final_kwargs = dict(kwargs)
# Force only the first model turn, avoid recursive force after tool result.
final_kwargs.pop("forced_tool_name", None)
final_response = await self.model.chat(messages, tools, **final_kwargs)
logger.info( logger.info(
"LLM最终输出: " "LLM最终输出: "
f"content={self._preview_log_payload(final_response.content)}" f"content={self._preview_log_payload(final_response.content)}"
@@ -400,6 +409,52 @@ class AIClient:
if len(text) > max_len: if len(text) > max_len:
return text[:max_len] + "..." return text[:max_len] + "..."
return text return text
@staticmethod
def _extract_forced_tool_name(
user_message: str, available_tool_names: List[str]
) -> Optional[str]:
if not user_message or not available_tool_names:
return None
triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"]
if not any(trigger in user_message for trigger in triggers):
return None
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
explicit_matches = [
name for name in pattern.findall(user_message) if name in available_tool_names
]
if len(explicit_matches) == 1:
return explicit_matches[0]
if len(explicit_matches) > 1:
return None
contained = [name for name in available_tool_names if name in user_message]
if len(contained) == 1:
return contained[0]
# 允许只写 skill/tool 前缀(如 humanizer_zh前提是前缀下只有一个工具。
prefixes = sorted(
{name.split(".", 1)[0] for name in available_tool_names},
key=len,
reverse=True,
)
matched_prefixes = [
prefix
for prefix in prefixes
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
]
if len(matched_prefixes) == 1:
prefix_tools = [
name
for name in available_tool_names
if name.startswith(f"{matched_prefixes[0]}.")
]
if len(prefix_tools) == 1:
return prefix_tools[0]
return None
def set_personality(self, personality_name: str) -> bool: def set_personality(self, personality_name: str) -> bool:
"""设置人格。""" """设置人格。"""

View File

@@ -1,6 +1,7 @@
""" """
OpenAI model implementation (including OpenAI-compatible providers). OpenAI model implementation (including OpenAI-compatible providers).
""" """
import asyncio
import inspect import inspect
import json import json
import re import re
@@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel):
) )
return {} return {}
def _build_forced_tool_params(
self,
params: Dict[str, Any],
forced_tool_name: Optional[str],
tools: Optional[List[dict]],
) -> Dict[str, Any]:
"""Build request params for forcing one specific tool call."""
if not forced_tool_name:
return {}
available_tool_names = self._extract_tool_names(tools)
if available_tool_names and forced_tool_name not in available_tool_names:
self.logger.warning(
"forced_tool_name is not in current tool list, ignored: "
f"{forced_tool_name}"
)
return {}
if "tools" in params:
return {
"tool_choice": {
"type": "function",
"function": {"name": forced_tool_name},
}
}
if "functions" in params:
return {"function_call": {"name": forced_tool_name}}
self.logger.warning(
"forced_tool_name provided but tool params are unavailable, ignored: "
f"{forced_tool_name}"
)
return {}
@staticmethod
def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]:
if not tools:
return []
names: List[str] = []
for tool in tools:
if not isinstance(tool, dict):
continue
function_data = tool.get("function")
if not isinstance(function_data, dict):
continue
name = function_data.get("name")
if isinstance(name, str) and name:
names.append(name)
return names
@staticmethod @staticmethod
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]: def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not isinstance(tool, dict): if not isinstance(tool, dict):
@@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel):
self._supports_tools = False self._supports_tools = False
retry_params = dict(params) retry_params = dict(params)
retry_params.pop("tools", None) retry_params.pop("tools", None)
forced_tool_name = self._extract_forced_tool_name_from_choice(
retry_params.pop("tool_choice", None)
)
retry_params.update(self._build_tool_params(tools)) retry_params.update(self._build_tool_params(tools))
if forced_tool_name and "functions" in retry_params:
retry_params["function_call"] = {"name": forced_tool_name}
return await self.client.chat.completions.create(**retry_params) return await self.client.chat.completions.create(**retry_params)
if "unexpected keyword argument 'functions'" in message and "functions" in params: if "unexpected keyword argument 'functions'" in message and "functions" in params:
@@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel):
return await self.client.chat.completions.create(**retry_params) return await self.client.chat.completions.create(**retry_params)
raise raise
except Exception as exc:
if self._is_timeout_error(exc):
return await self._retry_on_timeout(params)
raise
@staticmethod
def _is_timeout_error(error: Exception) -> bool:
if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
return True
error_name = type(error).__name__.lower()
if "timeout" in error_name:
return True
message = str(error).lower()
return "timed out" in message or "timeout" in message
async def _retry_on_timeout(self, params: Dict[str, Any]):
base_timeout = float(self.config.timeout or 60)
retry_timeout = min(max(base_timeout * 2, 120.0), 300.0)
retry_params = dict(params)
retry_params["timeout"] = retry_timeout
self.logger.warning(
"chat request timed out, retry once with longer timeout: "
f"{base_timeout:.0f}s -> {retry_timeout:.0f}s"
)
try:
return await self.client.chat.completions.create(**retry_params)
except Exception as retry_exc:
if self._is_timeout_error(retry_exc):
self.logger.error(
"chat request still timed out after retry: "
f"timeout={retry_timeout:.0f}s"
)
raise
@staticmethod
def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]:
if not tool_choice:
return None
if isinstance(tool_choice, dict):
function_data = tool_choice.get("function")
if isinstance(function_data, dict):
name = function_data.get("name")
return name if isinstance(name, str) and name else None
return None
function_data = getattr(tool_choice, "function", None)
if function_data:
name = getattr(function_data, "name", None)
return name if isinstance(name, str) and name else None
return None
async def chat( async def chat(
self, self,
@@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel):
) -> Message: ) -> Message:
"""Non-stream chat.""" """Non-stream chat."""
formatted_messages = [self._format_message(msg) for msg in messages] formatted_messages = [self._format_message(msg) for msg in messages]
forced_tool_name = kwargs.pop("forced_tool_name", None)
params = { params = {
"model": self.config.model_name, "model": self.config.model_name,
@@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel):
params.update(self._build_tool_params(tools)) params.update(self._build_tool_params(tools))
params.update(kwargs) params.update(kwargs)
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
tool_mode = "none" tool_mode = "none"
tool_count = 0 tool_count = 0
@@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel):
self.logger.info( self.logger.info(
"OpenAI chat request: " "OpenAI chat request: "
f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}" f"model={self.config.model_name}, tool_mode={tool_mode}, "
f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}"
) )
response = await self._create_completion_with_fallback(params, tools) response = await self._create_completion_with_fallback(params, tools)
@@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel):
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Streaming chat.""" """Streaming chat."""
formatted_messages = [self._format_message(msg) for msg in messages] formatted_messages = [self._format_message(msg) for msg in messages]
forced_tool_name = kwargs.pop("forced_tool_name", None)
params = { params = {
"model": self.config.model_name, "model": self.config.model_name,
@@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel):
params.update(self._build_tool_params(tools)) params.update(self._build_tool_params(tools))
params.update(kwargs) params.update(kwargs)
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
stream = await self._create_completion_with_fallback(params, tools) stream = await self._create_completion_with_fallback(params, tools)

View File

@@ -8,6 +8,8 @@ from pathlib import Path
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import httpx
from botpy.message import Message from botpy.message import Message
from src.ai import AIClient from src.ai import AIClient
@@ -619,6 +621,12 @@ class MessageHandler:
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
await self._reply_plain(
message,
"模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。",
)
return
await self._reply_plain(message, "消息处理失败,请稍后重试") await self._reply_plain(message, "消息处理失败,请稍后重试")
async def _handle_skills_command(self, message: Message, command: str): async def _handle_skills_command(self, message: Message, command: str):

View File

@@ -0,0 +1,39 @@
"""Tests for AIClient forced tool name extraction."""
from src.ai.client import AIClient
def test_extract_forced_tool_name_full_name():
tools = [
"humanizer_zh.read_skill_doc",
"skills_creator.create_skill",
]
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced == "humanizer_zh.read_skill_doc"
def test_extract_forced_tool_name_unique_prefix():
tools = [
"humanizer_zh.read_skill_doc",
"skills_creator.create_skill",
]
message = "please call tool humanizer_zh only"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced == "humanizer_zh.read_skill_doc"
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
tools = [
"skills_creator.create_skill",
"skills_creator.reload_skill",
]
message = "please call tool skills_creator"
forced = AIClient._extract_forced_tool_name(message, tools)
assert forced is None

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
from types import SimpleNamespace from types import SimpleNamespace
import httpx
import src.ai.models.openai_model as openai_model_module import src.ai.models.openai_model as openai_model_module
from src.ai.base import Message, ModelConfig, ModelProvider from src.ai.base import Message, ModelConfig, ModelProvider
from src.ai.models.openai_model import OpenAIModel from src.ai.models.openai_model import OpenAIModel
@@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI:
self.embeddings = _LengthLimitedEmbeddings() self.embeddings = _LengthLimitedEmbeddings()
class _TimeoutOnceCompletions:
def __init__(self):
self.calls = []
async def create(
self,
*,
model,
messages,
temperature=None,
max_tokens=None,
top_p=None,
frequency_penalty=None,
presence_penalty=None,
tools=None,
stream=False,
timeout=None,
**kwargs,
):
self.calls.append(
{
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"tools": tools,
"stream": stream,
"timeout": timeout,
**kwargs,
}
)
if len(self.calls) == 1:
raise httpx.ReadTimeout("timed out")
message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None)
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
class _TimeoutOnceAsyncOpenAI:
def __init__(self, **kwargs):
self.completions = _TimeoutOnceCompletions()
self.chat = SimpleNamespace(completions=self.completions)
self.embeddings = _FakeEmbeddings()
def test_openai_model_uses_tools_when_supported(monkeypatch): def test_openai_model_uses_tools_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
@@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool" assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_forces_tool_choice_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
model = OpenAIModel(_model_config())
tools = _tool_defs()
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=tools,
forced_tool_name="demo_tool",
)
)
sent = model.client.completions.last_params
assert sent["tool_choice"]["type"] == "function"
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
@@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool" assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
model = OpenAIModel(_model_config())
tools = _tool_defs()
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=tools,
forced_tool_name="demo_tool",
)
)
sent = model.client.completions.last_params
assert sent["function_call"] == {"name": "demo_tool"}
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
@@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool" assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch):
monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
)
model = OpenAIModel(_model_config())
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=_tool_defs(),
forced_tool_name="demo_tool",
)
)
calls = model.client.completions.calls
assert len(calls) == 2
assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool"
assert calls[1]["function_call"] == {"name": "demo_tool"}
def test_openai_model_retries_once_on_read_timeout(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI)
model = OpenAIModel(_model_config())
result = asyncio.run(
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
)
calls = model.client.completions.calls
assert len(calls) == 2
assert calls[0]["timeout"] is None
assert calls[1]["timeout"] == 120.0
assert result.content == "ok-after-timeout"
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI