Enhance AIClient and OpenAIModel to support tool names in messages and improve tool capability detection. Updated message formatting to include tool names for better clarity in error handling and results.
This commit is contained in:
@@ -254,7 +254,8 @@ class AIClient:
|
|||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
content=f"工具参数解析失败: {str(e)}",
|
content=f"工具参数解析失败: {str(e)}",
|
||||||
tool_call_id=fallback_id
|
tool_call_id=fallback_id,
|
||||||
|
name="tool"
|
||||||
))
|
))
|
||||||
continue
|
continue
|
||||||
if not tool_name:
|
if not tool_name:
|
||||||
@@ -265,9 +266,9 @@ class AIClient:
|
|||||||
if not tool_def:
|
if not tool_def:
|
||||||
error_msg = f"未找到工具: {tool_name}"
|
error_msg = f"未找到工具: {tool_name}"
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
if tool_call_id:
|
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
|
name=tool_name,
|
||||||
content=error_msg,
|
content=error_msg,
|
||||||
tool_call_id=tool_call_id
|
tool_call_id=tool_call_id
|
||||||
))
|
))
|
||||||
@@ -277,16 +278,16 @@ class AIClient:
|
|||||||
result = tool_def.function(**tool_args)
|
result = tool_def.function(**tool_args)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
result = await result
|
result = await result
|
||||||
if tool_call_id:
|
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
|
name=tool_name,
|
||||||
content=str(result),
|
content=str(result),
|
||||||
tool_call_id=tool_call_id
|
tool_call_id=tool_call_id
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if tool_call_id:
|
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
|
name=tool_name,
|
||||||
content=f"工具执行失败: {str(e)}",
|
content=f"工具执行失败: {str(e)}",
|
||||||
tool_call_id=tool_call_id
|
tool_call_id=tool_call_id
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -1,43 +1,103 @@
|
|||||||
"""
|
"""
|
||||||
OpenAI模型实现(兼容OpenAI API的模型)
|
OpenAI model implementation (including OpenAI-compatible providers).
|
||||||
"""
|
"""
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List, Optional, AsyncIterator, Dict, Any
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from ..base import BaseAIModel, Message, ModelConfig
|
from ..base import BaseAIModel, Message, ModelConfig
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger('OpenAIModel')
|
logger = setup_logger("OpenAIModel")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModel(BaseAIModel):
|
class OpenAIModel(BaseAIModel):
|
||||||
"""OpenAI模型实现"""
|
"""OpenAI model implementation."""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
# 创建支持UTF-8的httpx客户端
|
|
||||||
http_client = httpx.AsyncClient(
|
http_client = httpx.AsyncClient(
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
|
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.api_base,
|
base_url=config.api_base,
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
http_client=http_client
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._supports_tools = False
|
||||||
|
self._supports_functions = False
|
||||||
|
self._detect_tool_capability()
|
||||||
|
|
||||||
|
def _detect_tool_capability(self) -> None:
|
||||||
|
"""Detect tool-calling parameters supported by current SDK."""
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(self.client.chat.completions.create)
|
||||||
|
parameters = signature.parameters
|
||||||
|
self._supports_tools = "tools" in parameters
|
||||||
|
self._supports_functions = "functions" in parameters
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}")
|
||||||
|
|
||||||
|
def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]:
|
||||||
|
"""Build SDK-compatible tool/function request parameters."""
|
||||||
|
if not tools:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if self._supports_tools:
|
||||||
|
return {"tools": tools}
|
||||||
|
|
||||||
|
if self._supports_functions:
|
||||||
|
functions = []
|
||||||
|
for tool in tools:
|
||||||
|
schema = self._extract_function_schema(tool)
|
||||||
|
if schema:
|
||||||
|
functions.append(schema)
|
||||||
|
|
||||||
|
if functions:
|
||||||
|
return {"functions": functions, "function_call": "auto"}
|
||||||
|
|
||||||
|
self.logger.warning(
|
||||||
|
"Tool calling is not supported by current OpenAI SDK; tools were ignored."
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
if not isinstance(tool, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
function_data = tool.get("function")
|
||||||
|
if not isinstance(function_data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
name = function_data.get("name")
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"description": function_data.get("description", ""),
|
||||||
|
"parameters": function_data.get(
|
||||||
|
"parameters", {"type": "object", "properties": {}}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""同步对话"""
|
"""Non-stream chat."""
|
||||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -50,36 +110,26 @@ class OpenAIModel(BaseAIModel):
|
|||||||
"presence_penalty": self.config.presence_penalty,
|
"presence_penalty": self.config.presence_penalty,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
params.update(self._build_tool_params(tools))
|
||||||
params["tools"] = tools
|
|
||||||
|
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(**params)
|
response = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
raw_tool_calls = (
|
tool_calls = self._extract_response_tool_calls(choice.message)
|
||||||
choice.message.tool_calls
|
|
||||||
if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
tool_calls = (
|
|
||||||
[self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
|
|
||||||
if raw_tool_calls else None
|
|
||||||
)
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=choice.message.content or "",
|
content=choice.message.content or "",
|
||||||
tool_calls=tool_calls
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""流式对话"""
|
"""Streaming chat."""
|
||||||
formatted_messages = [self._format_message(msg) for msg in messages]
|
formatted_messages = [self._format_message(msg) for msg in messages]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -90,41 +140,97 @@ class OpenAIModel(BaseAIModel):
|
|||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
params.update(self._build_tool_params(tools))
|
||||||
params["tools"] = tools
|
|
||||||
|
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
stream = await self.client.chat.completions.create(**params)
|
stream = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
if chunk.choices[0].delta.content:
|
delta = chunk.choices[0].delta
|
||||||
yield chunk.choices[0].delta.content
|
if delta.content:
|
||||||
|
yield delta.content
|
||||||
|
|
||||||
|
def _extract_response_tool_calls(self, message: Any) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
raw_tool_calls = getattr(message, "tool_calls", None)
|
||||||
|
if raw_tool_calls:
|
||||||
|
return [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls]
|
||||||
|
|
||||||
|
legacy_function_call = getattr(message, "function_call", None)
|
||||||
|
if legacy_function_call:
|
||||||
|
return [self._normalize_legacy_function_call(legacy_function_call)]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _normalize_legacy_function_call(self, function_call: Any) -> Dict[str, Any]:
|
||||||
|
if isinstance(function_call, dict):
|
||||||
|
function_name = function_call.get("name")
|
||||||
|
raw_arguments = function_call.get("arguments")
|
||||||
|
else:
|
||||||
|
function_name = getattr(function_call, "name", None)
|
||||||
|
raw_arguments = getattr(function_call, "arguments", None)
|
||||||
|
|
||||||
|
if isinstance(raw_arguments, dict):
|
||||||
|
arguments = json.dumps(raw_arguments, ensure_ascii=False)
|
||||||
|
elif raw_arguments is None:
|
||||||
|
arguments = "{}"
|
||||||
|
else:
|
||||||
|
arguments = str(raw_arguments)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def _format_message(self, msg: Message) -> Dict[str, Any]:
|
def _format_message(self, msg: Message) -> Dict[str, Any]:
|
||||||
"""将内部消息结构转换为OpenAI消息格式"""
|
"""Convert internal message schema to OpenAI request format."""
|
||||||
formatted: Dict[str, Any] = {"role": msg.role}
|
|
||||||
|
|
||||||
if msg.role == "assistant":
|
if msg.role == "assistant":
|
||||||
formatted["content"] = msg.content if msg.content else None
|
formatted: Dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": msg.content if msg.content else None,
|
||||||
|
}
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
formatted["tool_calls"] = [
|
normalized_calls = [
|
||||||
self._normalize_tool_call(tool_call)
|
self._normalize_tool_call(tool_call) for tool_call in msg.tool_calls
|
||||||
for tool_call in msg.tool_calls
|
|
||||||
]
|
]
|
||||||
elif msg.role == "tool":
|
if self._supports_tools:
|
||||||
formatted["content"] = msg.content
|
formatted["tool_calls"] = normalized_calls
|
||||||
if msg.tool_call_id:
|
elif self._supports_functions:
|
||||||
formatted["tool_call_id"] = msg.tool_call_id
|
if len(normalized_calls) > 1:
|
||||||
else:
|
self.logger.warning(
|
||||||
formatted["content"] = msg.content
|
"Legacy function_call mode only supports one tool call; "
|
||||||
|
"extra tool calls were dropped."
|
||||||
|
)
|
||||||
|
function_data = normalized_calls[0].get("function") or {}
|
||||||
|
formatted["function_call"] = {
|
||||||
|
"name": function_data.get("name"),
|
||||||
|
"arguments": function_data.get("arguments", "{}"),
|
||||||
|
}
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
if msg.role == "tool":
|
||||||
|
if self._supports_tools and msg.tool_call_id:
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": msg.content,
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "function",
|
||||||
|
"name": msg.name or "tool",
|
||||||
|
"content": msg.content,
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted = {"role": msg.role, "content": msg.content}
|
||||||
if msg.name:
|
if msg.name:
|
||||||
formatted["name"] = msg.name
|
formatted["name"] = msg.name
|
||||||
|
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]:
|
def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]:
|
||||||
"""将工具调用对象统一转换为字典"""
|
"""Normalize tool call object/dict to a stable dict schema."""
|
||||||
if isinstance(tool_call, dict):
|
if isinstance(tool_call, dict):
|
||||||
normalized = dict(tool_call)
|
normalized = dict(tool_call)
|
||||||
elif hasattr(tool_call, "model_dump"):
|
elif hasattr(tool_call, "model_dump"):
|
||||||
@@ -143,15 +249,15 @@ class OpenAIModel(BaseAIModel):
|
|||||||
"type": getattr(tool_call, "type", "function"),
|
"type": getattr(tool_call, "type", "function"),
|
||||||
"function": {
|
"function": {
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
"arguments": raw_arguments
|
"arguments": raw_arguments,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
function_data = normalized.get("function") or {}
|
function_data = normalized.get("function") or {}
|
||||||
if not isinstance(function_data, dict):
|
if not isinstance(function_data, dict):
|
||||||
function_data = {
|
function_data = {
|
||||||
"name": getattr(function_data, "name", None),
|
"name": getattr(function_data, "name", None),
|
||||||
"arguments": getattr(function_data, "arguments", None)
|
"arguments": getattr(function_data, "arguments", None),
|
||||||
}
|
}
|
||||||
raw_arguments = function_data.get("arguments")
|
raw_arguments = function_data.get("arguments")
|
||||||
|
|
||||||
@@ -196,7 +302,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
return f"{compact[:head]} {compact[-tail:]}"
|
return f"{compact[:head]} {compact[-tail:]}"
|
||||||
|
|
||||||
async def embed(self, text: str) -> List[float]:
|
async def embed(self, text: str) -> List[float]:
|
||||||
"""?????"""
|
"""Generate embeddings."""
|
||||||
if isinstance(text, bytes):
|
if isinstance(text, bytes):
|
||||||
text = text.decode("utf-8", errors="ignore")
|
text = text.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
@@ -209,13 +315,17 @@ class OpenAIModel(BaseAIModel):
|
|||||||
response = await self.client.embeddings.create(
|
response = await self.client.embeddings.create(
|
||||||
model=self.config.model_name,
|
model=self.config.model_name,
|
||||||
input=candidate_text,
|
input=candidate_text,
|
||||||
encoding_format="float"
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self._is_embedding_too_long_error(e):
|
if self._is_embedding_too_long_error(e):
|
||||||
next_text = self._shrink_text_for_embedding(candidate_text)
|
next_text = self._shrink_text_for_embedding(candidate_text)
|
||||||
if next_text and len(next_text) < len(candidate_text) and retry_count < 5:
|
if (
|
||||||
|
next_text
|
||||||
|
and len(next_text) < len(candidate_text)
|
||||||
|
and retry_count < 5
|
||||||
|
):
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"embedding input too long, retry with truncated text: "
|
"embedding input too long, retry with truncated text: "
|
||||||
@@ -232,4 +342,3 @@ class OpenAIModel(BaseAIModel):
|
|||||||
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
|
self.logger.error(f"text preview: {repr(candidate_text[:100])}")
|
||||||
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
|
self.logger.error(f"full traceback:\n{traceback.format_exc()}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user