Implement forced tool selection in AIClient and OpenAIModel, enhancing tool invocation capabilities. Added methods for extracting forced tool names from user messages and updated logging to reflect forced tool usage. Improved error handling for timeout scenarios in message processing.
This commit is contained in:
@@ -191,12 +191,18 @@ class AIClient:
|
|||||||
if use_tools and self.tools.list():
|
if use_tools and self.tools.list():
|
||||||
tools = self.tools.to_openai_format()
|
tools = self.tools.to_openai_format()
|
||||||
tool_names = [tool.name for tool in self.tools.list()]
|
tool_names = [tool.name for tool in self.tools.list()]
|
||||||
|
forced_tool_name = self._extract_forced_tool_name(user_message, tool_names)
|
||||||
|
if forced_tool_name:
|
||||||
|
kwargs = dict(kwargs)
|
||||||
|
kwargs["forced_tool_name"] = forced_tool_name
|
||||||
|
logger.info(f"检测到显式工具调用意图,启用强制调用: {forced_tool_name}")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLM请求: "
|
"LLM请求: "
|
||||||
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
|
f"user_id={user_id}, use_memory={use_memory}, use_tools={use_tools}, "
|
||||||
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
|
f"registered_tools={len(tool_names)}, sent_tools={len(tools or [])}, "
|
||||||
f"tool_names={self._preview_log_payload(tool_names)}"
|
f"tool_names={self._preview_log_payload(tool_names)}, "
|
||||||
|
f"forced_tool={forced_tool_name or '-'}"
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLM输入: "
|
"LLM输入: "
|
||||||
@@ -251,7 +257,7 @@ class AIClient:
|
|||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"对话失败: {e}")
|
logger.error(f"对话失败: {type(e).__name__}: {e!r}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _chat_stream(
|
async def _chat_stream(
|
||||||
@@ -342,7 +348,10 @@ class AIClient:
|
|||||||
))
|
))
|
||||||
|
|
||||||
# 再次调用模型获取最终响应
|
# 再次调用模型获取最终响应
|
||||||
final_response = await self.model.chat(messages, tools, **kwargs)
|
final_kwargs = dict(kwargs)
|
||||||
|
# Force only the first model turn, avoid recursive force after tool result.
|
||||||
|
final_kwargs.pop("forced_tool_name", None)
|
||||||
|
final_response = await self.model.chat(messages, tools, **final_kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLM最终输出: "
|
"LLM最终输出: "
|
||||||
f"content={self._preview_log_payload(final_response.content)}"
|
f"content={self._preview_log_payload(final_response.content)}"
|
||||||
@@ -401,6 +410,52 @@ class AIClient:
|
|||||||
return text[:max_len] + "..."
|
return text[:max_len] + "..."
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_forced_tool_name(
|
||||||
|
user_message: str, available_tool_names: List[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
if not user_message or not available_tool_names:
|
||||||
|
return None
|
||||||
|
|
||||||
|
triggers = ["调用工具", "使用工具", "只调用", "务必调用", "必须调用", "tool"]
|
||||||
|
if not any(trigger in user_message for trigger in triggers):
|
||||||
|
return None
|
||||||
|
|
||||||
|
pattern = re.compile(r"([A-Za-z0-9_]+\.[A-Za-z0-9_]+)")
|
||||||
|
explicit_matches = [
|
||||||
|
name for name in pattern.findall(user_message) if name in available_tool_names
|
||||||
|
]
|
||||||
|
if len(explicit_matches) == 1:
|
||||||
|
return explicit_matches[0]
|
||||||
|
if len(explicit_matches) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
contained = [name for name in available_tool_names if name in user_message]
|
||||||
|
if len(contained) == 1:
|
||||||
|
return contained[0]
|
||||||
|
|
||||||
|
# 允许只写 skill/tool 前缀(如 humanizer_zh),前提是前缀下只有一个工具。
|
||||||
|
prefixes = sorted(
|
||||||
|
{name.split(".", 1)[0] for name in available_tool_names},
|
||||||
|
key=len,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
matched_prefixes = [
|
||||||
|
prefix
|
||||||
|
for prefix in prefixes
|
||||||
|
if re.search(rf"\b{re.escape(prefix)}\b", user_message)
|
||||||
|
]
|
||||||
|
if len(matched_prefixes) == 1:
|
||||||
|
prefix_tools = [
|
||||||
|
name
|
||||||
|
for name in available_tool_names
|
||||||
|
if name.startswith(f"{matched_prefixes[0]}.")
|
||||||
|
]
|
||||||
|
if len(prefix_tools) == 1:
|
||||||
|
return prefix_tools[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def set_personality(self, personality_name: str) -> bool:
|
def set_personality(self, personality_name: str) -> bool:
|
||||||
"""设置人格。"""
|
"""设置人格。"""
|
||||||
return self.personality.set_personality(personality_name)
|
return self.personality.set_personality(personality_name)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
OpenAI model implementation (including OpenAI-compatible providers).
|
OpenAI model implementation (including OpenAI-compatible providers).
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@@ -72,6 +73,58 @@ class OpenAIModel(BaseAIModel):
|
|||||||
)
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _build_forced_tool_params(
|
||||||
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
forced_tool_name: Optional[str],
|
||||||
|
tools: Optional[List[dict]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build request params for forcing one specific tool call."""
|
||||||
|
if not forced_tool_name:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
available_tool_names = self._extract_tool_names(tools)
|
||||||
|
if available_tool_names and forced_tool_name not in available_tool_names:
|
||||||
|
self.logger.warning(
|
||||||
|
"forced_tool_name is not in current tool list, ignored: "
|
||||||
|
f"{forced_tool_name}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if "tools" in params:
|
||||||
|
return {
|
||||||
|
"tool_choice": {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": forced_tool_name},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if "functions" in params:
|
||||||
|
return {"function_call": {"name": forced_tool_name}}
|
||||||
|
|
||||||
|
self.logger.warning(
|
||||||
|
"forced_tool_name provided but tool params are unavailable, ignored: "
|
||||||
|
f"{forced_tool_name}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_names(tools: Optional[List[dict]]) -> List[str]:
|
||||||
|
if not tools:
|
||||||
|
return []
|
||||||
|
|
||||||
|
names: List[str] = []
|
||||||
|
for tool in tools:
|
||||||
|
if not isinstance(tool, dict):
|
||||||
|
continue
|
||||||
|
function_data = tool.get("function")
|
||||||
|
if not isinstance(function_data, dict):
|
||||||
|
continue
|
||||||
|
name = function_data.get("name")
|
||||||
|
if isinstance(name, str) and name:
|
||||||
|
names.append(name)
|
||||||
|
return names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
if not isinstance(tool, dict):
|
if not isinstance(tool, dict):
|
||||||
@@ -108,7 +161,12 @@ class OpenAIModel(BaseAIModel):
|
|||||||
self._supports_tools = False
|
self._supports_tools = False
|
||||||
retry_params = dict(params)
|
retry_params = dict(params)
|
||||||
retry_params.pop("tools", None)
|
retry_params.pop("tools", None)
|
||||||
|
forced_tool_name = self._extract_forced_tool_name_from_choice(
|
||||||
|
retry_params.pop("tool_choice", None)
|
||||||
|
)
|
||||||
retry_params.update(self._build_tool_params(tools))
|
retry_params.update(self._build_tool_params(tools))
|
||||||
|
if forced_tool_name and "functions" in retry_params:
|
||||||
|
retry_params["function_call"] = {"name": forced_tool_name}
|
||||||
return await self.client.chat.completions.create(**retry_params)
|
return await self.client.chat.completions.create(**retry_params)
|
||||||
|
|
||||||
if "unexpected keyword argument 'functions'" in message and "functions" in params:
|
if "unexpected keyword argument 'functions'" in message and "functions" in params:
|
||||||
@@ -122,6 +180,60 @@ class OpenAIModel(BaseAIModel):
|
|||||||
return await self.client.chat.completions.create(**retry_params)
|
return await self.client.chat.completions.create(**retry_params)
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
if self._is_timeout_error(exc):
|
||||||
|
return await self._retry_on_timeout(params)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_timeout_error(error: Exception) -> bool:
|
||||||
|
if isinstance(error, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
error_name = type(error).__name__.lower()
|
||||||
|
if "timeout" in error_name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
message = str(error).lower()
|
||||||
|
return "timed out" in message or "timeout" in message
|
||||||
|
|
||||||
|
async def _retry_on_timeout(self, params: Dict[str, Any]):
|
||||||
|
base_timeout = float(self.config.timeout or 60)
|
||||||
|
retry_timeout = min(max(base_timeout * 2, 120.0), 300.0)
|
||||||
|
retry_params = dict(params)
|
||||||
|
retry_params["timeout"] = retry_timeout
|
||||||
|
self.logger.warning(
|
||||||
|
"chat request timed out, retry once with longer timeout: "
|
||||||
|
f"{base_timeout:.0f}s -> {retry_timeout:.0f}s"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return await self.client.chat.completions.create(**retry_params)
|
||||||
|
except Exception as retry_exc:
|
||||||
|
if self._is_timeout_error(retry_exc):
|
||||||
|
self.logger.error(
|
||||||
|
"chat request still timed out after retry: "
|
||||||
|
f"timeout={retry_timeout:.0f}s"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_forced_tool_name_from_choice(tool_choice: Any) -> Optional[str]:
|
||||||
|
if not tool_choice:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(tool_choice, dict):
|
||||||
|
function_data = tool_choice.get("function")
|
||||||
|
if isinstance(function_data, dict):
|
||||||
|
name = function_data.get("name")
|
||||||
|
return name if isinstance(name, str) and name else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
function_data = getattr(tool_choice, "function", None)
|
||||||
|
if function_data:
|
||||||
|
name = getattr(function_data, "name", None)
|
||||||
|
return name if isinstance(name, str) and name else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@@ -131,6 +243,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
) -> Message:
|
) -> Message:
|
||||||
"""Non-stream chat."""
|
"""Non-stream chat."""
|
||||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||||
|
forced_tool_name = kwargs.pop("forced_tool_name", None)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
@@ -144,6 +257,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
|
|
||||||
params.update(self._build_tool_params(tools))
|
params.update(self._build_tool_params(tools))
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
|
||||||
|
|
||||||
tool_mode = "none"
|
tool_mode = "none"
|
||||||
tool_count = 0
|
tool_count = 0
|
||||||
@@ -156,7 +270,8 @@ class OpenAIModel(BaseAIModel):
|
|||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"OpenAI chat request: "
|
"OpenAI chat request: "
|
||||||
f"model={self.config.model_name}, tool_mode={tool_mode}, tool_count={tool_count}"
|
f"model={self.config.model_name}, tool_mode={tool_mode}, "
|
||||||
|
f"tool_count={tool_count}, forced_tool={forced_tool_name or '-'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._create_completion_with_fallback(params, tools)
|
response = await self._create_completion_with_fallback(params, tools)
|
||||||
@@ -177,6 +292,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""Streaming chat."""
|
"""Streaming chat."""
|
||||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||||
|
forced_tool_name = kwargs.pop("forced_tool_name", None)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
@@ -188,6 +304,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
|
|
||||||
params.update(self._build_tool_params(tools))
|
params.update(self._build_tool_params(tools))
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
params.update(self._build_forced_tool_params(params, forced_tool_name, tools))
|
||||||
|
|
||||||
stream = await self._create_completion_with_fallback(params, tools)
|
stream = await self._create_completion_with_fallback(params, tools)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from botpy.message import Message
|
from botpy.message import Message
|
||||||
|
|
||||||
from src.ai import AIClient
|
from src.ai import AIClient
|
||||||
@@ -619,6 +621,12 @@ class MessageHandler:
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
if isinstance(exc, (httpx.ReadTimeout, TimeoutError, asyncio.TimeoutError)):
|
||||||
|
await self._reply_plain(
|
||||||
|
message,
|
||||||
|
"模型响应超时,请稍后重试,或将当前模型配置的 timeout 调大(建议 120-180 秒)。",
|
||||||
|
)
|
||||||
|
return
|
||||||
await self._reply_plain(message, "消息处理失败,请稍后重试")
|
await self._reply_plain(message, "消息处理失败,请稍后重试")
|
||||||
|
|
||||||
async def _handle_skills_command(self, message: Message, command: str):
|
async def _handle_skills_command(self, message: Message, command: str):
|
||||||
|
|||||||
39
tests/test_ai_client_forced_tool.py
Normal file
39
tests/test_ai_client_forced_tool.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Tests for AIClient forced tool name extraction."""
|
||||||
|
|
||||||
|
from src.ai.client import AIClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forced_tool_name_full_name():
|
||||||
|
tools = [
|
||||||
|
"humanizer_zh.read_skill_doc",
|
||||||
|
"skills_creator.create_skill",
|
||||||
|
]
|
||||||
|
message = "please call tool humanizer_zh.read_skill_doc and return first 100 chars"
|
||||||
|
|
||||||
|
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||||
|
|
||||||
|
assert forced == "humanizer_zh.read_skill_doc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forced_tool_name_unique_prefix():
|
||||||
|
tools = [
|
||||||
|
"humanizer_zh.read_skill_doc",
|
||||||
|
"skills_creator.create_skill",
|
||||||
|
]
|
||||||
|
message = "please call tool humanizer_zh only"
|
||||||
|
|
||||||
|
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||||
|
|
||||||
|
assert forced == "humanizer_zh.read_skill_doc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_forced_tool_name_ambiguous_prefix_returns_none():
|
||||||
|
tools = [
|
||||||
|
"skills_creator.create_skill",
|
||||||
|
"skills_creator.reload_skill",
|
||||||
|
]
|
||||||
|
message = "please call tool skills_creator"
|
||||||
|
|
||||||
|
forced = AIClient._extract_forced_tool_name(message, tools)
|
||||||
|
|
||||||
|
assert forced is None
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import httpx
|
||||||
import src.ai.models.openai_model as openai_model_module
|
import src.ai.models.openai_model as openai_model_module
|
||||||
from src.ai.base import Message, ModelConfig, ModelProvider
|
from src.ai.base import Message, ModelConfig, ModelProvider
|
||||||
from src.ai.models.openai_model import OpenAIModel
|
from src.ai.models.openai_model import OpenAIModel
|
||||||
@@ -216,6 +217,55 @@ class _LengthLimitedEmbedAsyncOpenAI:
|
|||||||
self.embeddings = _LengthLimitedEmbeddings()
|
self.embeddings = _LengthLimitedEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
|
class _TimeoutOnceCompletions:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
temperature=None,
|
||||||
|
max_tokens=None,
|
||||||
|
top_p=None,
|
||||||
|
frequency_penalty=None,
|
||||||
|
presence_penalty=None,
|
||||||
|
tools=None,
|
||||||
|
stream=False,
|
||||||
|
timeout=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"tools": tools,
|
||||||
|
"stream": stream,
|
||||||
|
"timeout": timeout,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.calls) == 1:
|
||||||
|
raise httpx.ReadTimeout("timed out")
|
||||||
|
|
||||||
|
message = SimpleNamespace(content="ok-after-timeout", tool_calls=None, function_call=None)
|
||||||
|
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||||
|
|
||||||
|
|
||||||
|
class _TimeoutOnceAsyncOpenAI:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.completions = _TimeoutOnceCompletions()
|
||||||
|
self.chat = SimpleNamespace(completions=self.completions)
|
||||||
|
self.embeddings = _FakeEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||||
|
|
||||||
@@ -233,6 +283,24 @@ def test_openai_model_uses_tools_when_supported(monkeypatch):
|
|||||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_forces_tool_choice_when_supported(monkeypatch):
|
||||||
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||||
|
|
||||||
|
model = OpenAIModel(_model_config())
|
||||||
|
tools = _tool_defs()
|
||||||
|
asyncio.run(
|
||||||
|
model.chat(
|
||||||
|
messages=[Message(role="user", content="hi")],
|
||||||
|
tools=tools,
|
||||||
|
forced_tool_name="demo_tool",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sent = model.client.completions.last_params
|
||||||
|
assert sent["tool_choice"]["type"] == "function"
|
||||||
|
assert sent["tool_choice"]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
||||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||||
|
|
||||||
@@ -252,6 +320,23 @@ def test_openai_model_falls_back_to_functions_for_legacy_sdk(monkeypatch):
|
|||||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_forces_function_call_for_legacy_sdk(monkeypatch):
|
||||||
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||||
|
|
||||||
|
model = OpenAIModel(_model_config())
|
||||||
|
tools = _tool_defs()
|
||||||
|
asyncio.run(
|
||||||
|
model.chat(
|
||||||
|
messages=[Message(role="user", content="hi")],
|
||||||
|
tools=tools,
|
||||||
|
forced_tool_name="demo_tool",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sent = model.client.completions.last_params
|
||||||
|
assert sent["function_call"] == {"name": "demo_tool"}
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
def test_openai_model_formats_tool_messages_for_legacy_sdk(monkeypatch):
|
||||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _LegacyAsyncOpenAI)
|
||||||
|
|
||||||
@@ -297,6 +382,41 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
|||||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_preserves_forced_tool_when_fallback_to_functions(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
openai_model_module, "AsyncOpenAI", _RuntimeRejectToolsAsyncOpenAI
|
||||||
|
)
|
||||||
|
|
||||||
|
model = OpenAIModel(_model_config())
|
||||||
|
asyncio.run(
|
||||||
|
model.chat(
|
||||||
|
messages=[Message(role="user", content="hi")],
|
||||||
|
tools=_tool_defs(),
|
||||||
|
forced_tool_name="demo_tool",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = model.client.completions.calls
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0]["tool_choice"]["function"]["name"] == "demo_tool"
|
||||||
|
assert calls[1]["function_call"] == {"name": "demo_tool"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_retries_once_on_read_timeout(monkeypatch):
|
||||||
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _TimeoutOnceAsyncOpenAI)
|
||||||
|
|
||||||
|
model = OpenAIModel(_model_config())
|
||||||
|
result = asyncio.run(
|
||||||
|
model.chat(messages=[Message(role="user", content="hi")], tools=_tool_defs())
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = model.client.completions.calls
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0]["timeout"] is None
|
||||||
|
assert calls[1]["timeout"] == 120.0
|
||||||
|
assert result.content == "ok-after-timeout"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
||||||
|
|||||||
Reference in New Issue
Block a user