Enhance AIClient and MCPServer to support tool registration with source tracking. Added logging for tool calls and improved error handling. Introduced methods for embedding token limit extraction and budget application in OpenAIModel. Added tests for MCP tool registration and execution.

This commit is contained in:
Mimikko-zeus
2026-03-03 13:10:09 +08:00
parent 586f09c3a5
commit fd2a09f681
6 changed files with 274 additions and 12 deletions

View File

@@ -43,6 +43,7 @@ class AIClient:
# 初始化工具注册表
self.tools = ToolRegistry()
self._tool_sources: Dict[str, str] = {}
# 初始化人格系统
self.personality = PersonalitySystem(
@@ -242,8 +243,11 @@ class AIClient:
) -> Message:
"""处理工具调用。"""
messages.append(response)
total_calls = len(response.tool_calls or [])
if total_calls:
logger.info(f"检测到工具调用请求: {total_calls}")
# 鎵ц宸ュ叿璋冪敤
# 执行工具调用
for tool_call in response.tool_calls or []:
try:
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
@@ -263,6 +267,7 @@ class AIClient:
continue
tool_def = self.tools.get(tool_name)
tool_source = self._tool_sources.get(tool_name, "custom")
if not tool_def:
error_msg = f"未找到工具: {tool_name}"
logger.warning(error_msg)
@@ -275,9 +280,19 @@ class AIClient:
continue
try:
logger.info(
"工具调用开始: "
f"source={tool_source}, name={tool_name}, "
f"args={self._preview_log_payload(tool_args)}"
)
result = tool_def.function(**tool_args)
if inspect.isawaitable(result):
result = await result
logger.info(
"工具调用成功: "
f"source={tool_source}, name={tool_name}, "
f"result={self._preview_log_payload(result)}"
)
messages.append(Message(
role="tool",
name=tool_name,
@@ -285,6 +300,10 @@ class AIClient:
tool_call_id=tool_call_id
))
except Exception as e:
logger.warning(
"工具调用失败: "
f"source={tool_source}, name={tool_name}, error={e}"
)
messages.append(Message(
role="tool",
name=tool_name,
@@ -335,6 +354,18 @@ class AIClient:
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
@staticmethod
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
"""日志中展示参数/结果时使用的简短预览。"""
try:
text = json.dumps(payload, ensure_ascii=False, default=str)
except Exception:
text = str(payload)
if len(text) > max_len:
return text[:max_len] + "..."
return text
def set_personality(self, personality_name: str) -> bool:
"""设置人格。"""
return self.personality.set_personality(personality_name)
@@ -382,7 +413,14 @@ class AIClient:
"""获取任务状态。"""
return self.task_manager.get_task_status(task_id)
def register_tool(self, name: str, description: str, parameters: Dict, function: callable):
def register_tool(
self,
name: str,
description: str,
parameters: Dict,
function: callable,
source: str = "custom",
):
"""注册工具。"""
from .base import ToolDefinition
tool = ToolDefinition(
@@ -392,18 +430,23 @@ class AIClient:
function=function
)
self.tools.register(tool)
logger.info(f"已注册工具: {name}")
self._tool_sources[name] = source
logger.info(f"已注册工具: {name} (source={source})")
def unregister_tool(self, name: str) -> bool:
"""卸载工具。"""
removed = self.tools.unregister(name)
if removed:
self._tool_sources.pop(name, None)
logger.info(f"已卸载工具: {name}")
return removed
def unregister_tools_by_prefix(self, prefix: str) -> int:
"""按前缀批量卸载工具。"""
removed_count = self.tools.unregister_by_prefix(prefix)
for tool_name in list(self._tool_sources.keys()):
if tool_name.startswith(prefix):
self._tool_sources.pop(tool_name, None)
if removed_count:
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
return removed_count

View File

@@ -44,6 +44,7 @@ class MCPServer:
self.version = version
self.resources: Dict[str, MCPResource] = {}
self.tools: Dict[str, Callable] = {}
self.tool_specs: Dict[str, MCPTool] = {}
self.prompts: Dict[str, MCPPrompt] = {}
async def initialize(self):
@@ -61,6 +62,7 @@ class MCPServer:
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
"""注册工具"""
tool = MCPTool(name=name, description=description, input_schema=input_schema)
self.tool_specs[name] = tool
self.tools[name] = handler
logger.info(f"✅ MCP工具注册: {self.name}.{name}")
@@ -78,10 +80,7 @@ class MCPServer:
async def list_tools(self) -> List[MCPTool]:
"""列出工具"""
return [
MCPTool(name=name, description="", input_schema={})
for name in self.tools.keys()
]
return list(self.tool_specs.values())
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
@@ -89,7 +88,16 @@ class MCPServer:
raise ValueError(f"工具不存在: {name}")
handler = self.tools[name]
return await handler(**arguments)
logger.info(
f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}"
)
try:
result = await handler(**arguments)
except Exception as exc:
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
raise
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
return result
async def list_prompts(self) -> List[MCPPrompt]:
"""列出提示词"""
@@ -216,4 +224,5 @@ class MCPManager:
raise ValueError(f"工具名格式错误: {full_tool_name}")
server_name, tool_name = parts
logger.info(f"MCP执行请求: {full_tool_name}")
return await self.client.call_tool(server_name, tool_name, arguments)

View File

@@ -3,6 +3,7 @@ OpenAI model implementation (including OpenAI-compatible providers).
"""
import inspect
import json
import re
from typing import Any, AsyncIterator, Dict, List, Optional
import httpx
@@ -20,6 +21,7 @@ class OpenAIModel(BaseAIModel):
def __init__(self, config: ModelConfig):
super().__init__(config)
self.logger = logger
self._embedding_token_limit: Optional[int] = None
http_client = httpx.AsyncClient(
timeout=config.timeout,
@@ -317,6 +319,70 @@ class OpenAIModel(BaseAIModel):
or "maximum context length" in message
)
@staticmethod
def _extract_embedding_token_limit(error: Exception) -> Optional[int]:
message = str(error).lower()
patterns = [
r"less than\s+(\d+)\s+tokens",
r"maximum context length.*?(\d+)\s+tokens",
r"max(?:imum)?(?: input)?(?: length)?\D+(\d+)\s+tokens",
]
for pattern in patterns:
match = re.search(pattern, message)
if not match:
continue
try:
limit = int(match.group(1))
except (TypeError, ValueError):
continue
if limit > 0:
return limit
return None
@staticmethod
def _estimate_embedding_tokens(text: str) -> int:
if not text:
return 0
ascii_count = 0
cjk_count = 0
other_count = 0
for ch in text:
code = ord(ch)
if code < 128:
ascii_count += 1
elif 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF:
cjk_count += 1
else:
other_count += 1
estimated = cjk_count + int(other_count * 0.7) + int(ascii_count / 4)
return max(1, estimated)
def _apply_embedding_budget(self, text: str) -> str:
if not text or not self._embedding_token_limit:
return text
safe_limit = max(32, int(self._embedding_token_limit * 0.9))
estimated_tokens = self._estimate_embedding_tokens(text)
if estimated_tokens <= safe_limit:
return text
ratio = safe_limit / max(estimated_tokens, 1)
target_length = max(64, int(len(text) * ratio))
if target_length >= len(text):
target_length = max(64, len(text) - 1)
if target_length <= 0 or target_length >= len(text):
return text
head = target_length // 2
tail = target_length - head
return f"{text[:head]} {text[-tail:]}"
@staticmethod
def _shrink_text_for_embedding(text: str) -> str:
compact = " ".join((text or "").split())
@@ -338,6 +404,7 @@ class OpenAIModel(BaseAIModel):
raw_text = str(text or "")
candidate_text = raw_text.strip() or raw_text or " "
candidate_text = self._apply_embedding_budget(candidate_text)
retry_count = 0
while True:
@@ -350,16 +417,32 @@ class OpenAIModel(BaseAIModel):
return response.data[0].embedding
except Exception as e:
if self._is_embedding_too_long_error(e):
next_text = self._shrink_text_for_embedding(candidate_text)
token_limit = self._extract_embedding_token_limit(e)
if token_limit:
if not self._embedding_token_limit:
self._embedding_token_limit = token_limit
else:
self._embedding_token_limit = min(
self._embedding_token_limit, token_limit
)
next_text = self._apply_embedding_budget(
self._shrink_text_for_embedding(candidate_text)
)
if (
next_text
and len(next_text) < len(candidate_text)
and retry_count < 5
):
retry_count += 1
self.logger.warning(
"embedding input too long, retry with truncated text: "
f"{len(candidate_text)} -> {len(next_text)}"
limit_desc = (
f", token_limit={self._embedding_token_limit}"
if self._embedding_token_limit
else ""
)
self.logger.info(
"embedding input exceeded provider limit, retry with truncated text: "
f"{len(candidate_text)} -> {len(next_text)}{limit_desc}"
)
candidate_text = next_text
continue

View File

@@ -324,6 +324,43 @@ class MessageHandler:
description=f"技能工具: {full_tool_name}",
parameters={"type": "object", "properties": {}},
function=tool_func,
source="skills",
)
count += 1
return count
async def _register_mcp_tools(self) -> int:
if not self.mcp_manager or not self.ai_client:
return 0
tools = await self.mcp_manager.get_all_tools_for_ai()
count = 0
for item in tools:
function_info = item.get("function") if isinstance(item, dict) else None
if not isinstance(function_info, dict):
continue
full_tool_name = function_info.get("name")
if not full_tool_name:
continue
parameters = function_info.get("parameters")
if not isinstance(parameters, dict):
parameters = {"type": "object", "properties": {}}
async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs):
if not self.mcp_manager:
raise RuntimeError("MCP manager not initialized")
return await self.mcp_manager.execute_tool(_full_tool_name, kwargs)
self.ai_client.register_tool(
name=full_tool_name,
description=f"MCP工具: {full_tool_name}",
parameters=parameters,
function=_mcp_proxy,
source="mcp",
)
count += 1
@@ -439,6 +476,8 @@ class MessageHandler:
self.mcp_manager = MCPManager(Path("config/mcp.json"))
fs_server = FileSystemMCPServer(root_path=Path("data"))
await self.mcp_manager.register_server(fs_server)
mcp_tool_count = await self._register_mcp_tools()
logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools")
except Exception as exc:
logger.warning(f"MCP 初始化失败: {exc}")

View File

@@ -0,0 +1,44 @@
import asyncio
from pathlib import Path
from src.ai.mcp.base import MCPManager, MCPServer
class _DummyMCPServer(MCPServer):
def __init__(self):
super().__init__(name="dummy", version="1.0.0")
async def initialize(self):
self.register_tool(
name="echo",
description="Echo text",
input_schema={
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"],
},
handler=self.echo,
)
async def echo(self, text: str) -> str:
return text
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
tools = asyncio.run(manager.get_all_tools_for_ai())
assert len(tools) == 1
function_info = tools[0]["function"]
assert function_info["name"] == "dummy.echo"
assert function_info["description"] == "Echo text"
assert function_info["parameters"]["required"] == ["text"]
def test_mcp_manager_execute_tool(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
assert result == "hello"

View File

@@ -43,6 +43,18 @@ class _FakeEmbeddings:
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
class _LengthLimitedEmbeddings:
def __init__(self):
self.inputs = []
async def create(self, **kwargs):
text = kwargs.get("input", "")
self.inputs.append(text)
if len(text) > 512:
raise RuntimeError("input must be less than 512 tokens")
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
class _ModernCompletions:
def __init__(self):
self.last_params = None
@@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI:
self.embeddings = _FakeEmbeddings()
class _LengthLimitedEmbedAsyncOpenAI:
def __init__(self, **kwargs):
self.completions = _ModernCompletions()
self.chat = SimpleNamespace(completions=self.completions)
self.embeddings = _LengthLimitedEmbeddings()
def test_openai_model_uses_tools_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
@@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
assert model._supports_tools is False
assert result.tool_calls is not None
assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
)
model = OpenAIModel(_model_config())
long_text = "" * 600
first_embedding = asyncio.run(model.embed(long_text))
assert first_embedding == [0.1, 0.2, 0.3]
assert model._embedding_token_limit == 512
inputs_after_first = list(model.client.embeddings.inputs)
assert len(inputs_after_first) == 2
assert len(inputs_after_first[0]) == 600
assert len(inputs_after_first[1]) < len(inputs_after_first[0])
second_embedding = asyncio.run(model.embed(long_text))
assert second_embedding == [0.1, 0.2, 0.3]
inputs_after_second = list(model.client.embeddings.inputs)
assert len(inputs_after_second) == 3
assert len(inputs_after_second[-1]) < 512