From 0b356d1fef5b69c543e96b48be595f20c3f01f5a Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Tue, 3 Mar 2026 01:37:00 +0800 Subject: [PATCH] Enhance AIClient and OpenAIModel to support tool names in messages and improve tool capability detection. Updated message formatting to include tool names for better clarity in error handling and results. --- src/ai/client.py | 39 +++--- src/ai/models/openai_model.py | 253 ++++++++++++++++++++++++---------- 2 files changed, 201 insertions(+), 91 deletions(-) diff --git a/src/ai/client.py b/src/ai/client.py index cfbdc47..edca079 100644 --- a/src/ai/client.py +++ b/src/ai/client.py @@ -254,7 +254,8 @@ class AIClient: messages.append(Message( role="tool", content=f"工具参数解析失败: {str(e)}", - tool_call_id=fallback_id + tool_call_id=fallback_id, + name="tool" )) continue if not tool_name: @@ -265,31 +266,31 @@ class AIClient: if not tool_def: error_msg = f"未找到工具: {tool_name}" logger.warning(error_msg) - if tool_call_id: - messages.append(Message( - role="tool", - content=error_msg, - tool_call_id=tool_call_id - )) + messages.append(Message( + role="tool", + name=tool_name, + content=error_msg, + tool_call_id=tool_call_id + )) continue try: result = tool_def.function(**tool_args) if inspect.isawaitable(result): result = await result - if tool_call_id: - messages.append(Message( - role="tool", - content=str(result), - tool_call_id=tool_call_id - )) + messages.append(Message( + role="tool", + name=tool_name, + content=str(result), + tool_call_id=tool_call_id + )) except Exception as e: - if tool_call_id: - messages.append(Message( - role="tool", - content=f"工具执行失败: {str(e)}", - tool_call_id=tool_call_id - )) + messages.append(Message( + role="tool", + name=tool_name, + content=f"工具执行失败: {str(e)}", + tool_call_id=tool_call_id + )) # 再次调用模型获取最终响应 return await self.model.chat(messages, tools, **kwargs) diff --git a/src/ai/models/openai_model.py b/src/ai/models/openai_model.py index fe1e855..33934c5 100644 --- a/src/ai/models/openai_model.py +++ b/src/ai/models/openai_model.py @@ -1,45 +1,105 @@ """ -OpenAI模型实现(兼容OpenAI API的模型) +OpenAI model implementation (including OpenAI-compatible providers). """ +import inspect import json +from typing import Any, AsyncIterator, Dict, List, Optional + import httpx -from typing import List, Optional, AsyncIterator, Dict, Any from openai import AsyncOpenAI + from ..base import BaseAIModel, Message, ModelConfig from src.utils.logger import setup_logger -logger = setup_logger('OpenAIModel') +logger = setup_logger("OpenAIModel") class OpenAIModel(BaseAIModel): - """OpenAI模型实现""" - + """OpenAI model implementation.""" + def __init__(self, config: ModelConfig): super().__init__(config) self.logger = logger - - # 创建支持UTF-8的httpx客户端 + http_client = httpx.AsyncClient( timeout=config.timeout, - limits=httpx.Limits(max_keepalive_connections=5, max_connections=10) + limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), ) - + self.client = AsyncOpenAI( api_key=config.api_key, base_url=config.api_base, timeout=config.timeout, - http_client=http_client + http_client=http_client, ) - + + self._supports_tools = False + self._supports_functions = False + self._detect_tool_capability() + + def _detect_tool_capability(self) -> None: + """Detect tool-calling parameters supported by current SDK.""" + try: + signature = inspect.signature(self.client.chat.completions.create) + parameters = signature.parameters + self._supports_tools = "tools" in parameters + self._supports_functions = "functions" in parameters + except Exception as exc: + self.logger.warning(f"Failed to inspect OpenAI completion signature: {exc}") + + def _build_tool_params(self, tools: Optional[List[dict]]) -> Dict[str, Any]: + """Build SDK-compatible tool/function request parameters.""" + if not tools: + return {} + + if self._supports_tools: + return {"tools": tools} + + if self._supports_functions: + functions = [] + for tool in tools: + schema = self._extract_function_schema(tool) + if schema: + functions.append(schema) + + if functions: + return {"functions": functions, "function_call": "auto"} + + self.logger.warning( + "Tool calling is not supported by current OpenAI SDK; tools were ignored." + ) + return {} + + @staticmethod + def _extract_function_schema(tool: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if not isinstance(tool, dict): + return None + + function_data = tool.get("function") + if not isinstance(function_data, dict): + return None + + name = function_data.get("name") + if not name: + return None + + return { + "name": name, + "description": function_data.get("description", ""), + "parameters": function_data.get( + "parameters", {"type": "object", "properties": {}} + ), + } + async def chat( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> Message: - """同步对话""" + """Non-stream chat.""" formatted_messages = [self._format_message(msg) for msg in messages] - + params = { "model": self.config.model_name, "messages": formatted_messages, @@ -49,39 +109,29 @@ class OpenAIModel(BaseAIModel): "frequency_penalty": self.config.frequency_penalty, "presence_penalty": self.config.presence_penalty, } - - if tools: - params["tools"] = tools - + + params.update(self._build_tool_params(tools)) params.update(kwargs) - + response = await self.client.chat.completions.create(**params) - + choice = response.choices[0] - raw_tool_calls = ( - choice.message.tool_calls - if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls - else None - ) - tool_calls = ( - [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls] - if raw_tool_calls else None - ) + tool_calls = self._extract_response_tool_calls(choice.message) return Message( role="assistant", content=choice.message.content or "", - tool_calls=tool_calls + tool_calls=tool_calls, ) - + async def chat_stream( self, messages: List[Message], tools: Optional[List[dict]] = None, - **kwargs + **kwargs, ) -> AsyncIterator[str]: - """流式对话""" + """Streaming chat.""" formatted_messages = [self._format_message(msg) for msg in messages] - + params = { "model": self.config.model_name, "messages": formatted_messages, @@ -89,42 +139,98 @@ class OpenAIModel(BaseAIModel): "max_tokens": self.config.max_tokens, "stream": True, } - - if tools: - params["tools"] = tools - + + params.update(self._build_tool_params(tools)) params.update(kwargs) - + stream = await self.client.chat.completions.create(**params) - + async for chunk in stream: - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content + delta = chunk.choices[0].delta + if delta.content: + yield delta.content + + def _extract_response_tool_calls(self, message: Any) -> Optional[List[Dict[str, Any]]]: + raw_tool_calls = getattr(message, "tool_calls", None) + if raw_tool_calls: + return [self._normalize_tool_call(tool_call) for tool_call in raw_tool_calls] + + legacy_function_call = getattr(message, "function_call", None) + if legacy_function_call: + return [self._normalize_legacy_function_call(legacy_function_call)] + + return None + + def _normalize_legacy_function_call(self, function_call: Any) -> Dict[str, Any]: + if isinstance(function_call, dict): + function_name = function_call.get("name") + raw_arguments = function_call.get("arguments") + else: + function_name = getattr(function_call, "name", None) + raw_arguments = getattr(function_call, "arguments", None) + + if isinstance(raw_arguments, dict): + arguments = json.dumps(raw_arguments, ensure_ascii=False) + elif raw_arguments is None: + arguments = "{}" + else: + arguments = str(raw_arguments) + + return { + "type": "function", + "function": { + "name": function_name, + "arguments": arguments, + }, + } def _format_message(self, msg: Message) -> Dict[str, Any]: - """将内部消息结构转换为OpenAI消息格式""" - formatted: Dict[str, Any] = {"role": msg.role} - + """Convert internal message schema to OpenAI request format.""" if msg.role == "assistant": - formatted["content"] = msg.content if msg.content else None + formatted: Dict[str, Any] = { + "role": "assistant", + "content": msg.content if msg.content else None, + } if msg.tool_calls: - formatted["tool_calls"] = [ - self._normalize_tool_call(tool_call) - for tool_call in msg.tool_calls + normalized_calls = [ + self._normalize_tool_call(tool_call) for tool_call in msg.tool_calls ] - elif msg.role == "tool": - formatted["content"] = msg.content - if msg.tool_call_id: - formatted["tool_call_id"] = msg.tool_call_id - else: - formatted["content"] = msg.content - if msg.name: - formatted["name"] = msg.name - + if self._supports_tools: + formatted["tool_calls"] = normalized_calls + elif self._supports_functions: + if len(normalized_calls) > 1: + self.logger.warning( + "Legacy function_call mode only supports one tool call; " + "extra tool calls were dropped." + ) + function_data = normalized_calls[0].get("function") or {} + formatted["function_call"] = { + "name": function_data.get("name"), + "arguments": function_data.get("arguments", "{}"), + } + return formatted + + if msg.role == "tool": + if self._supports_tools and msg.tool_call_id: + return { + "role": "tool", + "content": msg.content, + "tool_call_id": msg.tool_call_id, + } + + return { + "role": "function", + "name": msg.name or "tool", + "content": msg.content, + } + + formatted = {"role": msg.role, "content": msg.content} + if msg.name: + formatted["name"] = msg.name return formatted def _normalize_tool_call(self, tool_call: Any) -> Dict[str, Any]: - """将工具调用对象统一转换为字典""" + """Normalize tool call object/dict to a stable dict schema.""" if isinstance(tool_call, dict): normalized = dict(tool_call) elif hasattr(tool_call, "model_dump"): @@ -137,37 +243,37 @@ class OpenAIModel(BaseAIModel): else: function_name = getattr(function, "name", None) raw_arguments = getattr(function, "arguments", None) - + normalized = { "id": getattr(tool_call, "id", None), "type": getattr(tool_call, "type", "function"), "function": { "name": function_name, - "arguments": raw_arguments - } + "arguments": raw_arguments, + }, } - + function_data = normalized.get("function") or {} if not isinstance(function_data, dict): function_data = { "name": getattr(function_data, "name", None), - "arguments": getattr(function_data, "arguments", None) + "arguments": getattr(function_data, "arguments", None), } raw_arguments = function_data.get("arguments") - + if isinstance(raw_arguments, dict): arguments = json.dumps(raw_arguments, ensure_ascii=False) elif raw_arguments is None: arguments = "{}" else: arguments = str(raw_arguments) - + function_data["arguments"] = arguments normalized["function"] = function_data normalized["type"] = normalized.get("type") or "function" - + return normalized - + @staticmethod def _is_embedding_too_long_error(error: Exception) -> bool: status_code = getattr(error, "status_code", None) @@ -196,7 +302,7 @@ class OpenAIModel(BaseAIModel): return f"{compact[:head]} {compact[-tail:]}" async def embed(self, text: str) -> List[float]: - """?????""" + """Generate embeddings.""" if isinstance(text, bytes): text = text.decode("utf-8", errors="ignore") @@ -209,13 +315,17 @@ class OpenAIModel(BaseAIModel): response = await self.client.embeddings.create( model=self.config.model_name, input=candidate_text, - encoding_format="float" + encoding_format="float", ) return response.data[0].embedding except Exception as e: if self._is_embedding_too_long_error(e): next_text = self._shrink_text_for_embedding(candidate_text) - if next_text and len(next_text) < len(candidate_text) and retry_count < 5: + if ( + next_text + and len(next_text) < len(candidate_text) + and retry_count < 5 + ): retry_count += 1 self.logger.warning( "embedding input too long, retry with truncated text: " @@ -232,4 +342,3 @@ class OpenAIModel(BaseAIModel): self.logger.error(f"text preview: {repr(candidate_text[:100])}") self.logger.error(f"full traceback:\n{traceback.format_exc()}") raise -