443 lines
13 KiB
Python
443 lines
13 KiB
Python
"""Compatibility tests for OpenAIModel tool-calling behavior."""
|
|
|
|
import asyncio
|
|
from types import SimpleNamespace
|
|
|
|
import httpx
|
|
import src.ai.models.openai_model as openai_model_module
|
|
from src.ai.base import Message, ModelConfig, ModelProvider
|
|
from src.ai.models.openai_model import OpenAIModel
|
|
|
|
|
|
def _model_config() -> ModelConfig:
|
|
return ModelConfig(
|
|
provider=ModelProvider.OPENAI,
|
|
model_name="test-model",
|
|
api_key="test-key",
|
|
api_base="https://example.com/v1",
|
|
temperature=0.0,
|
|
max_tokens=256,
|
|
)
|
|
|
|
|
|
def _tool_defs():
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "demo_tool",
|
|
"description": "Demo tool",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {"type": "string"},
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
class _FakeEmbeddings:
|
|
async def create(self, **kwargs):
|
|
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
|
|
|
|
|
class _LengthLimitedEmbeddings:
|
|
def __init__(self):
|
|
self.inputs = []
|
|
|
|
async def create(self, **kwargs):
|
|
text = kwargs.get("input", "")
|
|
self.inputs.append(text)
|
|
if len(text) > 512:
|
|
raise RuntimeError("input must be less than 512 tokens")
|
|
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
|
|
|
|
|
class _ModernCompletions:
|
|
def __init__(self):
|
|
self.last_params = None
|
|
|
|
async def create(
|
|
self,
|
|
*,
|
|
model,
|
|
messages,
|
|
temperature=None,
|
|
max_tokens=None,
|
|
top_p=None,
|
|
frequency_penalty=None,
|
|
presence_penalty=None,
|
|
tools=None,
|
|
stream=False,
|
|
**kwargs,
|
|
):
|
|
self.last_params = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"tools": tools,
|
|
"stream": stream,
|
|
**kwargs,
|
|
}
|
|
message = SimpleNamespace(
|
|
content="ok",
|
|
tool_calls=[
|
|
SimpleNamespace(
|
|
id="call_1",
|
|
type="function",
|
|
function=SimpleNamespace(
|
|
name="demo_tool", arguments='{"city":"beijing"}'
|
|
),
|
|
)
|
|
],
|
|
function_call=None,
|
|
)
|
|
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
|
|
|
|
|
class _LegacyCompletions:
|
|
def __init__(self):
|
|
self.last_params = None
|
|
|
|
async def create(
|
|
self,
|
|
*,
|
|
model,
|
|
messages,
|
|
temperature=None,
|
|
max_tokens=None,
|
|
top_p=None,
|
|
frequency_penalty=None,
|
|
presence_penalty=None,
|
|
functions=None,
|
|
function_call=None,
|
|
stream=False,
|
|
**kwargs,
|
|
):
|
|
self.last_params = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"functions": functions,
|
|
"function_call": function_call,
|
|
"stream": stream,
|
|
**kwargs,
|
|
}
|
|
message = SimpleNamespace(
|
|
content="",
|
|
tool_calls=None,
|
|
function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'),
|
|
)
|
|
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
|
|
|
|
|
class _RuntimeRejectToolsCompletions:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
async def create(
|
|
self,
|
|
*,
|
|
model,
|
|
messages,
|
|
temperature=None,
|
|
max_tokens=None,
|
|
top_p=None,
|
|
frequency_penalty=None,
|
|
presence_penalty=None,
|
|
tools=None,
|
|
functions=None,
|
|
function_call=None,
|
|
stream=False,
|
|
**kwargs,
|
|
):
|
|
self.calls.append(
|
|
{
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"tools": tools,
|
|
"functions": functions,
|
|
"function_call": function_call,
|
|
"stream": stream,
|
|
**kwargs,
|
|
}
|
|
)
|
|
if tools is not None:
|
|
raise TypeError("AsyncCompletions.create() got an unexpected keyword argument 'tools'")
|
|
|
|
message = SimpleNamespace(
|
|
content="",
|
|
tool_calls=None,
|
|
function_call=SimpleNamespace(name="demo_tool", arguments='{"city":"beijing"}'),
|
|
)
|
|
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
|
|
|
|
|
class _ModernAsyncOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.completions = _ModernCompletions()
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
self.embeddings = _FakeEmbeddings()
|
|
|
|
|
|
class _LegacyAsyncOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.completions = _LegacyCompletions()
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
self.embeddings = _FakeEmbeddings()
|
|
|
|
|
|
class _RuntimeRejectToolsAsyncOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.completions = _RuntimeRejectToolsCompletions()
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
self.embeddings = _FakeEmbeddings()
|
|
|
|
|
|
class _LengthLimitedEmbedAsyncOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.completions = _ModernCompletions()
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
self.embeddings = _LengthLimitedEmbeddings()
|
|
|
|
|
|
class _TimeoutOnceCompletions:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
async def create(
|
|
self,
|
|
*,
|
|
model,
|
|
messages,
|
|
temperature=None,
|
|
max_tokens=None,
|
|
top_p=None,
|
|
frequency_penalty=None,
|
|
presence_penalty=None,
|
|
tools=None,
|
|
stream=False,
|
|
timeout=None,
|
|
**kwargs,
|
|
):
|
|
self.calls.append(
|
|
{
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"tools": tools,
|
|
"stream": stream,
|
|
"timeout": timeout,
|
|
**kwargs,
|
|
}
|
|
)
|
|
|
|
if len(self.calls) == 1:
|
|
raise httpx.ReadTimeout("timed out")
|
|
|
|
message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None)
|
|
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
|
|
|
|
|
class _TimeoutOnceAsyncOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.completions = _TimeoutOnceCompletions()
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
self.embeddings = _FakeEmbeddings()
|
|
|
|
|
|
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
tools = _tool_defs()
|
|
result = asyncio.run(
|
|
model.chat(messages=[Message(role="user", content="hi")], tools=tools)
|
|
)
|
|
|
|
sent = model.client.completions.last_params
|
|
assert model._supports_tools is True
|
|
assert sent["tools"] == tools
|
|
assert "functions" not in sent
|
|
assert result.tool_calls is not None
|
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
|
|
|
|
|
def test_openai_model_forces_tool_choice_when_supported(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
tools = _tool_defs()
|
|
asyncio.run(
|
|
model.chat(
|
|
messages=[Message(role="user", content="hi")],
|
|
tools=tools,
|
|
forced_tool_name="demo_tool",
|
|
)
|
|
)
|
|
|
|
sent = model.client.completions.last_params
|
|
assert sent["tool_choice"]["type"] == "function"
|
|
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
|
|
|
|
|
|
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
tools = _tool_defs()
|
|
result = asyncio.run(
|
|
model.chat(messages=[Message(role="user", content="hi")], tools=tools)
|
|
)
|
|
|
|
sent = model.client.completions.last_params
|
|
assert model._supports_tools is False
|
|
assert model._supports_functions is True
|
|
assert sent["function_call"] == "auto"
|
|
assert isinstance(sent["functions"], list) and sent["functions"]
|
|
assert sent["functions"][0]["name"] == "demo_tool"
|
|
assert result.tool_calls is not None
|
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
|
|
|
|
|
def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
tools = _tool_defs()
|
|
asyncio.run(
|
|
model.chat(
|
|
messages=[Message(role="user", content="hi")],
|
|
tools=tools,
|
|
forced_tool_name="demo_tool",
|
|
)
|
|
)
|
|
|
|
sent = model.client.completions.last_params
|
|
assert sent["function_call"] == {"name": "demo_tool"}
|
|
|
|
|
|
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
tool_message = Message(role="tool", name="demo_tool", content="done")
|
|
formatted_tool = model._format_message(tool_message)
|
|
|
|
assistant_message = Message(
|
|
role="assistant",
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"type": "function",
|
|
"function": {"name": "demo_tool", "arguments": '{"city":"beijing"}'},
|
|
}
|
|
],
|
|
)
|
|
formatted_assistant = model._format_message(assistant_message)
|
|
|
|
assert formatted_tool["role"] == "function"
|
|
assert formatted_tool["name"] == "demo_tool"
|
|
assert "function_call" in formatted_assistant
|
|
assert "tool_calls" not in formatted_assistant
|
|
|
|
|
|
def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
|
monkeypatch.setattr(
|
|
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
|
|
)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
result = asyncio.run(
|
|
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
|
|
)
|
|
|
|
calls = model.client.completions.calls
|
|
assert len(calls) == 2
|
|
assert calls[0]["tools"] is not None
|
|
assert calls[1]["tools"] is None
|
|
assert calls[1]["functions"][0]["name"] == "demo_tool"
|
|
assert model._supports_tools is False
|
|
assert result.tool_calls is not None
|
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
|
|
|
|
|
def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch):
|
|
monkeypatch.setattr(
|
|
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
|
|
)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
asyncio.run(
|
|
model.chat(
|
|
messages=[Message(role="user", content="hi")],
|
|
tools=_tool_defs(),
|
|
forced_tool_name="demo_tool",
|
|
)
|
|
)
|
|
|
|
calls = model.client.completions.calls
|
|
assert len(calls) == 2
|
|
assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool"
|
|
assert calls[1]["function_call"] == {"name": "demo_tool"}
|
|
|
|
|
|
def test_openai_model_retries_once_on_read_timeout(monkeypatch):
|
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
result = asyncio.run(
|
|
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
|
|
)
|
|
|
|
calls = model.client.completions.calls
|
|
assert len(calls) == 2
|
|
assert calls[0]["timeout"] is None
|
|
assert calls[1]["timeout"] == 120.0
|
|
assert result.content == "ok-after-timeout"
|
|
|
|
|
|
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
|
monkeypatch.setattr(
|
|
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
|
)
|
|
|
|
model = OpenAIModel(_model_config())
|
|
long_text = "你" * 600
|
|
|
|
first_embedding = asyncio.run(model.embed(long_text))
|
|
assert first_embedding == [0.1, 0.2, 0.3]
|
|
assert model._embedding_token_limit == 512
|
|
|
|
inputs_after_first = list(model.client.embeddings.inputs)
|
|
assert len(inputs_after_first) == 2
|
|
assert len(inputs_after_first[0]) == 600
|
|
assert len(inputs_after_first[1]) < len(inputs_after_first[0])
|
|
|
|
second_embedding = asyncio.run(model.embed(long_text))
|
|
assert second_embedding == [0.1, 0.2, 0.3]
|
|
|
|
inputs_after_second = list(model.client.embeddings.inputs)
|
|
assert len(inputs_after_second) == 3
|
|
assert len(inputs_after_second[-1]) < 512
|