"""Compatibility tests for OpenAIModel tool-calling behavior.""" import asyncio from types import SimpleNamespace import src.ai.models.openai_model as openai_model_module from src.ai.base import Message, ModelConfig, ModelProvider from src.ai.models.openai_model import OpenAIModel def _model_config() -> ModelConfig: return ModelConfig( provider=ModelProvider.OPENAI, model_name="test-model", api_key="test-key", api_base="https://example.com/v1", temperature=0.0, max_tokens=256, ) def _tool_defs(): return [ { "type": "function", "function": { "name": "demo_tool", "description": "Demo tool", "parameters": { "type": "object", "properties": { "city": {"type": "string"}, }, "required": ["city"], }, }, } ] class _FakeEmbeddings: async def create(self, **kwargs): return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) class _ModernCompletions: def __init__(self): self.last_params = None async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, tools=None, stream=False, **kwargs, ): self.last_params = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "tools": tools, "stream": stream, **kwargs, } message = SimpleNamespace( content="ok", tool_calls=[ SimpleNamespace( id="call_1", type="function", function=SimpleNamespace( name="demo_tool", arguments='{"city":"beijing"}' ), ) ], function_call=None, ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _LegacyCompletions: def __init__(self): self.last_params = None async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, functions=None, function_call=None, stream=False, **kwargs, ): self.last_params = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "functions": functions, "function_call": function_call, "stream": stream, **kwargs, } message = SimpleNamespace( content="", tool_calls=None, function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _RuntimeRejectToolsCompletions: def __init__(self): self.calls = [] async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, tools=None, functions=None, function_call=None, stream=False, **kwargs, ): self.calls.append( { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "tools": tools, "functions": functions, "function_call": function_call, "stream": stream, **kwargs, } ) if tools is not None: raise TypeError("AsyncCompletions.create() got an unexpected keyword argument 'tools'") message = SimpleNamespace( content="", tool_calls=None, function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _ModernAsyncOpenAI: def __init__(self, **kwargs): self.completions = _ModernCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() class _LegacyAsyncOpenAI: def __init__(self, **kwargs): self.completions = _LegacyCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() class _RuntimeRejectToolsAsyncOpenAI: def __init__(self, **kwargs): self.completions = _RuntimeRejectToolsCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=tools) ) sent = model.client.completions.last_params assert model._supports_tools is True assert sent["tools"] == tools assert "functions" not in sent assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=tools) ) sent = model.client.completions.last_params assert model._supports_tools is False assert model._supports_functions is True assert sent["function_call"] == "auto" assert isinstance(sent["functions"], list) and sent["functions"] assert sent["functions"][0]["name"] == "demo_tool" assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) model = OpenAIModel(_model_config()) tool_message = Message(role="tool", name="demo_tool", content="done") formatted_tool = model._format_message(tool_message) assistant_message = Message( role="assistant", content="", tool_calls=[ { "type": "function", "function": {"name": "demo_tool", "arguments": '{"city":"beijing"}'}, } ], ) formatted_assistant = model._format_message(assistant_message) assert formatted_tool["role"] == "function" assert formatted_tool["name"] == "demo_tool" assert "function_call" in formatted_assistant assert "tool_calls" not in formatted_assistant def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI ) model = OpenAIModel(_model_config()) result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) ) calls = model.client.completions.calls assert len(calls) == 2 assert calls[0]["tools"] is not None assert calls[1]["tools"] is None assert calls[1]["functions"][0]["name"] == "demo_tool" assert model._supports_tools is False assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool"