Enhance AIClient and OpenAIModel to support tool names in messages and improve tool capability detection. Updated message formatting to include tool names for better clarity in error handling and results.
This commit is contained in:
@@ -254,7 +254,8 @@ class AIClient:
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
content=f"工具参数解析失败: {str(e)}",
|
||||
tool_call_id=fallback_id
|
||||
tool_call_id=fallback_id,
|
||||
name="tool"
|
||||
))
|
||||
continue
|
||||
if not tool_name:
|
||||
@@ -265,31 +266,31 @@ class AIClient:
|
||||
if not tool_def:
|
||||
error_msg = f"未找到工具: {tool_name}"
|
||||
logger.warning(error_msg)
|
||||
if tool_call_id:
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
content=error_msg,
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
name=tool_name,
|
||||
content=error_msg,
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = tool_def.function(**tool_args)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if tool_call_id:
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
content=str(result),
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
name=tool_name,
|
||||
content=str(result),
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
except Exception as e:
|
||||
if tool_call_id:
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
content=f"工具执行失败: {str(e)}",
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
messages.append(Message(
|
||||
role="tool",
|
||||
name=tool_name,
|
||||
content=f"工具执行失败: {str(e)}",
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
|
||||
# 再次调用模型获取最终响应
|
||||
return await self.model.chat(messages, tools, **kwargs)
|
||||
|
||||
@@ -1,45 +1,105 @@
|
||||
"""
|
||||
OpenAI模型实现(兼容OpenAI API的模型)
|
||||
OpenAI model implementation (including OpenAI-compatible providers).
|
||||
"""
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from typing import List, Optional, AsyncIterator, Dict, Any
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..base import BaseAIModel, Message, ModelConfig
|
||||
from src.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger('OpenAIModel')
|
||||
logger = setup_logger("OpenAIModel")
|
||||
|
||||
|
||||
class OpenAIModel(BaseAIModel):
|
||||
"""OpenAI模型实现"""
|
||||
|
||||
"""OpenAI model implementation."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__(config)
|
||||
self.logger = logger
|
||||
|
||||
# 创建支持UTF-8的httpx客户端
|
||||
|
||||
http_client = httpx.AsyncClient(
|
||||
timeout=config.timeout,
|
||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
|
||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||
)
|
||||
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=config.api_key,
|
||||
base_url=config.api_base,
|
||||
timeout=config.timeout,
|
||||
http_client=http_client
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
|
||||
self._supports_tools = False
|
||||
self._supports_functions = False
|
||||
self._detect_tool_capability()
|
||||
|
||||
def _detect_tool_capability(self) -> None:
|
||||
"""Detect tool-calling parameters supported by current SDK."""
|
||||
try:
|
||||
signature = inspect.signature(self.client.chat.completions.create)
|
||||
parameters = signature.parameters
|
||||
self._supports_tools = "tools" in parameters
|
||||
self._supports_functions = "functions" in parameters
|
||||
except Exception as exc:
|
||||
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
|
||||
|
||||
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
|
||||
"""Build SDK-compatible tool/function request parameters."""
|
||||
if not tools:
|
||||
return {}
|
||||
|
||||
if self._supports_tools:
|
||||
return {"tools": tools}
|
||||
|
||||
if self._supports_functions:
|
||||
functions = []
|
||||
for tool in tools:
|
||||
schema = self._extract_function_schema(tool)
|
||||
if schema:
|
||||
functions.append(schema)
|
||||
|
||||
if functions:
|
||||
return {"functions": functions, "function_call": "auto"}
|
||||
|
||||
self.logger.warning(
|
||||
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
|
||||
)
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
if not isinstance(tool, dict):
|
||||
return None
|
||||
|
||||
function_data = tool.get("function")
|
||||
if not isinstance(function_data, dict):
|
||||
return None
|
||||
|
||||
name = function_data.get("name")
|
||||
if not name:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"description": function_data.get("description", ""),
|
||||
"parameters": function_data.get(
|
||||
"parameters", {"type": "object", "properties": {}}
|
||||
),
|
||||
}
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Message:
|
||||
"""同步对话"""
|
||||
"""Non-stream chat."""
|
||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
@@ -49,39 +109,29 @@ class OpenAIModel(BaseAIModel):
|
||||
"frequency_penalty": self.config.frequency_penalty,
|
||||
"presence_penalty": self.config.presence_penalty,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
|
||||
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
|
||||
choice = response.choices[0]
|
||||
raw_tool_calls = (
|
||||
choice.message.tool_calls
|
||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls
|
||||
else None
|
||||
)
|
||||
tool_calls = (
|
||||
[self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
|
||||
if raw_tool_calls else None
|
||||
)
|
||||
tool_calls = self._extract_response_tool_calls(choice.message)
|
||||
return Message(
|
||||
role="assistant",
|
||||
content=choice.message.content or "",
|
||||
tool_calls=tool_calls
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式对话"""
|
||||
"""Streaming chat."""
|
||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||
|
||||
|
||||
params = {
|
||||
"model": self.config.model_name,
|
||||
"messages": formatted_messages,
|
||||
@@ -89,42 +139,98 @@ class OpenAIModel(BaseAIModel):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
|
||||
params.update(self._build_tool_params(tools))
|
||||
params.update(kwargs)
|
||||
|
||||
|
||||
stream = await self.client.chat.completions.create(**params)
|
||||
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
yield delta.content
|
||||
|
||||
def _extract_response_tool_calls(self, message: Any) -> Optional[List[Dict[str, Any]]]:
|
||||
raw_tool_calls = getattr(message, "tool_calls", None)
|
||||
if raw_tool_calls:
|
||||
return [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
|
||||
|
||||
legacy_function_call = getattr(message, "function_call", None)
|
||||
if legacy_function_call:
|
||||
return [self._normalize_legacy_function_call(legacy_function_call)]
|
||||
|
||||
return None
|
||||
|
||||
def _normalize_legacy_function_call(self, function_call: Any) -> Dict[str, Any]:
|
||||
if isinstance(function_call, dict):
|
||||
function_name = function_call.get("name")
|
||||
raw_arguments = function_call.get("arguments")
|
||||
else:
|
||||
function_name = getattr(function_call, "name", None)
|
||||
raw_arguments = getattr(function_call, "arguments", None)
|
||||
|
||||
if isinstance(raw_arguments, dict):
|
||||
arguments = json.dumps(raw_arguments, ensure_ascii=False)
|
||||
elif raw_arguments is None:
|
||||
arguments = "{}"
|
||||
else:
|
||||
arguments = str(raw_arguments)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
|
||||
def _format_message(self, msg: Message) -> Dict[str, Any]:
|
||||
"""将内部消息结构转换为OpenAI消息格式"""
|
||||
formatted: Dict[str, Any] = {"role": msg.role}
|
||||
|
||||
"""Convert internal message schema to OpenAI request format."""
|
||||
if msg.role == "assistant":
|
||||
formatted["content"] = msg.content if msg.content else None
|
||||
formatted: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": msg.content if msg.content else None,
|
||||
}
|
||||
if msg.tool_calls:
|
||||
formatted["tool_calls"] = [
|
||||
self._normalize_tool_call(tool_call)
|
||||
for tool_call in msg.tool_calls
|
||||
normalized_calls = [
|
||||
self._normalize_tool_call(tool_call) for tool_call in msg.tool_calls
|
||||
]
|
||||
elif msg.role == "tool":
|
||||
formatted["content"] = msg.content
|
||||
if msg.tool_call_id:
|
||||
formatted["tool_call_id"] = msg.tool_call_id
|
||||
else:
|
||||
formatted["content"] = msg.content
|
||||
if msg.name:
|
||||
formatted["name"] = msg.name
|
||||
|
||||
if self._supports_tools:
|
||||
formatted["tool_calls"] = normalized_calls
|
||||
elif self._supports_functions:
|
||||
if len(normalized_calls) > 1:
|
||||
self.logger.warning(
|
||||
"Legacy function_call mode only supports one tool call; "
|
||||
"extra tool calls were dropped."
|
||||
)
|
||||
function_data = normalized_calls[0].get("function") or {}
|
||||
formatted["function_call"] = {
|
||||
"name": function_data.get("name"),
|
||||
"arguments": function_data.get("arguments", "{}"),
|
||||
}
|
||||
return formatted
|
||||
|
||||
if msg.role == "tool":
|
||||
if self._supports_tools and msg.tool_call_id:
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": msg.content,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"role": "function",
|
||||
"name": msg.name or "tool",
|
||||
"content": msg.content,
|
||||
}
|
||||
|
||||
formatted = {"role": msg.role, "content": msg.content}
|
||||
if msg.name:
|
||||
formatted["name"] = msg.name
|
||||
return formatted
|
||||
|
||||
def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]:
|
||||
"""将工具调用对象统一转换为字典"""
|
||||
"""Normalize tool call object/dict to a stable dict schema."""
|
||||
if isinstance(tool_call, dict):
|
||||
normalized = dict(tool_call)
|
||||
elif hasattr(tool_call, "model_dump"):
|
||||
@@ -137,37 +243,37 @@ class OpenAIModel(BaseAIModel):
|
||||
else:
|
||||
function_name = getattr(function, "name", None)
|
||||
raw_arguments = getattr(function, "arguments", None)
|
||||
|
||||
|
||||
normalized = {
|
||||
"id": getattr(tool_call, "id", None),
|
||||
"type": getattr(tool_call, "type", "function"),
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": raw_arguments
|
||||
}
|
||||
"arguments": raw_arguments,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
function_data = normalized.get("function") or {}
|
||||
if not isinstance(function_data, dict):
|
||||
function_data = {
|
||||
"name": getattr(function_data, "name", None),
|
||||
"arguments": getattr(function_data, "arguments", None)
|
||||
"arguments": getattr(function_data, "arguments", None),
|
||||
}
|
||||
raw_arguments = function_data.get("arguments")
|
||||
|
||||
|
||||
if isinstance(raw_arguments, dict):
|
||||
arguments = json.dumps(raw_arguments, ensure_ascii=False)
|
||||
elif raw_arguments is None:
|
||||
arguments = "{}"
|
||||
else:
|
||||
arguments = str(raw_arguments)
|
||||
|
||||
|
||||
function_data["arguments"] = arguments
|
||||
normalized["function"] = function_data
|
||||
normalized["type"] = normalized.get("type") or "function"
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _is_embedding_too_long_error(error: Exception) -> bool:
|
||||
status_code = getattr(error, "status_code", None)
|
||||
@@ -196,7 +302,7 @@ class OpenAIModel(BaseAIModel):
|
||||
return f"{compact[:head]} {compact[-tail:]}"
|
||||
|
||||
async def embed(self, text: str) -> List[float]:
|
||||
"""?????"""
|
||||
"""Generate embeddings."""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode("utf-8", errors="ignore")
|
||||
|
||||
@@ -209,13 +315,17 @@ class OpenAIModel(BaseAIModel):
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.config.model_name,
|
||||
input=candidate_text,
|
||||
encoding_format="float"
|
||||
encoding_format="float",
|
||||
)
|
||||
return response.data[0].embedding
|
||||
except Exception as e:
|
||||
if self._is_embedding_too_long_error(e):
|
||||
next_text = self._shrink_text_for_embedding(candidate_text)
|
||||
if next_text and len(next_text) < len(candidate_text) and retry_count < 5:
|
||||
if (
|
||||
next_text
|
||||
and len(next_text) < len(candidate_text)
|
||||
and retry_count < 5
|
||||
):
|
||||
retry_count += 1
|
||||
self.logger.warning(
|
||||
"embedding input too long, retry with truncated text: "
|
||||
@@ -232,4 +342,3 @@ class OpenAIModel(BaseAIModel):
|
||||
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
|
||||
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user