""" AI瀹㈡埛绔?- 鏁村悎鎵€鏈堿I鍔熻兘 """ import inspect import json import re from typing import List, Optional, Dict, Any, AsyncIterator, Tuple from pathlib import Path from .base import ModelConfig, ModelProvider, Message, ToolRegistry from .models import OpenAIModel, AnthropicModel from .personality import PersonalitySystem from .memory import MemorySystem from .task_manager import LongTaskManager from src.utils.logger import setup_logger logger = setup_logger('AIClient') class AIClient: """AI瀹㈡埛绔?- 缁熶竴鎺ュ彛""" def __init__( self, model_config: ModelConfig, embed_config: Optional[ModelConfig] = None, data_dir: Path = Path("data/ai"), use_vector_db: bool = True ): self.config = model_config self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) # 初始化主模型 self.model = self._create_model(model_config) # 初始化嵌入模型(如果提供) self.embed_model = None if embed_config: self.embed_model = self._create_model(embed_config) logger.info( f"嵌入模型初始化完成: {embed_config.provider.value}/{embed_config.model_name}" ) # 初始化工具注册表 self.tools = ToolRegistry() # 初始化人格系统 self.personality = PersonalitySystem( config_path=data_dir / "personalities.json" ) # 初始化记忆系统 self.memory = MemorySystem( storage_path=data_dir / "long_term_memory.json", embed_func=self._embed_wrapper, importance_evaluator=self._evaluate_memory_importance, use_vector_db=use_vector_db ) # 初始化长任务管理器 self.task_manager = LongTaskManager( storage_path=data_dir / "tasks.json" ) logger.info( f"AI 客户端初始化完成: {model_config.provider.value}/{model_config.model_name}" ) def _create_model(self, config: ModelConfig): """创建模型实例。""" if config.provider == ModelProvider.OPENAI: return OpenAIModel(config) elif config.provider == ModelProvider.ANTHROPIC: return AnthropicModel(config) elif config.provider in [ModelProvider.DEEPSEEK, ModelProvider.QWEN]: # DeepSeek 和 Qwen 使用 OpenAI 兼容接口 return OpenAIModel(config) else: raise ValueError(f"不支持的模型提供商: {config.provider}") async def _embed_wrapper(self, text: str) -> List[float]: """嵌入向量包装器。""" try: # 如果有独立的嵌入模型,优先使用 if self.embed_model: return await self.embed_model.embed(text) # 否则尝试使用主模型 return await self.model.embed(text) except NotImplementedError: # 如果都不支持嵌入,返回 None(记忆系统会降级) logger.warning("Current model does not support embeddings; vector retrieval disabled") return None except Exception as e: logger.error(f"生成嵌入向量失败: {e}") return None @staticmethod def _parse_importance_score(raw: str) -> float: text = (raw or "").strip() if not text: raise ValueError("empty importance response") try: parsed = json.loads(text) if isinstance(parsed, (int, float)): return float(parsed) if isinstance(parsed, dict): for key in ["importance", "score", "value"]: if key in parsed: return float(parsed[key]) except Exception: pass match = re.search(r"-?\d+(?:\.\d+)?", text) if not match: raise ValueError(f"cannot parse importance score: {text}") return float(match.group(0)) async def _evaluate_memory_importance( self, content: str, metadata: Optional[Dict] = None ) -> float: """ 调用主模型评估记忆重要性,返回 [0, 1] 分值。 """ system_prompt = ( "你是记忆重要性评估器。请根据输入内容判断该信息是否值得长期记忆。" "输出一个 0 到 1 的数字,数字越大表示越重要。" "只输出数字,不要输出任何解释、单位或多余文本。" ) payload = json.dumps( {"content": content, "metadata": metadata or {}}, ensure_ascii=False, ) messages = [ Message(role="system", content=system_prompt), Message(role="user", content=payload), ] try: response = await self.model.chat( messages=messages, tools=None, temperature=0.0, max_tokens=16, ) score = self._parse_importance_score(response.content) return max(0.0, min(1.0, score)) except Exception as e: logger.warning(f"memory importance evaluation failed, fallback to neutral score: {e}") return 0.5 async def chat( self, user_id: str, user_message: str, system_prompt: Optional[str] = None, use_memory: bool = True, use_tools: bool = True, stream: bool = False, **kwargs ) -> str: """对话接口。""" try: # 构建消息列表 messages = [] # 系统提示词 if system_prompt is None: system_prompt = self.personality.get_system_prompt() # 注入记忆上下文 if use_memory: short_term, long_term = await self.memory.get_context( user_id=user_id, query=user_message ) if short_term or long_term: memory_context = self.memory.format_context(short_term, long_term) system_prompt += f"\n\n{memory_context}" messages.append(Message(role="system", content=system_prompt)) # 添加用户消息 messages.append(Message(role="user", content=user_message)) # 准备工具 tools = None if use_tools and self.tools.list(): tools = self.tools.to_openai_format() # 调用模型 if stream: return self._chat_stream(messages, tools, **kwargs) else: response = await self.model.chat(messages, tools, **kwargs) # 处理工具调用 if response.tool_calls: response = await self._handle_tool_calls( messages, response, tools, **kwargs ) # 写入记忆 if use_memory: stored_memory = await self.memory.add_qa_pair( user_id=user_id, question=user_message, answer=response.content, metadata={"source": "chat"}, ) if stored_memory: logger.info( "已写入长期记忆问答对:\n" f"{stored_memory.content}\n" f"memory_id={stored_memory.id}, " f"importance={stored_memory.importance:.2f}" ) return response.content except Exception as e: logger.error(f"对话失败: {e}") raise async def _chat_stream( self, messages: List[Message], tools: Optional[List[Dict]], **kwargs ) -> AsyncIterator[str]: """流式对话。""" async for chunk in self.model.chat_stream(messages, tools, **kwargs): yield chunk async def _handle_tool_calls( self, messages: List[Message], response: Message, tools: Optional[List[Dict]], **kwargs ) -> Message: """处理工具调用。""" messages.append(response) # 鎵ц宸ュ叿璋冪敤 for tool_call in response.tool_calls or []: try: tool_name, tool_args, tool_call_id = self._parse_tool_call(tool_call) except Exception as e: logger.warning(f"解析工具调用失败: {e}") fallback_id = tool_call.get('id') if isinstance(tool_call, dict) else getattr(tool_call, 'id', None) if fallback_id: messages.append(Message( role="tool", content=f"工具参数解析失败: {str(e)}", tool_call_id=fallback_id, name="tool" )) continue if not tool_name: logger.warning(f"跳过无效工具调用: {tool_call}") continue tool_def = self.tools.get(tool_name) if not tool_def: error_msg = f"未找到工具: {tool_name}" logger.warning(error_msg) messages.append(Message( role="tool", name=tool_name, content=error_msg, tool_call_id=tool_call_id )) continue try: result = tool_def.function(**tool_args) if inspect.isawaitable(result): result = await result messages.append(Message( role="tool", name=tool_name, content=str(result), tool_call_id=tool_call_id )) except Exception as e: messages.append(Message( role="tool", name=tool_name, content=f"工具执行失败: {str(e)}", tool_call_id=tool_call_id )) # 再次调用模型获取最终响应 return await self.model.chat(messages, tools, **kwargs) def _parse_tool_call(self, tool_call: Any) -> Tuple[Optional[str], Dict[str, Any], Optional[str]]: """兼容不同 SDK 返回的工具调用结构。""" if isinstance(tool_call, dict): tool_call_id = tool_call.get('id') function = tool_call.get('function') or {} tool_name = function.get('name') raw_args = function.get('arguments') else: tool_call_id = getattr(tool_call, 'id', None) function = getattr(tool_call, 'function', None) tool_name = getattr(function, 'name', None) if function else None raw_args = getattr(function, 'arguments', None) if function else None tool_args = self._normalize_tool_args(raw_args) return tool_name, tool_args, tool_call_id def _normalize_tool_args(self, raw_args: Any) -> Dict[str, Any]: """将工具参数统一转换为字典。""" if raw_args is None: return {} if isinstance(raw_args, dict): return raw_args if isinstance(raw_args, str): raw_args = raw_args.strip() if not raw_args: return {} parsed = json.loads(raw_args) if not isinstance(parsed, dict): raise ValueError(f"工具参数必须是 JSON 对象,实际类型: {type(parsed)}") return parsed if hasattr(raw_args, 'model_dump'): parsed = raw_args.model_dump() if isinstance(parsed, dict): return parsed raise ValueError(f"不支持的工具参数类型: {type(raw_args)}") def set_personality(self, personality_name: str) -> bool: """设置人格。""" return self.personality.set_personality(personality_name) def list_personalities(self) -> List[str]: """列出所有人格。""" return self.personality.list_personalities() def switch_model(self, model_config: ModelConfig) -> bool: """Runtime switch for primary chat model.""" new_model = self._create_model(model_config) self.model = new_model self.config = model_config logger.info( f"已切换主模型: {model_config.provider.value}/{model_config.model_name}" ) return True async def create_long_task( self, user_id: str, title: str, description: str, steps: List[Dict], metadata: Optional[Dict] = None ) -> str: """创建长任务。""" return self.task_manager.create_task( user_id=user_id, title=title, description=description, steps=steps, metadata=metadata ) async def start_task( self, task_id: str, progress_callback: Optional[callable] = None ): """启动任务。""" await self.task_manager.start_task(task_id, progress_callback) def get_task_status(self, task_id: str) -> Optional[Dict]: """获取任务状态。""" return self.task_manager.get_task_status(task_id) def register_tool(self, name: str, description: str, parameters: Dict, function: callable): """注册工具。""" from .base import ToolDefinition tool = ToolDefinition( name=name, description=description, parameters=parameters, function=function ) self.tools.register(tool) logger.info(f"已注册工具: {name}") def unregister_tool(self, name: str) -> bool: """卸载工具。""" removed = self.tools.unregister(name) if removed: logger.info(f"已卸载工具: {name}") return removed def unregister_tools_by_prefix(self, prefix: str) -> int: """按前缀批量卸载工具。""" removed_count = self.tools.unregister_by_prefix(prefix) if removed_count: logger.info(f"Unregistered tools by prefix {prefix}: {removed_count}") return removed_count def clear_memory(self, user_id: str): """清除用户短期记忆。""" self.memory.clear_short_term(user_id) logger.info(f"Cleared short-term memory for user {user_id}") async def clear_long_term_memory(self, user_id: str) -> bool: try: await self.memory.clear_long_term(user_id) logger.info(f"Cleared long-term memory for user {user_id}") return True except Exception as e: logger.warning(f"Failed to clear long-term memory for user {user_id}: {e}") return False async def list_long_term_memories(self, user_id: str, limit: int = 20): return await self.memory.list_long_term(user_id, limit=limit) async def get_long_term_memory(self, user_id: str, memory_id: str): return await self.memory.get_long_term(user_id, memory_id) async def add_long_term_memory( self, user_id: str, content: str, importance: float = 0.8, metadata: Optional[Dict] = None, ): return await self.memory.add_long_term( user_id=user_id, content=content, importance=importance, metadata=metadata, ) async def search_long_term_memories( self, user_id: str, query: str, limit: int = 10 ): return await self.memory.search_long_term(user_id, query=query, limit=limit) async def update_long_term_memory( self, user_id: str, memory_id: str, content: Optional[str] = None, importance: Optional[float] = None, metadata: Optional[Dict] = None, ): return await self.memory.update_long_term( user_id=user_id, memory_id=memory_id, content=content, importance=importance, metadata=metadata, ) async def delete_long_term_memory(self, user_id: str, memory_id: str) -> bool: return await self.memory.delete_long_term(user_id, memory_id) async def clear_all_memory(self, user_id: str) -> bool: """清除用户全部记忆(短期 + 长期)。""" self.clear_memory(user_id) try: return await self.clear_long_term_memory(user_id) except Exception: return False