Enhance AIClient and OpenAIModel to support tool names in messages and improve tool capability detection. Updated message formatting to include tool names for better clarity in error handling and results.

This commit is contained in:
Mimikko-zeus
2026-03-03 01:37:00 +08:00
parent ae208af6a9
commit 0b356d1fef
2 changed files with 201 additions and 91 deletions

View File

@@ -254,7 +254,8 @@ class AIClient:
messages.append(Message(
role="tool",
content=f"工具参数解析失败: {str(e)}",
tool_call_id=fallback_id
tool_call_id=fallback_id,
name="tool"
))
continue
if not tool_name:
@@ -265,9 +266,9 @@ class AIClient:
if not tool_def:
error_msg = f"未找到工具: {tool_name}"
logger.warning(error_msg)
if tool_call_id:
messages.append(Message(
role="tool",
name=tool_name,
content=error_msg,
tool_call_id=tool_call_id
))
@@ -277,16 +278,16 @@ class AIClient:
result = tool_def.function(**tool_args)
if inspect.isawaitable(result):
result = await result
if tool_call_id:
messages.append(Message(
role="tool",
name=tool_name,
content=str(result),
tool_call_id=tool_call_id
))
except Exception as e:
if tool_call_id:
messages.append(Message(
role="tool",
name=tool_name,
content=f"工具执行失败: {str(e)}",
tool_call_id=tool_call_id
))

View File

@@ -1,43 +1,103 @@
"""
OpenAI模型实现兼容OpenAI API的模型
OpenAI model implementation (including OpenAI-compatible providers).
"""
import inspect
import json
from typing import Any, AsyncIterator, Dict, List, Optional
import httpx
from typing import List, Optional, AsyncIterator, Dict, Any
from openai import AsyncOpenAI
from ..base import BaseAIModel, Message, ModelConfig
from src.utils.logger import setup_logger
logger = setup_logger('OpenAIModel')
logger = setup_logger("OpenAIModel")
class OpenAIModel(BaseAIModel):
"""OpenAI模型实现"""
"""OpenAI model implementation."""
def __init__(self, config: ModelConfig):
super().__init__(config)
self.logger = logger
# 创建支持UTF-8的httpx客户端
http_client = httpx.AsyncClient(
timeout=config.timeout,
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
)
self.client = AsyncOpenAI(
api_key=config.api_key,
base_url=config.api_base,
timeout=config.timeout,
http_client=http_client
http_client=http_client,
)
self._supports_tools = False
self._supports_functions = False
self._detect_tool_capability()
def _detect_tool_capability(self) -> None:
"""Detect tool-calling parameters supported by current SDK."""
try:
signature = inspect.signature(self.client.chat.completions.create)
parameters = signature.parameters
self._supports_tools = "tools" in parameters
self._supports_functions = "functions" in parameters
except Exception as exc:
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build SDK-compatible tool/function request parameters."""
if not tools:
return {}
if self._supports_tools:
return {"tools": tools}
if self._supports_functions:
functions = []
for tool in tools:
schema = self._extract_function_schema(tool)
if schema:
functions.append(schema)
if functions:
return {"functions": functions, "function_call": "auto"}
self.logger.warning(
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
)
return {}
@staticmethod
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not isinstance(tool, dict):
return None
function_data = tool.get("function")
if not isinstance(function_data, dict):
return None
name = function_data.get("name")
if not name:
return None
return {
"name": name,
"description": function_data.get("description", ""),
"parameters": function_data.get(
"parameters", {"type": "object", "properties": {}}
),
}
async def chat(
self,
messages: List[Message],
tools: Optional[List[dict]] = None,
**kwargs
**kwargs,
) -> Message:
"""同步对话"""
"""Non-stream chat."""
formatted_messages = [self._format_message(msg) for msg in messages]
params = {
@@ -50,36 +110,26 @@ class OpenAIModel(BaseAIModel):
"presence_penalty": self.config.presence_penalty,
}
if tools:
params["tools"] = tools
params.update(self._build_tool_params(tools))
params.update(kwargs)
response = await self.client.chat.completions.create(**params)
choice = response.choices[0]
raw_tool_calls = (
choice.message.tool_calls
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls
else None
)
tool_calls = (
[self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
if raw_tool_calls else None
)
tool_calls = self._extract_response_tool_calls(choice.message)
return Message(
role="assistant",
content=choice.message.content or "",
tool_calls=tool_calls
tool_calls=tool_calls,
)
async def chat_stream(
self,
messages: List[Message],
tools: Optional[List[dict]] = None,
**kwargs
**kwargs,
) -> AsyncIterator[str]:
"""流式对话"""
"""Streaming chat."""
formatted_messages = [self._format_message(msg) for msg in messages]
params = {
@@ -90,41 +140,97 @@ class OpenAIModel(BaseAIModel):
"stream": True,
}
if tools:
params["tools"] = tools
params.update(self._build_tool_params(tools))
params.update(kwargs)
stream = await self.client.chat.completions.create(**params)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
delta = chunk.choices[0].delta
if delta.content:
yield delta.content
def _extract_response_tool_calls(self, message: Any) -> Optional[List[Dict[str, Any]]]:
raw_tool_calls = getattr(message, "tool_calls", None)
if raw_tool_calls:
return [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
legacy_function_call = getattr(message, "function_call", None)
if legacy_function_call:
return [self._normalize_legacy_function_call(legacy_function_call)]
return None
def _normalize_legacy_function_call(self, function_call: Any) -> Dict[str, Any]:
if isinstance(function_call, dict):
function_name = function_call.get("name")
raw_arguments = function_call.get("arguments")
else:
function_name = getattr(function_call, "name", None)
raw_arguments = getattr(function_call, "arguments", None)
if isinstance(raw_arguments, dict):
arguments = json.dumps(raw_arguments, ensure_ascii=False)
elif raw_arguments is None:
arguments = "{}"
else:
arguments = str(raw_arguments)
return {
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
def _format_message(self, msg: Message) -> Dict[str, Any]:
"""将内部消息结构转换为OpenAI消息格式"""
formatted: Dict[str, Any] = {"role": msg.role}
"""Convert internal message schema to OpenAI request format."""
if msg.role == "assistant":
formatted["content"] = msg.content if msg.content else None
formatted: Dict[str, Any] = {
"role": "assistant",
"content": msg.content if msg.content else None,
}
if msg.tool_calls:
formatted["tool_calls"] = [
self._normalize_tool_call(tool_call)
for tool_call in msg.tool_calls
normalized_calls = [
self._normalize_tool_call(tool_call) for tool_call in msg.tool_calls
]
elif msg.role == "tool":
formatted["content"] = msg.content
if msg.tool_call_id:
formatted["tool_call_id"] = msg.tool_call_id
else:
formatted["content"] = msg.content
if self._supports_tools:
formatted["tool_calls"] = normalized_calls
elif self._supports_functions:
if len(normalized_calls) > 1:
self.logger.warning(
"Legacy function_call mode only supports one tool call; "
"extra tool calls were dropped."
)
function_data = normalized_calls[0].get("function") or {}
formatted["function_call"] = {
"name": function_data.get("name"),
"arguments": function_data.get("arguments", "{}"),
}
return formatted
if msg.role == "tool":
if self._supports_tools and msg.tool_call_id:
return {
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id,
}
return {
"role": "function",
"name": msg.name or "tool",
"content": msg.content,
}
formatted = {"role": msg.role, "content": msg.content}
if msg.name:
formatted["name"] = msg.name
return formatted
def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]:
"""将工具调用对象统一转换为字典"""
"""Normalize tool call object/dict to a stable dict schema."""
if isinstance(tool_call, dict):
normalized = dict(tool_call)
elif hasattr(tool_call, "model_dump"):
@@ -143,15 +249,15 @@ class OpenAIModel(BaseAIModel):
"type": getattr(tool_call, "type", "function"),
"function": {
"name": function_name,
"arguments": raw_arguments
}
"arguments": raw_arguments,
},
}
function_data = normalized.get("function") or {}
if not isinstance(function_data, dict):
function_data = {
"name": getattr(function_data, "name", None),
"arguments": getattr(function_data, "arguments", None)
"arguments": getattr(function_data, "arguments", None),
}
raw_arguments = function_data.get("arguments")
@@ -196,7 +302,7 @@ class OpenAIModel(BaseAIModel):
return f"{compact[:head]} {compact[-tail:]}"
async def embed(self, text: str) -> List[float]:
"""?????"""
"""Generate embeddings."""
if isinstance(text, bytes):
text = text.decode("utf-8", errors="ignore")
@@ -209,13 +315,17 @@ class OpenAIModel(BaseAIModel):
response = await self.client.embeddings.create(
model=self.config.model_name,
input=candidate_text,
encoding_format="float"
encoding_format="float",
)
return response.data[0].embedding
except Exception as e:
if self._is_embedding_too_long_error(e):
next_text = self._shrink_text_for_embedding(candidate_text)
if next_text and len(next_text) < len(candidate_text) and retry_count < 5:
if (
next_text
and len(next_text) < len(candidate_text)
and retry_count < 5
):
retry_count += 1
self.logger.warning(
"embedding input too long, retry with truncated text: "
@@ -232,4 +342,3 @@ class OpenAIModel(BaseAIModel):
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
raise