Add fallback mechanism for OpenAIModel completion requests
This commit is contained in:
@@ -91,6 +91,36 @@ class OpenAIModel(BaseAIModel):
|
||||
),
|
||||
}
|
||||
|
||||
async def _create_completion_with_fallback(
|
||||
self, params: Dict[str, Any], tools: Optional[List[dict]]
|
||||
):
|
||||
"""Create completion with runtime fallback for old SDK signatures."""
|
||||
try:
|
||||
return await self.client.chat.completions.create(**params)
|
||||
except TypeError as exc:
|
||||
message = str(exc)
|
||||
if "unexpected keyword argument 'tools'" in message and "tools" in params:
|
||||
self.logger.warning(
|
||||
"SDK rejected `tools` at runtime, retrying with legacy functions."
|
||||
)
|
||||
self._supports_tools = False
|
||||
retry_params = dict(params)
|
||||
retry_params.pop("tools", None)
|
||||
retry_params.update(self._build_tool_params(tools))
|
||||
return await self.client.chat.completions.create(**retry_params)
|
||||
|
||||
if "unexpected keyword argument 'functions'" in message and "functions" in params:
|
||||
self.logger.warning(
|
||||
"SDK rejected `functions` at runtime, retrying without tool calling."
|
||||
)
|
||||
self._supports_functions = False
|
||||
retry_params = dict(params)
|
||||
retry_params.pop("functions", None)
|
||||
retry_params.pop("function_call", None)
|
||||
return await self.client.chat.completions.create(**retry_params)
|
||||
|
||||
raise
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
@@ -113,7 +143,7 @@ class OpenAIModel(BaseAIModel):
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
response = await self._create_completion_with_fallback(params, tools)
|
||||
|
||||
choice = response.choices[0]
|
||||
tool_calls = self._extract_response_tool_calls(choice.message)
|
||||
@@ -143,7 +173,7 @@ class OpenAIModel(BaseAIModel):
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
|
||||
stream = await self.client.chat.completions.create(**params)
|
||||
stream = await self._create_completion_with_fallback(params, tools)
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
278
tests/test_openai_model_compat.py
Normal file
278
tests/test_openai_model_compat.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Compatibility tests for OpenAIModel tool-calling behavior."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.ai.models.openai_model as openai_model_module
|
||||
from src.ai.base import Message, ModelConfig, ModelProvider
|
||||
from src.ai.models.openai_model import OpenAIModel
|
||||
|
||||
|
||||
def _model_config() -> ModelConfig:
|
||||
return ModelConfig(
|
||||
provider=ModelProvider.OPENAI,
|
||||
model_name="test-model",
|
||||
api_key="test-key",
|
||||
api_base="https://example.com/v1",
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
|
||||
def _tool_defs():
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "demo_tool",
|
||||
"description": "Demo tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _FakeEmbeddings:
|
||||
async def create(self, **kwargs):
|
||||
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
||||
|
||||
|
||||
class _ModernCompletions:
|
||||
def __init__(self):
|
||||
self.last_params = None
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
tools=None,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.last_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
message = SimpleNamespace(
|
||||
content="ok",
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name="demo_tool", arguments='{"city":"beijing"}'
|
||||
),
|
||||
)
|
||||
],
|
||||
function_call=None,
|
||||
)
|
||||
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||
|
||||
|
||||
class _LegacyCompletions:
|
||||
def __init__(self):
|
||||
self.last_params = None
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
functions=None,
|
||||
function_call=None,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.last_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"functions": functions,
|
||||
"function_call": function_call,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
message = SimpleNamespace(
|
||||
content="",
|
||||
tool_calls=None,
|
||||
function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'),
|
||||
)
|
||||
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||
|
||||
|
||||
class _RuntimeRejectToolsCompletions:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
tools=None,
|
||||
functions=None,
|
||||
function_call=None,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"tools": tools,
|
||||
"functions": functions,
|
||||
"function_call": function_call,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
if tools is not None:
|
||||
raise TypeError("AsyncCompletions.create() got an unexpected keyword argument 'tools'")
|
||||
|
||||
message = SimpleNamespace(
|
||||
content="",
|
||||
tool_calls=None,
|
||||
function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'),
|
||||
)
|
||||
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||
|
||||
|
||||
class _ModernAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.completions = _ModernCompletions()
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
self.embeddings = _FakeEmbeddings()
|
||||
|
||||
|
||||
class _LegacyAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.completions = _LegacyCompletions()
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
self.embeddings = _FakeEmbeddings()
|
||||
|
||||
|
||||
class _RuntimeRejectToolsAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.completions = _RuntimeRejectToolsCompletions()
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
self.embeddings = _FakeEmbeddings()
|
||||
|
||||
|
||||
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
tools = _tool_defs()
|
||||
result = asyncio.run(
|
||||
model.chat(messages=[Message(role="user", content="hi")], tools=tools)
|
||||
)
|
||||
|
||||
sent = model.client.completions.last_params
|
||||
assert model._supports_tools is True
|
||||
assert sent["tools"] == tools
|
||||
assert "functions" not in sent
|
||||
assert result.tool_calls is not None
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
tools = _tool_defs()
|
||||
result = asyncio.run(
|
||||
model.chat(messages=[Message(role="user", content="hi")], tools=tools)
|
||||
)
|
||||
|
||||
sent = model.client.completions.last_params
|
||||
assert model._supports_tools is False
|
||||
assert model._supports_functions is True
|
||||
assert sent["function_call"] == "auto"
|
||||
assert isinstance(sent["functions"], list) and sent["functions"]
|
||||
assert sent["functions"][0]["name"] == "demo_tool"
|
||||
assert result.tool_calls is not None
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
tool_message = Message(role="tool", name="demo_tool", content="done")
|
||||
formatted_tool = model._format_message(tool_message)
|
||||
|
||||
assistant_message = Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "demo_tool", "arguments": '{"city":"beijing"}'},
|
||||
}
|
||||
],
|
||||
)
|
||||
formatted_assistant = model._format_message(assistant_message)
|
||||
|
||||
assert formatted_tool["role"] == "function"
|
||||
assert formatted_tool["name"] == "demo_tool"
|
||||
assert "function_call" in formatted_assistant
|
||||
assert "tool_calls" not in formatted_assistant
|
||||
|
||||
|
||||
def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
|
||||
)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
result = asyncio.run(
|
||||
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
|
||||
)
|
||||
|
||||
calls = model.client.completions.calls
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["tools"] is not None
|
||||
assert calls[1]["tools"] is None
|
||||
assert calls[1]["functions"][0]["name"] == "demo_tool"
|
||||
assert model._supports_tools is False
|
||||
assert result.tool_calls is not None
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
Reference in New Issue
Block a user