Enhance AIClient and MCPServer to support tool registration with source tracking. Added logging for tool calls and improved error handling. Introduced methods for embedding token limit extraction and budget application in OpenAIModel. Added tests for MCP tool registration and execution.
This commit is contained in:
44
tests/test_mcp_tool_registration.py
Normal file
44
tests/test_mcp_tool_registration.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from src.ai.mcp.base import MCPManager, MCPServer
|
||||
|
||||
|
||||
class _DummyMCPServer(MCPServer):
|
||||
def __init__(self):
|
||||
super().__init__(name="dummy", version="1.0.0")
|
||||
|
||||
async def initialize(self):
|
||||
self.register_tool(
|
||||
name="echo",
|
||||
description="Echo text",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
handler=self.echo,
|
||||
)
|
||||
|
||||
async def echo(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
|
||||
manager = MCPManager(tmp_path / "mcp.json")
|
||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||
|
||||
tools = asyncio.run(manager.get_all_tools_for_ai())
|
||||
assert len(tools) == 1
|
||||
function_info = tools[0]["function"]
|
||||
assert function_info["name"] == "dummy.echo"
|
||||
assert function_info["description"] == "Echo text"
|
||||
assert function_info["parameters"]["required"] == ["text"]
|
||||
|
||||
|
||||
def test_mcp_manager_execute_tool(tmp_path: Path):
|
||||
manager = MCPManager(tmp_path / "mcp.json")
|
||||
asyncio.run(manager.register_server(_DummyMCPServer()))
|
||||
|
||||
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
|
||||
assert result == "hello"
|
||||
@@ -43,6 +43,18 @@ class _FakeEmbeddings:
|
||||
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
||||
|
||||
|
||||
class _LengthLimitedEmbeddings:
|
||||
def __init__(self):
|
||||
self.inputs = []
|
||||
|
||||
async def create(self, **kwargs):
|
||||
text = kwargs.get("input", "")
|
||||
self.inputs.append(text)
|
||||
if len(text) > 512:
|
||||
raise RuntimeError("input must be less than 512 tokens")
|
||||
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
|
||||
|
||||
|
||||
class _ModernCompletions:
|
||||
def __init__(self):
|
||||
self.last_params = None
|
||||
@@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI:
|
||||
self.embeddings = _FakeEmbeddings()
|
||||
|
||||
|
||||
class _LengthLimitedEmbedAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.completions = _ModernCompletions()
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
self.embeddings = _LengthLimitedEmbeddings()
|
||||
|
||||
|
||||
def test_openai_model_uses_tools_when_supported(monkeypatch):
|
||||
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
|
||||
|
||||
@@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
|
||||
assert model._supports_tools is False
|
||||
assert result.tool_calls is not None
|
||||
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
|
||||
|
||||
|
||||
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
|
||||
)
|
||||
|
||||
model = OpenAIModel(_model_config())
|
||||
long_text = "你" * 600
|
||||
|
||||
first_embedding = asyncio.run(model.embed(long_text))
|
||||
assert first_embedding == [0.1, 0.2, 0.3]
|
||||
assert model._embedding_token_limit == 512
|
||||
|
||||
inputs_after_first = list(model.client.embeddings.inputs)
|
||||
assert len(inputs_after_first) == 2
|
||||
assert len(inputs_after_first[0]) == 600
|
||||
assert len(inputs_after_first[1]) < len(inputs_after_first[0])
|
||||
|
||||
second_embedding = asyncio.run(model.embed(long_text))
|
||||
assert second_embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
inputs_after_second = list(model.client.embeddings.inputs)
|
||||
assert len(inputs_after_second) == 3
|
||||
assert len(inputs_after_second[-1]) < 512
|
||||
|
||||
Reference in New Issue
Block a user