Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.
This commit is contained in:
39
tests/test_ai_client_forced_tool.py
Normal file
39
tests/test_ai_client_forced_tool.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Tests for AIClient forced tool name extraction."""
|
||||
|
||||
from src.ai.client import AIClient
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_full_name():
|
||||
tools = [
|
||||
"humanizer_zh.read_skill_doc",
|
||||
"skills_creator.create_skill",
|
||||
]
|
||||
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced == "humanizer_zh.read_skill_doc"
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_unique_prefix():
|
||||
tools = [
|
||||
"humanizer_zh.read_skill_doc",
|
||||
"skills_creator.create_skill",
|
||||
]
|
||||
message = "please call tool humanizer_zh only"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced == "humanizer_zh.read_skill_doc"
|
||||
|
||||
|
||||
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
|
||||
tools = [
|
||||
"skills_creator.create_skill",
|
||||
"skills_creator.reload_skill",
|
||||
]
|
||||
message = "please call tool skills_creator"
|
||||
|
||||
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||
|
||||
assert forced is None
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import src.ai.models.openai_model as openai_model_module
|
||||
from src.ai.base import Message, ModelConfig, ModelProvider
|
||||
from src.ai.models.openai_model import OpenAIModel
|
||||
@@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI:
|
||||
self.embeddings = _LengthLimitedEmbeddings()
|
||||
|
||||
|
||||
class _TimeoutOnceCompletions:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
tools=None,
|
||||
stream=False,
|
||||
timeout=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
if len(self.calls) == 1:
|
||||
raise httpx.ReadTimeout("timed out")
|
||||
|
||||
message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None)
|
||||
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||
|
||||
|
||||
class _TimeoutOnceAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.completions = _TimeoutOnceCompletions()
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
self.embeddings = _FakeEmbeddings()
|
||||
|
||||
|
||||
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||
|
||||
@@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_forces_tool_choice_when_supported(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
tools = _tool_defs()
|
||||
asyncio.run(
|
||||
model.chat(
|
||||
messages=[Message(role="user", content="hi")],
|
||||
tools=tools,
|
||||
forced_tool_name="demo_tool",
|
||||
)
|
||||
)
|
||||
|
||||
sent = model.client.completions.last_params
|
||||
assert sent["tool_choice"]["type"] == "function"
|
||||
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||
|
||||
@@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
tools = _tool_defs()
|
||||
asyncio.run(
|
||||
model.chat(
|
||||
messages=[Message(role="user", content="hi")],
|
||||
tools=tools,
|
||||
forced_tool_name="demo_tool",
|
||||
)
|
||||
)
|
||||
|
||||
sent = model.client.completions.last_params
|
||||
assert sent["function_call"] == {"name": "demo_tool"}
|
||||
|
||||
|
||||
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||
|
||||
@@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
|
||||
)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
asyncio.run(
|
||||
model.chat(
|
||||
messages=[Message(role="user", content="hi")],
|
||||
tools=_tool_defs(),
|
||||
forced_tool_name="demo_tool",
|
||||
)
|
||||
)
|
||||
|
||||
calls = model.client.completions.calls
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool"
|
||||
assert calls[1]["function_call"] == {"name": "demo_tool"}
|
||||
|
||||
|
||||
def test_openai_model_retries_once_on_read_timeout(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
result = asyncio.run(
|
||||
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
|
||||
)
|
||||
|
||||
calls = model.client.completions.calls
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["timeout"] is None
|
||||
assert calls[1]["timeout"] == 120.0
|
||||
assert result.content == "ok-after-timeout"
|
||||
|
||||
|
||||
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
||||
|
||||
Reference in New Issue
Block a user