diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index 33934c5..dd9fac4 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -91,6 +91,36 @@ class OpenAIModel(BaseAIModel): ), } + async def _create_completion_with_fallback( + self, params: Dict[str, Any], tools: Optional[List[dict]] + ): + """Create completion with runtime fallback for old SDK signatures.""" + try: + return await self.client.chat.completions.create(**params) + except TypeError as exc: + message = str(exc) + if "unexpected keyword argument 'tools'" in message and "tools" in params: + self.logger.warning( + "SDK rejected `tools` at runtime, retrying with legacy functions." + ) + self._supports_tools = False + retry_params = dict(params) + retry_params.pop("tools", None) + retry_params.update(self._build_tool_params(tools)) + return await self.client.chat.completions.create(**retry_params) + + if "unexpected keyword argument 'functions'" in message and "functions" in params: + self.logger.warning( + "SDK rejected `functions` at runtime, retrying without tool calling." + ) + self._supports_functions = False + retry_params = dict(params) + retry_params.pop("functions", None) + retry_params.pop("function_call", None) + return await self.client.chat.completions.create(**retry_params) + + raise + async def chat( self, messages: List[Message], @@ -113,7 +143,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) - response = await self.client.chat.completions.create(**params) + response = await self._create_completion_with_fallback(params, tools) choice = response.choices[0] tool_calls = self._extract_response_tool_calls(choice.message) @@ -143,7 +173,7 @@ class OpenAIModel(BaseAIModel): params.update(self._build_tool_params(tools)) params.update(kwargs) - stream = await self.client.chat.completions.create(**params) + stream = await self._create_completion_with_fallback(params, tools) async for chunk in stream: delta = chunk.choices[0].delta diff --git a/tests/test_openai_model_compat.py b/tests/test_openai_model_compat.py new file mode 100644 index 0000000..55ee06d --- /dev/null +++ b/tests/test_openai_model_compat.py @@ -0,0 +1,278 @@ +"""Compatibility tests for OpenAIModel tool-calling behavior.""" + +import asyncio +from types import SimpleNamespace + +import src.ai.models.openai_model as openai_model_module +from src.ai.base import Message, ModelConfig, ModelProvider +from src.ai.models.openai_model import OpenAIModel + + +def _model_config() -> ModelConfig: + return ModelConfig( + provider=ModelProvider.OPENAI, + model_name="test-model", + api_key="test-key", + api_base="https://example.com/v1", + temperature=0.0, + max_tokens=256, + ) + + +def _tool_defs(): + return [ + { + "type": "function", + "function": { + "name": "demo_tool", + "description": "Demo tool", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + "required": ["city"], + }, + }, + } + ] + + +class _FakeEmbeddings: + async def create(self, **kwargs): + return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) + + +class _ModernCompletions: + def __init__(self): + self.last_params = None + + async def create( + self, + *, + model, + messages, + temperature=None, + max_tokens=None, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + tools=None, + stream=False, + **kwargs, + ): + self.last_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "tools": tools, + "stream": stream, + **kwargs, + } + message = SimpleNamespace( + content="ok", + tool_calls=[ + SimpleNamespace( + id="call_1", + type="function", + function=SimpleNamespace( + name="demo_tool", arguments='{"city":"beijing"}' + ), + ) + ], + function_call=None, + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + +class _LegacyCompletions: + def __init__(self): + self.last_params = None + + async def create( + self, + *, + model, + messages, + temperature=None, + max_tokens=None, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + functions=None, + function_call=None, + stream=False, + **kwargs, + ): + self.last_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "functions": functions, + "function_call": function_call, + "stream": stream, + **kwargs, + } + message = SimpleNamespace( + content="", + tool_calls=None, + function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + +class _RuntimeRejectToolsCompletions: + def __init__(self): + self.calls = [] + + async def create( + self, + *, + model, + messages, + temperature=None, + max_tokens=None, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + tools=None, + functions=None, + function_call=None, + stream=False, + **kwargs, + ): + self.calls.append( + { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "tools": tools, + "functions": functions, + "function_call": function_call, + "stream": stream, + **kwargs, + } + ) + if tools is not None: + raise TypeError("AsyncCompletions.create() got an unexpected keyword argument 'tools'") + + message = SimpleNamespace( + content="", + tool_calls=None, + function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + +class _ModernAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _ModernCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _FakeEmbeddings() + + +class _LegacyAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _LegacyCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _FakeEmbeddings() + + +class _RuntimeRejectToolsAsyncOpenAI: + def __init__(self, **kwargs): + self.completions = _RuntimeRejectToolsCompletions() + self.chat = SimpleNamespace(completions=self.completions) + self.embeddings = _FakeEmbeddings() + + +def test_openai_model_uses_tools_when_supported(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + result = asyncio.run( + model.chat(messages=[Message(role="user", content="hi")], tools=tools) + ) + + sent = model.client.completions.last_params + assert model._supports_tools is True + assert sent["tools"] == tools + assert "functions" not in sent + assert result.tool_calls is not None + assert result.tool_calls[0]["function"]["name"] == "demo_tool" + + +def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tools = _tool_defs() + result = asyncio.run( + model.chat(messages=[Message(role="user", content="hi")], tools=tools) + ) + + sent = model.client.completions.last_params + assert model._supports_tools is False + assert model._supports_functions is True + assert sent["function_call"] == "auto" + assert isinstance(sent["functions"], list) and sent["functions"] + assert sent["functions"][0]["name"] == "demo_tool" + assert result.tool_calls is not None + assert result.tool_calls[0]["function"]["name"] == "demo_tool" + + +def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): + monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) + + model = OpenAIModel(_model_config()) + tool_message = Message(role="tool", name="demo_tool", content="done") + formatted_tool = model._format_message(tool_message) + + assistant_message = Message( + role="assistant", + content="", + tool_calls=[ + { + "type": "function", + "function": {"name": "demo_tool", "arguments": '{"city":"beijing"}'}, + } + ], + ) + formatted_assistant = model._format_message(assistant_message) + + assert formatted_tool["role"] == "function" + assert formatted_tool["name"] == "demo_tool" + assert "function_call" in formatted_assistant + assert "tool_calls" not in formatted_assistant + + +def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): + monkeypatch.setattr( + openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI + ) + + model = OpenAIModel(_model_config()) + result = asyncio.run( + model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) + ) + + calls = model.client.completions.calls + assert len(calls) == 2 + assert calls[0]["tools"] is not None + assert calls[1]["tools"] is None + assert calls[1]["functions"][0]["name"] == "demo_tool" + assert model._supports_tools is False + assert result.tool_calls is not None + assert result.tool_calls[0]["function"]["name"] == "demo_tool"