"""Compatibility tests for OpenAIModel tool-calling behavior.""" import asyncio from types import SimpleNamespace import httpx import src.ai.models.openai_model as openai_model_module from src.ai.base import Message, ModelConfig, ModelProvider from src.ai.models.openai_model import OpenAIModel def _model_config() -> ModelConfig: return ModelConfig( provider=ModelProvider.OPENAI, model_name="test-model", api_key="test-key", api_base="https://example.com/v1", temperature=0.0, max_tokens=256, ) def _tool_defs(): return [ { "type": "function", "function": { "name": "demo_tool", "description": "Demo tool", "parameters": { "type": "object", "properties": { "city": {"type": "string"}, }, "required": ["city"], }, }, } ] class _FakeEmbeddings: async def create(self, **kwargs): return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) class _LengthLimitedEmbeddings: def __init__(self): self.inputs = [] async def create(self, **kwargs): text = kwargs.get("input", "") self.inputs.append(text) if len(text) > 512: raise RuntimeError("input must be less than 512 tokens") return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) class _ModernCompletions: def __init__(self): self.last_params = None async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, tools=None, stream=False, **kwargs, ): self.last_params = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "tools": tools, "stream": stream, **kwargs, } message = SimpleNamespace( content="ok", tool_calls=[ SimpleNamespace( id="call_1", type="function", function=SimpleNamespace( name="demo_tool", arguments='{"city":"beijing"}' ), ) ], function_call=None, ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _LegacyCompletions: def __init__(self): self.last_params = None async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, functions=None, function_call=None, stream=False, **kwargs, ): self.last_params = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "functions": functions, "function_call": function_call, "stream": stream, **kwargs, } message = SimpleNamespace( content="", tool_calls=None, function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _RuntimeRejectToolsCompletions: def __init__(self): self.calls = [] async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, tools=None, functions=None, function_call=None, stream=False, **kwargs, ): self.calls.append( { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "tools": tools, "functions": functions, "function_call": function_call, "stream": stream, **kwargs, } ) if tools is not None: raise TypeError("AsyncCompletions.create() got an unexpected keyword argument 'tools'") message = SimpleNamespace( content="", tool_calls=None, function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'), ) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _ModernAsyncOpenAI: def __init__(self, **kwargs): self.completions = _ModernCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() class _LegacyAsyncOpenAI: def __init__(self, **kwargs): self.completions = _LegacyCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() class _RuntimeRejectToolsAsyncOpenAI: def __init__(self, **kwargs): self.completions = _RuntimeRejectToolsCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() class _LengthLimitedEmbedAsyncOpenAI: def __init__(self, **kwargs): self.completions = _ModernCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _LengthLimitedEmbeddings() class _TimeoutOnceCompletions: def __init__(self): self.calls = [] async def create( self, *, model, messages, temperature=None, max_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, tools=None, stream=False, timeout=None, **kwargs, ): self.calls.append( { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "tools": tools, "stream": stream, "timeout": timeout, **kwargs, } ) if len(self.calls) == 1: raise httpx.ReadTimeout("timed out") message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _TimeoutOnceAsyncOpenAI: def __init__(self, **kwargs): self.completions = _TimeoutOnceCompletions() self.chat = SimpleNamespace(completions=self.completions) self.embeddings = _FakeEmbeddings() def test_openai_model_uses_tools_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=tools) ) sent = model.client.completions.last_params assert model._supports_tools is True assert sent["tools"] == tools assert "functions" not in sent assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" def test_openai_model_forces_tool_choice_when_supported(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() asyncio.run( model.chat( messages=[Message(role="user", content="hi")], tools=tools, forced_tool_name="demo_tool", ) ) sent = model.client.completions.last_params assert sent["tool_choice"]["type"] == "function" assert sent["tool_choice"]["function"]["name"] == "demo_tool" def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=tools) ) sent = model.client.completions.last_params assert model._supports_tools is False assert model._supports_functions is True assert sent["tools"] == tools assert sent["function_call"] is None assert sent["functions"] is None assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) model = OpenAIModel(_model_config()) tools = _tool_defs() asyncio.run( model.chat( messages=[Message(role="user", content="hi")], tools=tools, forced_tool_name="demo_tool", ) ) sent = model.client.completions.last_params assert sent["tool_choice"]["type"] == "function" assert sent["tool_choice"]["function"]["name"] == "demo_tool" def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI) model = OpenAIModel(_model_config()) tool_message = Message(role="tool", name="demo_tool", content="done") formatted_tool = model._format_message(tool_message) assistant_message = Message( role="assistant", content="", tool_calls=[ { "type": "function", "function": {"name": "demo_tool", "arguments": '{"city":"beijing"}'}, } ], ) formatted_assistant = model._format_message(assistant_message) assert formatted_tool["role"] == "function" assert formatted_tool["name"] == "demo_tool" assert "function_call" in formatted_assistant assert "tool_calls" not in formatted_assistant def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI ) model = OpenAIModel(_model_config()) result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) ) calls = model.client.completions.calls assert len(calls) == 2 assert calls[0]["tools"] is not None assert calls[1]["tools"] is None assert calls[1]["functions"][0]["name"] == "demo_tool" assert model._supports_tools is False assert result.tool_calls is not None assert result.tool_calls[0]["function"]["name"] == "demo_tool" def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI ) model = OpenAIModel(_model_config()) asyncio.run( model.chat( messages=[Message(role="user", content="hi")], tools=_tool_defs(), forced_tool_name="demo_tool", ) ) calls = model.client.completions.calls assert len(calls) == 2 assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool" assert calls[1]["function_call"] == {"name": "demo_tool"} def test_openai_model_retries_once_on_read_timeout(monkeypatch): monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI) model = OpenAIModel(_model_config()) result = asyncio.run( model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs()) ) calls = model.client.completions.calls assert len(calls) == 2 assert calls[0]["timeout"] is None assert calls[1]["timeout"] == 120.0 assert result.content == "ok-after-timeout" def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch): monkeypatch.setattr( openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI ) model = OpenAIModel(_model_config()) long_text = "你" * 600 first_embedding = asyncio.run(model.embed(long_text)) assert first_embedding == [0.1, 0.2, 0.3] assert model._embedding_token_limit == 512 inputs_after_first = list(model.client.embeddings.inputs) assert len(inputs_after_first) == 2 assert len(inputs_after_first[0]) == 600 assert len(inputs_after_first[1]) < len(inputs_after_first[0]) second_embedding = asyncio.run(model.embed(long_text)) assert second_embedding == [0.1, 0.2, 0.3] inputs_after_second = list(model.client.embeddings.inputs) assert len(inputs_after_second) == 3 assert len(inputs_after_second[-1]) < 512