477 lines
17 KiB
Python
477 lines
17 KiB
Python
"""
|
||
AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘
|
||
"""
|
||
import inspect
|
||
import json
|
||
import re
|
||
from typing import List, Optional, Dict, Any, AsyncIterator, Tuple
|
||
from pathlib import Path
|
||
from .base import ModelConfig, ModelProvider, Message, ToolRegistry
|
||
from .models import OpenAIModel, AnthropicModel
|
||
from .personality import PersonalitySystem
|
||
from .memory import MemorySystem
|
||
from .task_manager import LongTaskManager
|
||
from src.utils.logger import setup_logger
|
||
|
||
logger = setup_logger('AIClient')
|
||
|
||
|
||
class AIClient:
|
||
"""AI瀹㈡埛绔?- 缁熶竴鎺ュ彛"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_config: ModelConfig,
|
||
embed_config: Optional[ModelConfig] = None,
|
||
data_dir: Path = Path("data/ai"),
|
||
use_vector_db: bool = True
|
||
):
|
||
self.config = model_config
|
||
self.data_dir = data_dir
|
||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 初始化主模型
|
||
self.model = self._create_model(model_config)
|
||
|
||
# 初始化嵌入模型(如果提供)
|
||
self.embed_model = None
|
||
if embed_config:
|
||
self.embed_model = self._create_model(embed_config)
|
||
logger.info(
|
||
f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}"
|
||
)
|
||
|
||
# 初始化工具注册表
|
||
self.tools = ToolRegistry()
|
||
|
||
# 初始化人格系统
|
||
self.personality = PersonalitySystem(
|
||
config_path=data_dir / "personalities.json"
|
||
)
|
||
|
||
# 初始化记忆系统
|
||
self.memory = MemorySystem(
|
||
storage_path=data_dir / "long_term_memory.json",
|
||
embed_func=self._embed_wrapper,
|
||
importance_evaluator=self._evaluate_memory_importance,
|
||
use_vector_db=use_vector_db
|
||
)
|
||
|
||
# 初始化长任务管理器
|
||
self.task_manager = LongTaskManager(
|
||
storage_path=data_dir / "tasks.json"
|
||
)
|
||
|
||
logger.info(
|
||
f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}"
|
||
)
|
||
|
||
def _create_model(self, config: ModelConfig):
|
||
"""创建模型实例。"""
|
||
if config.provider == ModelProvider.OPENAI:
|
||
return OpenAIModel(config)
|
||
elif config.provider == ModelProvider.ANTHROPIC:
|
||
return AnthropicModel(config)
|
||
elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]:
|
||
# DeepSeek 和 Qwen 使用 OpenAI 兼容接口
|
||
return OpenAIModel(config)
|
||
else:
|
||
raise ValueError(f"不支持的模型提供商: {config.provider}")
|
||
|
||
async def _embed_wrapper(self, text: str) -> List[float]:
|
||
"""嵌入向量包装器。"""
|
||
try:
|
||
# 如果有独立的嵌入模型,优先使用
|
||
if self.embed_model:
|
||
return await self.embed_model.embed(text)
|
||
# 否则尝试使用主模型
|
||
return await self.model.embed(text)
|
||
except NotImplementedError:
|
||
# 如果都不支持嵌入,返回 None(记忆系统会降级)
|
||
logger.warning("Current model does not support embeddings; vector retrieval disabled")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"生成嵌入向量失败: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def _parse_importance_score(raw: str) -> float:
|
||
text = (raw or "").strip()
|
||
if not text:
|
||
raise ValueError("empty importance response")
|
||
|
||
try:
|
||
parsed = json.loads(text)
|
||
if isinstance(parsed, (int, float)):
|
||
return float(parsed)
|
||
if isinstance(parsed, dict):
|
||
for key in ["importance", "score", "value"]:
|
||
if key in parsed:
|
||
return float(parsed[key])
|
||
except Exception:
|
||
pass
|
||
|
||
match = re.search(r"-?\d+(?:\.\d+)?", text)
|
||
if not match:
|
||
raise ValueError(f"cannot parse importance score: {text}")
|
||
return float(match.group(0))
|
||
|
||
async def _evaluate_memory_importance(
|
||
self, content: str, metadata: Optional[Dict] = None
|
||
) -> float:
|
||
"""
|
||
调用主模型评估记忆重要性,返回 [0, 1] 分值。
|
||
"""
|
||
system_prompt = (
|
||
"你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。"
|
||
"输出一个 0 到 1 的数字,数字越大表示越重要。"
|
||
"只输出数字,不要输出任何解释、单位或多余文本。"
|
||
)
|
||
payload = json.dumps(
|
||
{"content": content, "metadata": metadata or {}},
|
||
ensure_ascii=False,
|
||
)
|
||
messages = [
|
||
Message(role="system", content=system_prompt),
|
||
Message(role="user", content=payload),
|
||
]
|
||
|
||
try:
|
||
response = await self.model.chat(
|
||
messages=messages,
|
||
tools=None,
|
||
temperature=0.0,
|
||
max_tokens=16,
|
||
)
|
||
score = self._parse_importance_score(response.content)
|
||
return max(0.0, min(1.0, score))
|
||
except Exception as e:
|
||
logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}")
|
||
return 0.5
|
||
|
||
async def chat(
|
||
self,
|
||
user_id: str,
|
||
user_message: str,
|
||
system_prompt: Optional[str] = None,
|
||
use_memory: bool = True,
|
||
use_tools: bool = True,
|
||
stream: bool = False,
|
||
**kwargs
|
||
) -> str:
|
||
"""对话接口。"""
|
||
try:
|
||
# 构建消息列表
|
||
messages = []
|
||
|
||
# 系统提示词
|
||
if system_prompt is None:
|
||
system_prompt = self.personality.get_system_prompt()
|
||
|
||
# 注入记忆上下文
|
||
if use_memory:
|
||
short_term, long_term = await self.memory.get_context(
|
||
user_id=user_id,
|
||
query=user_message
|
||
)
|
||
|
||
if short_term or long_term:
|
||
memory_context = self.memory.format_context(short_term, long_term)
|
||
system_prompt += f"\n\n{memory_context}"
|
||
|
||
messages.append(Message(role="system", content=system_prompt))
|
||
|
||
# 添加用户消息
|
||
messages.append(Message(role="user", content=user_message))
|
||
|
||
# 准备工具
|
||
tools = None
|
||
if use_tools and self.tools.list():
|
||
tools = self.tools.to_openai_format()
|
||
|
||
# 调用模型
|
||
if stream:
|
||
return self._chat_stream(messages, tools, **kwargs)
|
||
else:
|
||
response = await self.model.chat(messages, tools, **kwargs)
|
||
|
||
# 处理工具调用
|
||
if response.tool_calls:
|
||
response = await self._handle_tool_calls(
|
||
messages, response, tools, **kwargs
|
||
)
|
||
|
||
# 写入记忆
|
||
if use_memory:
|
||
stored_memory = await self.memory.add_qa_pair(
|
||
user_id=user_id,
|
||
question=user_message,
|
||
answer=response.content,
|
||
metadata={"source": "chat"},
|
||
)
|
||
if stored_memory:
|
||
logger.info(
|
||
"已写入长期记忆问答对:\n"
|
||
f"{stored_memory.content}\n"
|
||
f"memory_id={stored_memory.id}, "
|
||
f"importance={stored_memory.importance:.2f}"
|
||
)
|
||
|
||
return response.content
|
||
|
||
except Exception as e:
|
||
logger.error(f"对话失败: {e}")
|
||
raise
|
||
|
||
async def _chat_stream(
|
||
self,
|
||
messages: List[Message],
|
||
tools: Optional[List[Dict]],
|
||
**kwargs
|
||
) -> AsyncIterator[str]:
|
||
"""流式对话。"""
|
||
async for chunk in self.model.chat_stream(messages, tools, **kwargs):
|
||
yield chunk
|
||
|
||
async def _handle_tool_calls(
|
||
self,
|
||
messages: List[Message],
|
||
response: Message,
|
||
tools: Optional[List[Dict]],
|
||
**kwargs
|
||
) -> Message:
|
||
"""处理工具调用。"""
|
||
messages.append(response)
|
||
|
||
# 鎵ц宸ュ叿璋冪敤
|
||
for tool_call in response.tool_calls or []:
|
||
try:
|
||
tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call)
|
||
except Exception as e:
|
||
logger.warning(f"解析工具调用失败: {e}")
|
||
fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None)
|
||
if fallback_id:
|
||
messages.append(Message(
|
||
role="tool",
|
||
content=f"工具参数解析失败: {str(e)}",
|
||
tool_call_id=fallback_id,
|
||
name="tool"
|
||
))
|
||
continue
|
||
if not tool_name:
|
||
logger.warning(f"跳过无效工具调用: {tool_call}")
|
||
continue
|
||
|
||
tool_def = self.tools.get(tool_name)
|
||
if not tool_def:
|
||
error_msg = f"未找到工具: {tool_name}"
|
||
logger.warning(error_msg)
|
||
messages.append(Message(
|
||
role="tool",
|
||
name=tool_name,
|
||
content=error_msg,
|
||
tool_call_id=tool_call_id
|
||
))
|
||
continue
|
||
|
||
try:
|
||
result = tool_def.function(**tool_args)
|
||
if inspect.isawaitable(result):
|
||
result = await result
|
||
messages.append(Message(
|
||
role="tool",
|
||
name=tool_name,
|
||
content=str(result),
|
||
tool_call_id=tool_call_id
|
||
))
|
||
except Exception as e:
|
||
messages.append(Message(
|
||
role="tool",
|
||
name=tool_name,
|
||
content=f"工具执行失败: {str(e)}",
|
||
tool_call_id=tool_call_id
|
||
))
|
||
|
||
# 再次调用模型获取最终响应
|
||
return await self.model.chat(messages, tools, **kwargs)
|
||
|
||
def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]:
|
||
"""兼容不同 SDK 返回的工具调用结构。"""
|
||
if isinstance(tool_call, dict):
|
||
tool_call_id = tool_call.get('id')
|
||
function = tool_call.get('function') or {}
|
||
tool_name = function.get('name')
|
||
raw_args = function.get('arguments')
|
||
else:
|
||
tool_call_id = getattr(tool_call, 'id', None)
|
||
function = getattr(tool_call, 'function', None)
|
||
tool_name = getattr(function, 'name', None) if function else None
|
||
raw_args = getattr(function, 'arguments', None) if function else None
|
||
|
||
tool_args = self._normalize_tool_args(raw_args)
|
||
return tool_name, tool_args, tool_call_id
|
||
|
||
def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]:
|
||
"""将工具参数统一转换为字典。"""
|
||
if raw_args is None:
|
||
return {}
|
||
|
||
if isinstance(raw_args, dict):
|
||
return raw_args
|
||
|
||
if isinstance(raw_args, str):
|
||
raw_args = raw_args.strip()
|
||
if not raw_args:
|
||
return {}
|
||
parsed = json.loads(raw_args)
|
||
if not isinstance(parsed, dict):
|
||
raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}")
|
||
return parsed
|
||
|
||
if hasattr(raw_args, 'model_dump'):
|
||
parsed = raw_args.model_dump()
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
|
||
raise ValueError(f"不支持的工具参数类型: {type(raw_args)}")
|
||
|
||
def set_personality(self, personality_name: str) -> bool:
|
||
"""设置人格。"""
|
||
return self.personality.set_personality(personality_name)
|
||
|
||
def list_personalities(self) -> List[str]:
|
||
"""列出所有人格。"""
|
||
return self.personality.list_personalities()
|
||
|
||
def switch_model(self, model_config: ModelConfig) -> bool:
|
||
"""Runtime switch for primary chat model."""
|
||
new_model = self._create_model(model_config)
|
||
self.model = new_model
|
||
self.config = model_config
|
||
logger.info(
|
||
f"已切换主模型: {model_config.provider.value}/{model_config.model_name}"
|
||
)
|
||
return True
|
||
|
||
async def create_long_task(
|
||
self,
|
||
user_id: str,
|
||
title: str,
|
||
description: str,
|
||
steps: List[Dict],
|
||
metadata: Optional[Dict] = None
|
||
) -> str:
|
||
"""创建长任务。"""
|
||
return self.task_manager.create_task(
|
||
user_id=user_id,
|
||
title=title,
|
||
description=description,
|
||
steps=steps,
|
||
metadata=metadata
|
||
)
|
||
|
||
async def start_task(
|
||
self,
|
||
task_id: str,
|
||
progress_callback: Optional[callable] = None
|
||
):
|
||
"""启动任务。"""
|
||
await self.task_manager.start_task(task_id, progress_callback)
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
||
"""获取任务状态。"""
|
||
return self.task_manager.get_task_status(task_id)
|
||
|
||
def register_tool(self, name: str, description: str, parameters: Dict, function: callable):
|
||
"""注册工具。"""
|
||
from .base import ToolDefinition
|
||
tool = ToolDefinition(
|
||
name=name,
|
||
description=description,
|
||
parameters=parameters,
|
||
function=function
|
||
)
|
||
self.tools.register(tool)
|
||
logger.info(f"已注册工具: {name}")
|
||
|
||
def unregister_tool(self, name: str) -> bool:
|
||
"""卸载工具。"""
|
||
removed = self.tools.unregister(name)
|
||
if removed:
|
||
logger.info(f"已卸载工具: {name}")
|
||
return removed
|
||
|
||
def unregister_tools_by_prefix(self, prefix: str) -> int:
|
||
"""按前缀批量卸载工具。"""
|
||
removed_count = self.tools.unregister_by_prefix(prefix)
|
||
if removed_count:
|
||
logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}")
|
||
return removed_count
|
||
|
||
def clear_memory(self, user_id: str):
|
||
"""清除用户短期记忆。"""
|
||
self.memory.clear_short_term(user_id)
|
||
logger.info(f"Cleared short-term memory for user {user_id}")
|
||
|
||
async def clear_long_term_memory(self, user_id: str) -> bool:
|
||
try:
|
||
await self.memory.clear_long_term(user_id)
|
||
logger.info(f"Cleared long-term memory for user {user_id}")
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}")
|
||
return False
|
||
|
||
async def list_long_term_memories(self, user_id: str, limit: int = 20):
|
||
return await self.memory.list_long_term(user_id, limit=limit)
|
||
|
||
async def get_long_term_memory(self, user_id: str, memory_id: str):
|
||
return await self.memory.get_long_term(user_id, memory_id)
|
||
|
||
async def add_long_term_memory(
|
||
self,
|
||
user_id: str,
|
||
content: str,
|
||
importance: float = 0.8,
|
||
metadata: Optional[Dict] = None,
|
||
):
|
||
return await self.memory.add_long_term(
|
||
user_id=user_id,
|
||
content=content,
|
||
importance=importance,
|
||
metadata=metadata,
|
||
)
|
||
|
||
async def search_long_term_memories(
|
||
self, user_id: str, query: str, limit: int = 10
|
||
):
|
||
return await self.memory.search_long_term(user_id, query=query, limit=limit)
|
||
|
||
async def update_long_term_memory(
|
||
self,
|
||
user_id: str,
|
||
memory_id: str,
|
||
content: Optional[str] = None,
|
||
importance: Optional[float] = None,
|
||
metadata: Optional[Dict] = None,
|
||
):
|
||
return await self.memory.update_long_term(
|
||
user_id=user_id,
|
||
memory_id=memory_id,
|
||
content=content,
|
||
importance=importance,
|
||
metadata=metadata,
|
||
)
|
||
|
||
async def delete_long_term_memory(self, user_id: str, memory_id: str) -> bool:
|
||
return await self.memory.delete_long_term(user_id, memory_id)
|
||
|
||
async def clear_all_memory(self, user_id: str) -> bool:
|
||
"""清除用户全部记忆(短期 + 长期)。"""
|
||
self.clear_memory(user_id)
|
||
try:
|
||
return await self.clear_long_term_memory(user_id)
|
||
except Exception:
|
||
return False
|
||
|