Enhance AIClient and OpenAIModel to support tool names in messages and improve tool capability detection. Updated message formatting to include tool names for better clarity in error handling and results.

This commit is contained in:
Mimikko-zeus
2026-03-03 01:37:00 +08:00
parent ae208af6a9
commit 0b356d1fef
2 changed files with 201 additions and 91 deletions

View File

@@ -254,7 +254,8 @@ class AIClient:
messages.append(Message( messages.append(Message(
role="tool", role="tool",
content=f"工具参数解析失败: {str(e)}", content=f"工具参数解析失败: {str(e)}",
tool_call_id=fallback_id tool_call_id=fallback_id,
name="tool"
)) ))
continue continue
if not tool_name: if not tool_name:
@@ -265,9 +266,9 @@ class AIClient:
if not tool_def: if not tool_def:
error_msg = f"未找到工具: {tool_name}" error_msg = f"未找到工具: {tool_name}"
logger.warning(error_msg) logger.warning(error_msg)
if tool_call_id:
messages.append(Message( messages.append(Message(
role="tool", role="tool",
name=tool_name,
content=error_msg, content=error_msg,
tool_call_id=tool_call_id tool_call_id=tool_call_id
)) ))
@@ -277,16 +278,16 @@ class AIClient:
result = tool_def.function(**tool_args) result = tool_def.function(**tool_args)
if inspect.isawaitable(result): if inspect.isawaitable(result):
result = await result result = await result
if tool_call_id:
messages.append(Message( messages.append(Message(
role="tool", role="tool",
name=tool_name,
content=str(result), content=str(result),
tool_call_id=tool_call_id tool_call_id=tool_call_id
)) ))
except Exception as e: except Exception as e:
if tool_call_id:
messages.append(Message( messages.append(Message(
role="tool", role="tool",
name=tool_name,
content=f"工具执行失败: {str(e)}", content=f"工具执行失败: {str(e)}",
tool_call_id=tool_call_id tool_call_id=tool_call_id
)) ))

View File

@@ -1,43 +1,103 @@
""" """
OpenAI模型实现兼容OpenAI API的模型 OpenAI model implementation (including OpenAI-compatible providers).
""" """
import inspect
import json import json
from typing import Any, AsyncIterator, Dict, List, Optional
import httpx import httpx
from typing import List, Optional, AsyncIterator, Dict, Any
from openai import AsyncOpenAI from openai import AsyncOpenAI
from ..base import BaseAIModel, Message, ModelConfig from ..base import BaseAIModel, Message, ModelConfig
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
logger = setup_logger('OpenAIModel') logger = setup_logger("OpenAIModel")
class OpenAIModel(BaseAIModel): class OpenAIModel(BaseAIModel):
"""OpenAI模型实现""" """OpenAI model implementation."""
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.logger = logger self.logger = logger
# 创建支持UTF-8的httpx客户端
http_client = httpx.AsyncClient( http_client = httpx.AsyncClient(
timeout=config.timeout, timeout=config.timeout,
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10) limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
) )
self.client = AsyncOpenAI( self.client = AsyncOpenAI(
api_key=config.api_key, api_key=config.api_key,
base_url=config.api_base, base_url=config.api_base,
timeout=config.timeout, timeout=config.timeout,
http_client=http_client http_client=http_client,
) )
self._supports_tools = False
self._supports_functions = False
self._detect_tool_capability()
def _detect_tool_capability(self) -> None:
"""Detect tool-calling parameters supported by current SDK."""
try:
signature = inspect.signature(self.client.chat.completions.create)
parameters = signature.parameters
self._supports_tools = "tools" in parameters
self._supports_functions = "functions" in parameters
except Exception as exc:
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
"""Build SDK-compatible tool/function request parameters."""
if not tools:
return {}
if self._supports_tools:
return {"tools": tools}
if self._supports_functions:
functions = []
for tool in tools:
schema = self._extract_function_schema(tool)
if schema:
functions.append(schema)
if functions:
return {"functions": functions, "function_call": "auto"}
self.logger.warning(
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
)
return {}
@staticmethod
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not isinstance(tool, dict):
return None
function_data = tool.get("function")
if not isinstance(function_data, dict):
return None
name = function_data.get("name")
if not name:
return None
return {
"name": name,
"description": function_data.get("description", ""),
"parameters": function_data.get(
"parameters", {"type": "object", "properties": {}}
),
}
async def chat( async def chat(
self, self,
messages: List[Message], messages: List[Message],
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
**kwargs **kwargs,
) -> Message: ) -> Message:
"""同步对话""" """Non-stream chat."""
formatted_messages = [self._format_message(msg) for msg in messages] formatted_messages = [self._format_message(msg) for msg in messages]
params = { params = {
@@ -50,36 +110,26 @@ class OpenAIModel(BaseAIModel):
"presence_penalty": self.config.presence_penalty, "presence_penalty": self.config.presence_penalty,
} }
if tools: params.update(self._build_tool_params(tools))
params["tools"] = tools
params.update(kwargs) params.update(kwargs)
response = await self.client.chat.completions.create(**params) response = await self.client.chat.completions.create(**params)
choice = response.choices[0] choice = response.choices[0]
raw_tool_calls = ( tool_calls = self._extract_response_tool_calls(choice.message)
choice.message.tool_calls
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls
else None
)
tool_calls = (
[self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
if raw_tool_calls else None
)
return Message( return Message(
role="assistant", role="assistant",
content=choice.message.content or "", content=choice.message.content or "",
tool_calls=tool_calls tool_calls=tool_calls,
) )
async def chat_stream( async def chat_stream(
self, self,
messages: List[Message], messages: List[Message],
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
**kwargs **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""流式对话""" """Streaming chat."""
formatted_messages = [self._format_message(msg) for msg in messages] formatted_messages = [self._format_message(msg) for msg in messages]
params = { params = {
@@ -90,41 +140,97 @@ class OpenAIModel(BaseAIModel):
"stream": True, "stream": True,
} }
if tools: params.update(self._build_tool_params(tools))
params["tools"] = tools
params.update(kwargs) params.update(kwargs)
stream = await self.client.chat.completions.create(**params) stream = await self.client.chat.completions.create(**params)
async for chunk in stream: async for chunk in stream:
if chunk.choices[0].delta.content: delta = chunk.choices[0].delta
yield chunk.choices[0].delta.content if delta.content:
yield delta.content
def _extract_response_tool_calls(self, message: Any) -> Optional[List[Dict[str, Any]]]:
raw_tool_calls = getattr(message, "tool_calls", None)
if raw_tool_calls:
return [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
legacy_function_call = getattr(message, "function_call", None)
if legacy_function_call:
return [self._normalize_legacy_function_call(legacy_function_call)]
return None
def _normalize_legacy_function_call(self, function_call: Any) -> Dict[str, Any]:
if isinstance(function_call, dict):
function_name = function_call.get("name")
raw_arguments = function_call.get("arguments")
else:
function_name = getattr(function_call, "name", None)
raw_arguments = getattr(function_call, "arguments", None)
if isinstance(raw_arguments, dict):
arguments = json.dumps(raw_arguments, ensure_ascii=False)
elif raw_arguments is None:
arguments = "{}"
else:
arguments = str(raw_arguments)
return {
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
def _format_message(self, msg: Message) -> Dict[str, Any]: def _format_message(self, msg: Message) -> Dict[str, Any]:
"""将内部消息结构转换为OpenAI消息格式""" """Convert internal message schema to OpenAI request format."""
formatted: Dict[str, Any] = {"role": msg.role}
if msg.role == "assistant": if msg.role == "assistant":
formatted["content"] = msg.content if msg.content else None formatted: Dict[str, Any] = {
"role": "assistant",
"content": msg.content if msg.content else None,
}
if msg.tool_calls: if msg.tool_calls:
formatted["tool_calls"] = [ normalized_calls = [
self._normalize_tool_call(tool_call) self._normalize_tool_call(tool_call) for tool_call in msg.tool_calls
for tool_call in msg.tool_calls
] ]
elif msg.role == "tool": if self._supports_tools:
formatted["content"] = msg.content formatted["tool_calls"] = normalized_calls
if msg.tool_call_id: elif self._supports_functions:
formatted["tool_call_id"] = msg.tool_call_id if len(normalized_calls) > 1:
else: self.logger.warning(
formatted["content"] = msg.content "Legacy function_call mode only supports one tool call; "
"extra tool calls were dropped."
)
function_data = normalized_calls[0].get("function") or {}
formatted["function_call"] = {
"name": function_data.get("name"),
"arguments": function_data.get("arguments", "{}"),
}
return formatted
if msg.role == "tool":
if self._supports_tools and msg.tool_call_id:
return {
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id,
}
return {
"role": "function",
"name": msg.name or "tool",
"content": msg.content,
}
formatted = {"role": msg.role, "content": msg.content}
if msg.name: if msg.name:
formatted["name"] = msg.name formatted["name"] = msg.name
return formatted return formatted
def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]: def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]:
"""将工具调用对象统一转换为字典""" """Normalize tool call object/dict to a stable dict schema."""
if isinstance(tool_call, dict): if isinstance(tool_call, dict):
normalized = dict(tool_call) normalized = dict(tool_call)
elif hasattr(tool_call, "model_dump"): elif hasattr(tool_call, "model_dump"):
@@ -143,15 +249,15 @@ class OpenAIModel(BaseAIModel):
"type": getattr(tool_call, "type", "function"), "type": getattr(tool_call, "type", "function"),
"function": { "function": {
"name": function_name, "name": function_name,
"arguments": raw_arguments "arguments": raw_arguments,
} },
} }
function_data = normalized.get("function") or {} function_data = normalized.get("function") or {}
if not isinstance(function_data, dict): if not isinstance(function_data, dict):
function_data = { function_data = {
"name": getattr(function_data, "name", None), "name": getattr(function_data, "name", None),
"arguments": getattr(function_data, "arguments", None) "arguments": getattr(function_data, "arguments", None),
} }
raw_arguments = function_data.get("arguments") raw_arguments = function_data.get("arguments")
@@ -196,7 +302,7 @@ class OpenAIModel(BaseAIModel):
return f"{compact[:head]} {compact[-tail:]}" return f"{compact[:head]} {compact[-tail:]}"
async def embed(self, text: str) -> List[float]: async def embed(self, text: str) -> List[float]:
"""?????""" """Generate embeddings."""
if isinstance(text, bytes): if isinstance(text, bytes):
text = text.decode("utf-8", errors="ignore") text = text.decode("utf-8", errors="ignore")
@@ -209,13 +315,17 @@ class OpenAIModel(BaseAIModel):
response = await self.client.embeddings.create( response = await self.client.embeddings.create(
model=self.config.model_name, model=self.config.model_name,
input=candidate_text, input=candidate_text,
encoding_format="float" encoding_format="float",
) )
return response.data[0].embedding return response.data[0].embedding
except Exception as e: except Exception as e:
if self._is_embedding_too_long_error(e): if self._is_embedding_too_long_error(e):
next_text = self._shrink_text_for_embedding(candidate_text) next_text = self._shrink_text_for_embedding(candidate_text)
if next_text and len(next_text) < len(candidate_text) and retry_count < 5: if (
next_text
and len(next_text) < len(candidate_text)
and retry_count < 5
):
retry_count += 1 retry_count += 1
self.logger.warning( self.logger.warning(
"embedding input too long, retry with truncated text: " "embedding input too long, retry with truncated text: "
@@ -232,4 +342,3 @@ class OpenAIModel(BaseAIModel):
self.logger.error(f"text preview: {repr(candidate_text[:100])}") self.logger.error(f"text preview: {repr(candidate_text[:100])}")
self.logger.error(f"full traceback:\n{traceback.format_exc()}") self.logger.error(f"full traceback:\n{traceback.format_exc()}")
raise raise