Enhance AIClient and MCPServer to support tool registration with source tracking. Added logging for tool calls and improved error handling. Introduced methods for embedding token limit extraction and budget application in OpenAIModel. Added tests for MCP tool registration and execution.
This commit is contained in:
@@ -43,6 +43,7 @@ class AIClient:
|
|||||||
|
|
||||||
# 初始化工具注册表
|
# 初始化工具注册表
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
|
self._tool_sources: Dict[str, str] = {}
|
||||||
|
|
||||||
# 初始化人格系统
|
# 初始化人格系统
|
||||||
self.personality = PersonalitySystem(
|
self.personality = PersonalitySystem(
|
||||||
@@ -242,8 +243,11 @@ class AIClient:
|
|||||||
) -> Message:
|
) -> Message:
|
||||||
"""处理工具调用。"""
|
"""处理工具调用。"""
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
total_calls = len(response.tool_calls or [])
|
||||||
|
if total_calls:
|
||||||
|
logger.info(f"检测到工具调用请求: {total_calls} 个")
|
||||||
|
|
||||||
# 鎵ц宸ュ叿璋冪敤
|
# 执行工具调用
|
||||||
for tool_call in response.tool_calls or []:
|
for tool_call in response.tool_calls or []:
|
||||||
try:
|
try:
|
||||||
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
|
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
|
||||||
@@ -263,6 +267,7 @@ class AIClient:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
tool_def = self.tools.get(tool_name)
|
tool_def = self.tools.get(tool_name)
|
||||||
|
tool_source = self._tool_sources.get(tool_name, "custom")
|
||||||
if not tool_def:
|
if not tool_def:
|
||||||
error_msg = f"未找到工具: {tool_name}"
|
error_msg = f"未找到工具: {tool_name}"
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
@@ -275,9 +280,19 @@ class AIClient:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"工具调用开始: "
|
||||||
|
f"source={tool_source}, name={tool_name}, "
|
||||||
|
f"args={self._preview_log_payload(tool_args)}"
|
||||||
|
)
|
||||||
result = tool_def.function(**tool_args)
|
result = tool_def.function(**tool_args)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
result = await result
|
result = await result
|
||||||
|
logger.info(
|
||||||
|
"工具调用成功: "
|
||||||
|
f"source={tool_source}, name={tool_name}, "
|
||||||
|
f"result={self._preview_log_payload(result)}"
|
||||||
|
)
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
@@ -285,6 +300,10 @@ class AIClient:
|
|||||||
tool_call_id=tool_call_id
|
tool_call_id=tool_call_id
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"工具调用失败: "
|
||||||
|
f"source={tool_source}, name={tool_name}, error={e}"
|
||||||
|
)
|
||||||
messages.append(Message(
|
messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
@@ -335,6 +354,18 @@ class AIClient:
|
|||||||
|
|
||||||
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
|
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
|
||||||
|
"""日志中展示参数/结果时使用的简短预览。"""
|
||||||
|
try:
|
||||||
|
text = json.dumps(payload, ensure_ascii=False, default=str)
|
||||||
|
except Exception:
|
||||||
|
text = str(payload)
|
||||||
|
|
||||||
|
if len(text) > max_len:
|
||||||
|
return text[:max_len] + "..."
|
||||||
|
return text
|
||||||
|
|
||||||
def set_personality(self, personality_name: str) -> bool:
|
def set_personality(self, personality_name: str) -> bool:
|
||||||
"""设置人格。"""
|
"""设置人格。"""
|
||||||
return self.personality.set_personality(personality_name)
|
return self.personality.set_personality(personality_name)
|
||||||
@@ -382,7 +413,14 @@ class AIClient:
|
|||||||
"""获取任务状态。"""
|
"""获取任务状态。"""
|
||||||
return self.task_manager.get_task_status(task_id)
|
return self.task_manager.get_task_status(task_id)
|
||||||
|
|
||||||
def register_tool(self, name: str, description: str, parameters: Dict, function: callable):
|
def register_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: Dict,
|
||||||
|
function: callable,
|
||||||
|
source: str = "custom",
|
||||||
|
):
|
||||||
"""注册工具。"""
|
"""注册工具。"""
|
||||||
from .base import ToolDefinition
|
from .base import ToolDefinition
|
||||||
tool = ToolDefinition(
|
tool = ToolDefinition(
|
||||||
@@ -392,18 +430,23 @@ class AIClient:
|
|||||||
function=function
|
function=function
|
||||||
)
|
)
|
||||||
self.tools.register(tool)
|
self.tools.register(tool)
|
||||||
logger.info(f"已注册工具: {name}")
|
self._tool_sources[name] = source
|
||||||
|
logger.info(f"已注册工具: {name} (source={source})")
|
||||||
|
|
||||||
def unregister_tool(self, name: str) -> bool:
|
def unregister_tool(self, name: str) -> bool:
|
||||||
"""卸载工具。"""
|
"""卸载工具。"""
|
||||||
removed = self.tools.unregister(name)
|
removed = self.tools.unregister(name)
|
||||||
if removed:
|
if removed:
|
||||||
|
self._tool_sources.pop(name, None)
|
||||||
logger.info(f"已卸载工具: {name}")
|
logger.info(f"已卸载工具: {name}")
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
def unregister_tools_by_prefix(self, prefix: str) -> int:
|
def unregister_tools_by_prefix(self, prefix: str) -> int:
|
||||||
"""按前缀批量卸载工具。"""
|
"""按前缀批量卸载工具。"""
|
||||||
removed_count = self.tools.unregister_by_prefix(prefix)
|
removed_count = self.tools.unregister_by_prefix(prefix)
|
||||||
|
for tool_name in list(self._tool_sources.keys()):
|
||||||
|
if tool_name.startswith(prefix):
|
||||||
|
self._tool_sources.pop(tool_name, None)
|
||||||
if removed_count:
|
if removed_count:
|
||||||
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
|
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
|
||||||
return removed_count
|
return removed_count
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class MCPServer:
|
|||||||
self.version = version
|
self.version = version
|
||||||
self.resources: Dict[str, MCPResource] = {}
|
self.resources: Dict[str, MCPResource] = {}
|
||||||
self.tools: Dict[str, Callable] = {}
|
self.tools: Dict[str, Callable] = {}
|
||||||
|
self.tool_specs: Dict[str, MCPTool] = {}
|
||||||
self.prompts: Dict[str, MCPPrompt] = {}
|
self.prompts: Dict[str, MCPPrompt] = {}
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
@@ -61,6 +62,7 @@ class MCPServer:
|
|||||||
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
|
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
|
||||||
"""注册工具"""
|
"""注册工具"""
|
||||||
tool = MCPTool(name=name, description=description, input_schema=input_schema)
|
tool = MCPTool(name=name, description=description, input_schema=input_schema)
|
||||||
|
self.tool_specs[name] = tool
|
||||||
self.tools[name] = handler
|
self.tools[name] = handler
|
||||||
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
|
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
|
||||||
|
|
||||||
@@ -78,10 +80,7 @@ class MCPServer:
|
|||||||
|
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
"""列出工具"""
|
"""列出工具"""
|
||||||
return [
|
return list(self.tool_specs.values())
|
||||||
MCPTool(name=name, description="", input_schema={})
|
|
||||||
for name in self.tools.keys()
|
|
||||||
]
|
|
||||||
|
|
||||||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
"""调用工具"""
|
"""调用工具"""
|
||||||
@@ -89,7 +88,16 @@ class MCPServer:
|
|||||||
raise ValueError(f"工具不存在: {name}")
|
raise ValueError(f"工具不存在: {name}")
|
||||||
|
|
||||||
handler = self.tools[name]
|
handler = self.tools[name]
|
||||||
return await handler(**arguments)
|
logger.info(
|
||||||
|
f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await handler(**arguments)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
|
||||||
|
raise
|
||||||
|
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
|
||||||
|
return result
|
||||||
|
|
||||||
async def list_prompts(self) -> List[MCPPrompt]:
|
async def list_prompts(self) -> List[MCPPrompt]:
|
||||||
"""列出提示词"""
|
"""列出提示词"""
|
||||||
@@ -216,4 +224,5 @@ class MCPManager:
|
|||||||
raise ValueError(f"工具名格式错误: {full_tool_name}")
|
raise ValueError(f"工具名格式错误: {full_tool_name}")
|
||||||
|
|
||||||
server_name, tool_name = parts
|
server_name, tool_name = parts
|
||||||
|
logger.info(f"MCP执行请求: {full_tool_name}")
|
||||||
return await self.client.call_tool(server_name, tool_name, arguments)
|
return await self.client.call_tool(server_name, tool_name, arguments)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ OpenAI model implementation (including OpenAI-compatible providers).
|
|||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -20,6 +21,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self._embedding_token_limit: Optional[int] = None
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(
|
http_client = httpx.AsyncClient(
|
||||||
timeout=config.timeout,
|
timeout=config.timeout,
|
||||||
@@ -317,6 +319,70 @@ class OpenAIModel(BaseAIModel):
|
|||||||
or "maximum context length" in message
|
or "maximum context length" in message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_embedding_token_limit(error: Exception) -> Optional[int]:
|
||||||
|
message = str(error).lower()
|
||||||
|
patterns = [
|
||||||
|
r"less than\s+(\d+)\s+tokens",
|
||||||
|
r"maximum context length.*?(\d+)\s+tokens",
|
||||||
|
r"max(?:imum)?(?: input)?(?: length)?\D+(\d+)\s+tokens",
|
||||||
|
]
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, message)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit = int(match.group(1))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if limit > 0:
|
||||||
|
return limit
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_embedding_tokens(text: str) -> int:
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
ascii_count = 0
|
||||||
|
cjk_count = 0
|
||||||
|
other_count = 0
|
||||||
|
|
||||||
|
for ch in text:
|
||||||
|
code = ord(ch)
|
||||||
|
if code < 128:
|
||||||
|
ascii_count += 1
|
||||||
|
elif 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF:
|
||||||
|
cjk_count += 1
|
||||||
|
else:
|
||||||
|
other_count += 1
|
||||||
|
|
||||||
|
estimated = cjk_count + int(other_count * 0.7) + int(ascii_count / 4)
|
||||||
|
return max(1, estimated)
|
||||||
|
|
||||||
|
def _apply_embedding_budget(self, text: str) -> str:
|
||||||
|
if not text or not self._embedding_token_limit:
|
||||||
|
return text
|
||||||
|
|
||||||
|
safe_limit = max(32, int(self._embedding_token_limit * 0.9))
|
||||||
|
estimated_tokens = self._estimate_embedding_tokens(text)
|
||||||
|
if estimated_tokens <= safe_limit:
|
||||||
|
return text
|
||||||
|
|
||||||
|
ratio = safe_limit / max(estimated_tokens, 1)
|
||||||
|
target_length = max(64, int(len(text) * ratio))
|
||||||
|
if target_length >= len(text):
|
||||||
|
target_length = max(64, len(text) - 1)
|
||||||
|
if target_length <= 0 or target_length >= len(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
head = target_length // 2
|
||||||
|
tail = target_length - head
|
||||||
|
return f"{text[:head]} {text[-tail:]}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _shrink_text_for_embedding(text: str) -> str:
|
def _shrink_text_for_embedding(text: str) -> str:
|
||||||
compact = " ".join((text or "").split())
|
compact = " ".join((text or "").split())
|
||||||
@@ -338,6 +404,7 @@ class OpenAIModel(BaseAIModel):
|
|||||||
|
|
||||||
raw_text = str(text or "")
|
raw_text = str(text or "")
|
||||||
candidate_text = raw_text.strip() or raw_text or " "
|
candidate_text = raw_text.strip() or raw_text or " "
|
||||||
|
candidate_text = self._apply_embedding_budget(candidate_text)
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -350,16 +417,32 @@ class OpenAIModel(BaseAIModel):
|
|||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self._is_embedding_too_long_error(e):
|
if self._is_embedding_too_long_error(e):
|
||||||
next_text = self._shrink_text_for_embedding(candidate_text)
|
token_limit = self._extract_embedding_token_limit(e)
|
||||||
|
if token_limit:
|
||||||
|
if not self._embedding_token_limit:
|
||||||
|
self._embedding_token_limit = token_limit
|
||||||
|
else:
|
||||||
|
self._embedding_token_limit = min(
|
||||||
|
self._embedding_token_limit, token_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
next_text = self._apply_embedding_budget(
|
||||||
|
self._shrink_text_for_embedding(candidate_text)
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
next_text
|
next_text
|
||||||
and len(next_text) < len(candidate_text)
|
and len(next_text) < len(candidate_text)
|
||||||
and retry_count < 5
|
and retry_count < 5
|
||||||
):
|
):
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
self.logger.warning(
|
limit_desc = (
|
||||||
"embedding input too long, retry with truncated text: "
|
f", token_limit={self._embedding_token_limit}"
|
||||||
f"{len(candidate_text)} -> {len(next_text)}"
|
if self._embedding_token_limit
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
"embedding input exceeded provider limit, retry with truncated text: "
|
||||||
|
f"{len(candidate_text)} -> {len(next_text)}{limit_desc}"
|
||||||
)
|
)
|
||||||
candidate_text = next_text
|
candidate_text = next_text
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -324,6 +324,43 @@ class MessageHandler:
|
|||||||
description=f"技能工具: {full_tool_name}",
|
description=f"技能工具: {full_tool_name}",
|
||||||
parameters={"type": "object", "properties": {}},
|
parameters={"type": "object", "properties": {}},
|
||||||
function=tool_func,
|
function=tool_func,
|
||||||
|
source="skills",
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _register_mcp_tools(self) -> int:
|
||||||
|
if not self.mcp_manager or not self.ai_client:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
tools = await self.mcp_manager.get_all_tools_for_ai()
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for item in tools:
|
||||||
|
function_info = item.get("function") if isinstance(item, dict) else None
|
||||||
|
if not isinstance(function_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_tool_name = function_info.get("name")
|
||||||
|
if not full_tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
parameters = function_info.get("parameters")
|
||||||
|
if not isinstance(parameters, dict):
|
||||||
|
parameters = {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs):
|
||||||
|
if not self.mcp_manager:
|
||||||
|
raise RuntimeError("MCP manager not initialized")
|
||||||
|
return await self.mcp_manager.execute_tool(_full_tool_name, kwargs)
|
||||||
|
|
||||||
|
self.ai_client.register_tool(
|
||||||
|
name=full_tool_name,
|
||||||
|
description=f"MCP工具: {full_tool_name}",
|
||||||
|
parameters=parameters,
|
||||||
|
function=_mcp_proxy,
|
||||||
|
source="mcp",
|
||||||
)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -439,6 +476,8 @@ class MessageHandler:
|
|||||||
self.mcp_manager = MCPManager(Path("config/mcp.json"))
|
self.mcp_manager = MCPManager(Path("config/mcp.json"))
|
||||||
fs_server = FileSystemMCPServer(root_path=Path("data"))
|
fs_server = FileSystemMCPServer(root_path=Path("data"))
|
||||||
await self.mcp_manager.register_server(fs_server)
|
await self.mcp_manager.register_server(fs_server)
|
||||||
|
mcp_tool_count = await self._register_mcp_tools()
|
||||||
|
logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"MCP 初始化失败: {exc}")
|
logger.warning(f"MCP 初始化失败: {exc}")
|
||||||
|
|
||||||
|
|||||||
44
tests/test_mcp_tool_registration.py
Normal file
44
tests/test_mcp_tool_registration.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.ai.mcp.base import MCPManager, MCPServer
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyMCPServer(MCPServer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name="dummy", version="1.0.0")
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.register_tool(
|
||||||
|
name="echo",
|
||||||
|
description="Echo text",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"text": {"type": "string"}},
|
||||||
|
"required": ["text"],
|
||||||
|
},
|
||||||
|
handler=self.echo,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def echo(self, text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
|
||||||
|
manager = MCPManager(tmp_path / "mcp.json")
|
||||||
|
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||||
|
|
||||||
|
tools = asyncio.run(manager.get_all_tools_for_ai())
|
||||||
|
assert len(tools) == 1
|
||||||
|
function_info = tools[0]["function"]
|
||||||
|
assert function_info["name"] == "dummy.echo"
|
||||||
|
assert function_info["description"] == "Echo text"
|
||||||
|
assert function_info["parameters"]["required"] == ["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_manager_execute_tool(tmp_path: Path):
|
||||||
|
manager = MCPManager(tmp_path / "mcp.json")
|
||||||
|
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||||
|
|
||||||
|
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
|
||||||
|
assert result == "hello"
|
||||||
@@ -43,6 +43,18 @@ class _FakeEmbeddings:
|
|||||||
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
||||||
|
|
||||||
|
|
||||||
|
class _LengthLimitedEmbeddings:
|
||||||
|
def __init__(self):
|
||||||
|
self.inputs = []
|
||||||
|
|
||||||
|
async def create(self, **kwargs):
|
||||||
|
text = kwargs.get("input", "")
|
||||||
|
self.inputs.append(text)
|
||||||
|
if len(text) > 512:
|
||||||
|
raise RuntimeError("input must be less than 512 tokens")
|
||||||
|
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
||||||
|
|
||||||
|
|
||||||
class _ModernCompletions:
|
class _ModernCompletions:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_params = None
|
self.last_params = None
|
||||||
@@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI:
|
|||||||
self.embeddings = _FakeEmbeddings()
|
self.embeddings = _FakeEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
|
class _LengthLimitedEmbedAsyncOpenAI:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.completions = _ModernCompletions()
|
||||||
|
self.chat = SimpleNamespace(completions=self.completions)
|
||||||
|
self.embeddings = _LengthLimitedEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||||
|
|
||||||
@@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
|||||||
assert model._supports_tools is False
|
assert model._supports_tools is False
|
||||||
assert result.tool_calls is not None
|
assert result.tool_calls is not None
|
||||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
||||||
|
)
|
||||||
|
|
||||||
|
model = OpenAIModel(_model_config())
|
||||||
|
long_text = "你" * 600
|
||||||
|
|
||||||
|
first_embedding = asyncio.run(model.embed(long_text))
|
||||||
|
assert first_embedding == [0.1, 0.2, 0.3]
|
||||||
|
assert model._embedding_token_limit == 512
|
||||||
|
|
||||||
|
inputs_after_first = list(model.client.embeddings.inputs)
|
||||||
|
assert len(inputs_after_first) == 2
|
||||||
|
assert len(inputs_after_first[0]) == 600
|
||||||
|
assert len(inputs_after_first[1]) < len(inputs_after_first[0])
|
||||||
|
|
||||||
|
second_embedding = asyncio.run(model.embed(long_text))
|
||||||
|
assert second_embedding == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
inputs_after_second = list(model.client.embeddings.inputs)
|
||||||
|
assert len(inputs_after_second) == 3
|
||||||
|
assert len(inputs_after_second[-1]) < 512
|
||||||
|
|||||||
Reference in New Issue
Block a user