Enhance AIClient and MCPServer to support tool registration with source tracking. Added logging for tool calls and improved error handling. Introduced methods for embedding token limit extraction and budget application in OpenAIModel. Added tests for MCP tool registration and execution.

This commit is contained in:
Mimikko-zeus
2026-03-03 13:10:09 +08:00
parent 586f09c3a5
commit fd2a09f681
6 changed files with 274 additions and 12 deletions

View File

@@ -43,6 +43,7 @@ class AIClient:
# 初始化工具注册表 # 初始化工具注册表
self.tools = ToolRegistry() self.tools = ToolRegistry()
self._tool_sources: Dict[str, str] = {}
# 初始化人格系统 # 初始化人格系统
self.personality = PersonalitySystem( self.personality = PersonalitySystem(
@@ -242,8 +243,11 @@ class AIClient:
) -> Message: ) -> Message:
"""处理工具调用。""" """处理工具调用。"""
messages.append(response) messages.append(response)
total_calls = len(response.tool_calls or [])
if total_calls:
logger.info(f"检测到工具调用请求: {total_calls}")
# 鎵ц宸ュ叿璋冪敤 # 执行工具调用
for tool_call in response.tool_calls or []: for tool_call in response.tool_calls or []:
try: try:
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
@@ -263,6 +267,7 @@ class AIClient:
continue continue
tool_def = self.tools.get(tool_name) tool_def = self.tools.get(tool_name)
tool_source = self._tool_sources.get(tool_name, "custom")
if not tool_def: if not tool_def:
error_msg = f"未找到工具: {tool_name}" error_msg = f"未找到工具: {tool_name}"
logger.warning(error_msg) logger.warning(error_msg)
@@ -275,9 +280,19 @@ class AIClient:
continue continue
try: try:
logger.info(
"工具调用开始: "
f"source={tool_source}, name={tool_name}, "
f"args={self._preview_log_payload(tool_args)}"
)
result = tool_def.function(**tool_args) result = tool_def.function(**tool_args)
if inspect.isawaitable(result): if inspect.isawaitable(result):
result = await result result = await result
logger.info(
"工具调用成功: "
f"source={tool_source}, name={tool_name}, "
f"result={self._preview_log_payload(result)}"
)
messages.append(Message( messages.append(Message(
role="tool", role="tool",
name=tool_name, name=tool_name,
@@ -285,6 +300,10 @@ class AIClient:
tool_call_id=tool_call_id tool_call_id=tool_call_id
)) ))
except Exception as e: except Exception as e:
logger.warning(
"工具调用失败: "
f"source={tool_source}, name={tool_name}, error={e}"
)
messages.append(Message( messages.append(Message(
role="tool", role="tool",
name=tool_name, name=tool_name,
@@ -334,6 +353,18 @@ class AIClient:
return parsed return parsed
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
@staticmethod
def _preview_log_payload(payload: Any, max_len: int = 240) -> str:
"""日志中展示参数/结果时使用的简短预览。"""
try:
text = json.dumps(payload, ensure_ascii=False, default=str)
except Exception:
text = str(payload)
if len(text) > max_len:
return text[:max_len] + "..."
return text
def set_personality(self, personality_name: str) -> bool: def set_personality(self, personality_name: str) -> bool:
"""设置人格。""" """设置人格。"""
@@ -382,7 +413,14 @@ class AIClient:
"""获取任务状态。""" """获取任务状态。"""
return self.task_manager.get_task_status(task_id) return self.task_manager.get_task_status(task_id)
def register_tool(self, name: str, description: str, parameters: Dict, function: callable): def register_tool(
self,
name: str,
description: str,
parameters: Dict,
function: callable,
source: str = "custom",
):
"""注册工具。""" """注册工具。"""
from .base import ToolDefinition from .base import ToolDefinition
tool = ToolDefinition( tool = ToolDefinition(
@@ -392,18 +430,23 @@ class AIClient:
function=function function=function
) )
self.tools.register(tool) self.tools.register(tool)
logger.info(f"已注册工具: {name}") self._tool_sources[name] = source
logger.info(f"已注册工具: {name} (source={source})")
def unregister_tool(self, name: str) -> bool: def unregister_tool(self, name: str) -> bool:
"""卸载工具。""" """卸载工具。"""
removed = self.tools.unregister(name) removed = self.tools.unregister(name)
if removed: if removed:
self._tool_sources.pop(name, None)
logger.info(f"已卸载工具: {name}") logger.info(f"已卸载工具: {name}")
return removed return removed
def unregister_tools_by_prefix(self, prefix: str) -> int: def unregister_tools_by_prefix(self, prefix: str) -> int:
"""按前缀批量卸载工具。""" """按前缀批量卸载工具。"""
removed_count = self.tools.unregister_by_prefix(prefix) removed_count = self.tools.unregister_by_prefix(prefix)
for tool_name in list(self._tool_sources.keys()):
if tool_name.startswith(prefix):
self._tool_sources.pop(tool_name, None)
if removed_count: if removed_count:
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
return removed_count return removed_count

View File

@@ -44,6 +44,7 @@ class MCPServer:
self.version = version self.version = version
self.resources: Dict[str, MCPResource] = {} self.resources: Dict[str, MCPResource] = {}
self.tools: Dict[str, Callable] = {} self.tools: Dict[str, Callable] = {}
self.tool_specs: Dict[str, MCPTool] = {}
self.prompts: Dict[str, MCPPrompt] = {} self.prompts: Dict[str, MCPPrompt] = {}
async def initialize(self): async def initialize(self):
@@ -61,6 +62,7 @@ class MCPServer:
def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable): def register_tool(self, name: str, description: str, input_schema: Dict, handler: Callable):
"""注册工具""" """注册工具"""
tool = MCPTool(name=name, description=description, input_schema=input_schema) tool = MCPTool(name=name, description=description, input_schema=input_schema)
self.tool_specs[name] = tool
self.tools[name] = handler self.tools[name] = handler
logger.info(f"✅ MCP工具注册: {self.name}.{name}") logger.info(f"✅ MCP工具注册: {self.name}.{name}")
@@ -78,10 +80,7 @@ class MCPServer:
async def list_tools(self) -> List[MCPTool]: async def list_tools(self) -> List[MCPTool]:
"""列出工具""" """列出工具"""
return [ return list(self.tool_specs.values())
MCPTool(name=name, description="", input_schema={})
for name in self.tools.keys()
]
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具""" """调用工具"""
@@ -89,7 +88,16 @@ class MCPServer:
raise ValueError(f"工具不存在: {name}") raise ValueError(f"工具不存在: {name}")
handler = self.tools[name] handler = self.tools[name]
return await handler(**arguments) logger.info(
f"MCP工具调用开始: server={self.name}, tool={name}, args={json.dumps(arguments, ensure_ascii=False)}"
)
try:
result = await handler(**arguments)
except Exception as exc:
logger.warning(f"MCP工具调用失败: server={self.name}, tool={name}, error={exc}")
raise
logger.info(f"MCP工具调用成功: server={self.name}, tool={name}")
return result
async def list_prompts(self) -> List[MCPPrompt]: async def list_prompts(self) -> List[MCPPrompt]:
"""列出提示词""" """列出提示词"""
@@ -216,4 +224,5 @@ class MCPManager:
raise ValueError(f"工具名格式错误: {full_tool_name}") raise ValueError(f"工具名格式错误: {full_tool_name}")
server_name, tool_name = parts server_name, tool_name = parts
logger.info(f"MCP执行请求: {full_tool_name}")
return await self.client.call_tool(server_name, tool_name, arguments) return await self.client.call_tool(server_name, tool_name, arguments)

View File

@@ -3,6 +3,7 @@ OpenAI model implementation (including OpenAI-compatible providers).
""" """
import inspect import inspect
import json import json
import re
from typing import Any, AsyncIterator, Dict, List, Optional from typing import Any, AsyncIterator, Dict, List, Optional
import httpx import httpx
@@ -20,6 +21,7 @@ class OpenAIModel(BaseAIModel):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.logger = logger self.logger = logger
self._embedding_token_limit: Optional[int] = None
http_client = httpx.AsyncClient( http_client = httpx.AsyncClient(
timeout=config.timeout, timeout=config.timeout,
@@ -317,6 +319,70 @@ class OpenAIModel(BaseAIModel):
or "maximum context length" in message or "maximum context length" in message
) )
@staticmethod
def _extract_embedding_token_limit(error: Exception) -> Optional[int]:
message = str(error).lower()
patterns = [
r"less than\s+(\d+)\s+tokens",
r"maximum context length.*?(\d+)\s+tokens",
r"max(?:imum)?(?: input)?(?: length)?\D+(\d+)\s+tokens",
]
for pattern in patterns:
match = re.search(pattern, message)
if not match:
continue
try:
limit = int(match.group(1))
except (TypeError, ValueError):
continue
if limit > 0:
return limit
return None
@staticmethod
def _estimate_embedding_tokens(text: str) -> int:
if not text:
return 0
ascii_count = 0
cjk_count = 0
other_count = 0
for ch in text:
code = ord(ch)
if code < 128:
ascii_count += 1
elif 0x4E00 <= code <= 0x9FFF or 0x3400 <= code <= 0x4DBF:
cjk_count += 1
else:
other_count += 1
estimated = cjk_count + int(other_count * 0.7) + int(ascii_count / 4)
return max(1, estimated)
def _apply_embedding_budget(self, text: str) -> str:
if not text or not self._embedding_token_limit:
return text
safe_limit = max(32, int(self._embedding_token_limit * 0.9))
estimated_tokens = self._estimate_embedding_tokens(text)
if estimated_tokens <= safe_limit:
return text
ratio = safe_limit / max(estimated_tokens, 1)
target_length = max(64, int(len(text) * ratio))
if target_length >= len(text):
target_length = max(64, len(text) - 1)
if target_length <= 0 or target_length >= len(text):
return text
head = target_length // 2
tail = target_length - head
return f"{text[:head]} {text[-tail:]}"
@staticmethod @staticmethod
def _shrink_text_for_embedding(text: str) -> str: def _shrink_text_for_embedding(text: str) -> str:
compact = " ".join((text or "").split()) compact = " ".join((text or "").split())
@@ -338,6 +404,7 @@ class OpenAIModel(BaseAIModel):
raw_text = str(text or "") raw_text = str(text or "")
candidate_text = raw_text.strip() or raw_text or " " candidate_text = raw_text.strip() or raw_text or " "
candidate_text = self._apply_embedding_budget(candidate_text)
retry_count = 0 retry_count = 0
while True: while True:
@@ -350,16 +417,32 @@ class OpenAIModel(BaseAIModel):
return response.data[0].embedding return response.data[0].embedding
except Exception as e: except Exception as e:
if self._is_embedding_too_long_error(e): if self._is_embedding_too_long_error(e):
next_text = self._shrink_text_for_embedding(candidate_text) token_limit = self._extract_embedding_token_limit(e)
if token_limit:
if not self._embedding_token_limit:
self._embedding_token_limit = token_limit
else:
self._embedding_token_limit = min(
self._embedding_token_limit, token_limit
)
next_text = self._apply_embedding_budget(
self._shrink_text_for_embedding(candidate_text)
)
if ( if (
next_text next_text
and len(next_text) < len(candidate_text) and len(next_text) < len(candidate_text)
and retry_count < 5 and retry_count < 5
): ):
retry_count += 1 retry_count += 1
self.logger.warning( limit_desc = (
"embedding input too long, retry with truncated text: " f", token_limit={self._embedding_token_limit}"
f"{len(candidate_text)} -> {len(next_text)}" if self._embedding_token_limit
else ""
)
self.logger.info(
"embedding input exceeded provider limit, retry with truncated text: "
f"{len(candidate_text)} -> {len(next_text)}{limit_desc}"
) )
candidate_text = next_text candidate_text = next_text
continue continue

View File

@@ -324,6 +324,43 @@ class MessageHandler:
description=f"技能工具: {full_tool_name}", description=f"技能工具: {full_tool_name}",
parameters={"type": "object", "properties": {}}, parameters={"type": "object", "properties": {}},
function=tool_func, function=tool_func,
source="skills",
)
count += 1
return count
async def _register_mcp_tools(self) -> int:
if not self.mcp_manager or not self.ai_client:
return 0
tools = await self.mcp_manager.get_all_tools_for_ai()
count = 0
for item in tools:
function_info = item.get("function") if isinstance(item, dict) else None
if not isinstance(function_info, dict):
continue
full_tool_name = function_info.get("name")
if not full_tool_name:
continue
parameters = function_info.get("parameters")
if not isinstance(parameters, dict):
parameters = {"type": "object", "properties": {}}
async def _mcp_proxy(_full_tool_name=full_tool_name, **kwargs):
if not self.mcp_manager:
raise RuntimeError("MCP manager not initialized")
return await self.mcp_manager.execute_tool(_full_tool_name, kwargs)
self.ai_client.register_tool(
name=full_tool_name,
description=f"MCP工具: {full_tool_name}",
parameters=parameters,
function=_mcp_proxy,
source="mcp",
) )
count += 1 count += 1
@@ -439,6 +476,8 @@ class MessageHandler:
self.mcp_manager = MCPManager(Path("config/mcp.json")) self.mcp_manager = MCPManager(Path("config/mcp.json"))
fs_server = FileSystemMCPServer(root_path=Path("data")) fs_server = FileSystemMCPServer(root_path=Path("data"))
await self.mcp_manager.register_server(fs_server) await self.mcp_manager.register_server(fs_server)
mcp_tool_count = await self._register_mcp_tools()
logger.info(f"MCP 工具注册完成: {mcp_tool_count} tools")
except Exception as exc: except Exception as exc:
logger.warning(f"MCP 初始化失败: {exc}") logger.warning(f"MCP 初始化失败: {exc}")

View File

@@ -0,0 +1,44 @@
import asyncio
from pathlib import Path
from src.ai.mcp.base import MCPManager, MCPServer
class _DummyMCPServer(MCPServer):
def __init__(self):
super().__init__(name="dummy", version="1.0.0")
async def initialize(self):
self.register_tool(
name="echo",
description="Echo text",
input_schema={
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"],
},
handler=self.echo,
)
async def echo(self, text: str) -> str:
return text
def test_mcp_manager_exports_tool_metadata_for_ai(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
tools = asyncio.run(manager.get_all_tools_for_ai())
assert len(tools) == 1
function_info = tools[0]["function"]
assert function_info["name"] == "dummy.echo"
assert function_info["description"] == "Echo text"
assert function_info["parameters"]["required"] == ["text"]
def test_mcp_manager_execute_tool(tmp_path: Path):
manager = MCPManager(tmp_path / "mcp.json")
asyncio.run(manager.register_server(_DummyMCPServer()))
result = asyncio.run(manager.execute_tool("dummy.echo", {"text": "hello"}))
assert result == "hello"

View File

@@ -43,6 +43,18 @@ class _FakeEmbeddings:
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]) return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
class _LengthLimitedEmbeddings:
def __init__(self):
self.inputs = []
async def create(self, **kwargs):
text = kwargs.get("input", "")
self.inputs.append(text)
if len(text) > 512:
raise RuntimeError("input must be less than 512 tokens")
return SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])])
class _ModernCompletions: class _ModernCompletions:
def __init__(self): def __init__(self):
self.last_params = None self.last_params = None
@@ -197,6 +209,13 @@ class _RuntimeRejectToolsAsyncOpenAI:
self.embeddings = _FakeEmbeddings() self.embeddings = _FakeEmbeddings()
class _LengthLimitedEmbedAsyncOpenAI:
def __init__(self, **kwargs):
self.completions = _ModernCompletions()
self.chat = SimpleNamespace(completions=self.completions)
self.embeddings = _LengthLimitedEmbeddings()
def test_openai_model_uses_tools_when_supported(monkeypatch): def test_openai_model_uses_tools_when_supported(monkeypatch):
monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI) monkeypatch.setattr(openai_model_module, "AsyncOpenAI", _ModernAsyncOpenAI)
@@ -276,3 +295,28 @@ def test_openai_model_retries_with_functions_when_tools_rejected(monkeypatch):
assert model._supports_tools is False assert model._supports_tools is False
assert result.tool_calls is not None assert result.tool_calls is not None
assert result.tool_calls[0]["function"]["name"] == "demo_tool" assert result.tool_calls[0]["function"]["name"] == "demo_tool"
def test_openai_model_learns_embedding_limit_and_pretruncates(monkeypatch):
monkeypatch.setattr(
openai_model_module, "AsyncOpenAI", _LengthLimitedEmbedAsyncOpenAI
)
model = OpenAIModel(_model_config())
long_text = "" * 600
first_embedding = asyncio.run(model.embed(long_text))
assert first_embedding == [0.1, 0.2, 0.3]
assert model._embedding_token_limit == 512
inputs_after_first = list(model.client.embeddings.inputs)
assert len(inputs_after_first) == 2
assert len(inputs_after_first[0]) == 600
assert len(inputs_after_first[1]) < len(inputs_after_first[0])
second_embedding = asyncio.run(model.embed(long_text))
assert second_embedding == [0.1, 0.2, 0.3]
inputs_after_second = list(model.client.embeddings.inputs)
assert len(inputs_after_second) == 3
assert len(inputs_after_second[-1]) < 512