Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.

This commit is contained in:
Mimikko-zeus
2026-03-03 14:14:16 +08:00
parent 00501eb44d
commit 7d7a4b8f54
5 changed files with 343 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
import asyncio
from types import SimpleNamespace
import httpx
import src.ai.models.openai_model as openai_model_module
from src.ai.base import Message, ModelConfig, ModelProvider
from src.ai.models.openai_model import OpenAIModel
@@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI:
self.embeddings = _LengthLimitedEmbeddings()
class _TimeoutOnceCompletions:
def __init__(self):
self.calls = []
async def create(
self,
*,
model,
messages,
temperature=None,
max_tokens=None,
top_p=None,
frequency_penalty=None,
presence_penalty=None,
tools=None,
stream=False,
timeout=None,
**kwargs,
):
self.calls.append(
{
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"tools": tools,
"stream": stream,
"timeout": timeout,
**kwargs,
}
)
if len(self.calls) == 1:
raise httpx.ReadTimeout("timed out")
message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None)
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
class _TimeoutOnceAsyncOpenAI:
def __init__(self, **kwargs):
self.completions = _TimeoutOnceCompletions()
self.chat = SimpleNamespace(completions=self.completions)
self.embeddings = _FakeEmbeddings()
def test_openai_model_uses_tools_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
@@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_forces_tool_choice_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
model = OpenAIModel(_model_config())
tools = _tool_defs()
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=tools,
forced_tool_name="demo_tool",
)
)
sent = model.client.completions.last_params
assert sent["tool_choice"]["type"] == "function"
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
@@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
model = OpenAIModel(_model_config())
tools = _tool_defs()
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=tools,
forced_tool_name="demo_tool",
)
)
sent = model.client.completions.last_params
assert sent["function_call"] == {"name": "demo_tool"}
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
@@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch):
monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
)
model = OpenAIModel(_model_config())
asyncio.run(
model.chat(
messages=[Message(role="user", content="hi")],
tools=_tool_defs(),
forced_tool_name="demo_tool",
)
)
calls = model.client.completions.calls
assert len(calls) == 2
assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool"
assert calls[1]["function_call"] == {"name": "demo_tool"}
def test_openai_model_retries_once_on_read_timeout(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI)
model = OpenAIModel(_model_config())
result = asyncio.run(
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
)
calls = model.client.completions.calls
assert len(calls) == 2
assert calls[0]["timeout"] is None
assert calls[1]["timeout"] == 120.0
assert result.content == "ok-after-timeout"
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI