Files
LocalAgent/llm/client.py
Mimikko-zeus 8a538bb950 feat: refactor API key configuration and enhance application initialization
- Renamed `check_environment` to `check_api_key_configured` for clarity, simplifying the API key validation logic.
- Removed the blocking behavior of the API key check during application startup, allowing the app to run while providing a prompt for configuration.
- Updated `LocalAgentApp` to accept an `api_configured` parameter, enabling conditional messaging for API key setup.
- Enhanced the `SandboxRunner` to support backup management and improved execution result handling with detailed metrics.
- Integrated data governance strategies into the `HistoryManager`, ensuring compliance and improved data management.
- Added privacy settings and metrics tracking across various components to enhance user experience and application safety.
2026-02-27 14:32:30 +08:00

581 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM 统一调用客户端
所有模型通过 SiliconFlow API 调用
支持流式和非流式两种模式
支持自动重试机制
"""
import os
import json
import time
import requests
from pathlib import Path
from typing import Optional, Generator, Callable, List, Dict, Any
from dotenv import load_dotenv
import logging
from datetime import datetime
# 获取项目根目录
PROJECT_ROOT = Path(__file__).parent.parent
ENV_PATH = PROJECT_ROOT / ".env"
# 配置日志目录
LOGS_DIR = PROJECT_ROOT / "workspace" / "logs"
LOGS_DIR.mkdir(parents=True, exist_ok=True)
# 配置日志记录器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# 创建文件处理器 - 按日期命名
log_file = LOGS_DIR / f"llm_calls_{datetime.now().strftime('%Y%m%d')}.log"
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.DEBUG)
# 设置日志格式
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
class LLMClientError(Exception):
"""LLM 客户端异常"""
# 异常类型分类
TYPE_NETWORK = "network" # 网络错误(超时、连接失败等)
TYPE_SERVER = "server" # 服务器错误5xx
TYPE_CLIENT = "client" # 客户端错误4xx
TYPE_PARSE = "parse" # 解析错误
TYPE_CONFIG = "config" # 配置错误
def __init__(self, message: str, error_type: str = TYPE_CLIENT, original_exception: Optional[Exception] = None):
super().__init__(message)
self.error_type = error_type
self.original_exception = original_exception
class LLMClient:
"""
统一的 LLM 调用客户端
使用方式:
client = LLMClient()
# 非流式调用
response = client.chat(
messages=[{"role": "user", "content": "你好"}],
model="Qwen/Qwen2.5-7B-Instruct"
)
# 流式调用
for chunk in client.chat_stream(
messages=[{"role": "user", "content": "你好"}],
model="Qwen/Qwen2.5-7B-Instruct"
):
print(chunk, end="", flush=True)
特性:
- 自动重试网络错误时自动重试默认3次
- 指数退避:重试间隔逐渐增加
"""
# 重试配置
DEFAULT_MAX_RETRIES = 3
DEFAULT_RETRY_DELAY = 1.0 # 初始重试延迟(秒)
DEFAULT_RETRY_BACKOFF = 2.0 # 退避倍数
def __init__(self, max_retries: int = DEFAULT_MAX_RETRIES):
load_dotenv(ENV_PATH)
self.api_url = os.getenv("LLM_API_URL")
self.api_key = os.getenv("LLM_API_KEY")
self.max_retries = max_retries
if not self.api_url:
raise LLMClientError("未配置 LLM_API_URL请检查 .env 文件", error_type=LLMClientError.TYPE_CONFIG)
if not self.api_key or self.api_key == "your_api_key_here":
raise LLMClientError("未配置有效的 LLM_API_KEY请检查 .env 文件", error_type=LLMClientError.TYPE_CONFIG)
def _should_retry(self, exception: Exception) -> bool:
"""
判断是否应该重试
可重试的异常类型:
- 网络错误(超时、连接失败)
- 服务器错误5xx
- 限流错误429
"""
# 直接的网络异常(理论上不应该到这里,但保留作为兜底)
if isinstance(exception, (requests.exceptions.ConnectionError,
requests.exceptions.Timeout)):
return True
# LLMClientError 根据错误类型判断
if isinstance(exception, LLMClientError):
# 网络错误和服务器错误可以重试
if exception.error_type in (LLMClientError.TYPE_NETWORK, LLMClientError.TYPE_SERVER):
return True
# 检查原始异常
if exception.original_exception:
if isinstance(exception.original_exception,
(requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ChunkedEncodingError)):
return True
return False
def _do_request_with_retry(
self,
request_func: Callable,
operation_name: str = "请求"
):
"""带重试的请求执行"""
last_exception = None
retry_count = 0
for attempt in range(self.max_retries + 1):
try:
result = request_func()
# 记录成功的请求(包括重试后成功)
if retry_count > 0:
try:
from llm.config_metrics import get_config_metrics
workspace = PROJECT_ROOT / "workspace"
if workspace.exists():
metrics = get_config_metrics(workspace)
metrics.record_retry_success(retry_count)
except:
pass
return result
except Exception as e:
last_exception = e
# 判断是否应该重试
if attempt < self.max_retries and self._should_retry(e):
retry_count += 1
delay = self.DEFAULT_RETRY_DELAY * (self.DEFAULT_RETRY_BACKOFF ** attempt)
# 记录重试信息
error_type = getattr(e, 'error_type', 'unknown') if isinstance(e, LLMClientError) else type(e).__name__
print(f"[重试] {operation_name}失败 (错误类型: {error_type}){delay:.1f}秒后重试 ({attempt + 1}/{self.max_retries})...")
# 记录重试次数到配置度量
try:
from llm.config_metrics import get_config_metrics
workspace = PROJECT_ROOT / "workspace"
if workspace.exists():
metrics = get_config_metrics(workspace)
metrics.increment_retry()
except:
pass # 度量记录失败不影响主流程
time.sleep(delay)
continue
else:
# 记录最终失败
if retry_count > 0:
try:
from llm.config_metrics import get_config_metrics
workspace = PROJECT_ROOT / "workspace"
if workspace.exists():
metrics = get_config_metrics(workspace)
metrics.record_retry_failure(retry_count)
except:
pass
raise
# 所有重试都失败
raise last_exception
def chat(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float = 0.7,
max_tokens: int = 1024,
timeout: int = 180
) -> str:
"""
调用 LLM 进行对话(非流式,带自动重试)
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大生成 token 数
timeout: 超时时间(秒),默认 180 秒
Returns:
LLM 生成的文本内容
"""
# 记录输入 - 完整内容不截断
logger.info("=" * 80)
logger.info(f"LLM 调用 [非流式] - 模型: {model}")
logger.info(f"参数: temperature={temperature}, max_tokens={max_tokens}, timeout={timeout}s")
logger.info(f"时间戳: {datetime.now().isoformat()}")
logger.info("-" * 80)
logger.info("输入消息:")
for i, msg in enumerate(messages):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
logger.info(f" [{i+1}] {role} ({len(content)} 字符):")
# 完整记录,不截断
for line in content.split('\n'):
logger.info(f" {line}")
logger.info("-" * 80)
def do_request():
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": False,
"temperature": temperature,
"max_tokens": max_tokens
}
# 记录请求详情
logger.debug(f"API URL: {self.api_url}")
logger.debug(f"请求 Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
start_time = time.time()
response = requests.post(
self.api_url,
headers=headers,
json=payload,
timeout=timeout
)
elapsed_time = time.time() - start_time
logger.info(f"请求耗时: {elapsed_time:.2f}")
except requests.exceptions.Timeout as e:
logger.error(f"请求超时: {timeout}")
raise LLMClientError(
f"请求超时({timeout}秒),请检查网络连接或稍后重试",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
except requests.exceptions.ConnectionError as e:
logger.error(f"网络连接失败: {str(e)}")
raise LLMClientError(
"网络连接失败,请检查网络设置",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
except requests.exceptions.RequestException as e:
logger.error(f"网络请求异常: {str(e)}")
raise LLMClientError(
f"网络请求异常: {str(e)}",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
# 记录响应状态
logger.debug(f"响应状态码: {response.status_code}")
if response.status_code != 200:
error_msg = f"API 返回错误 (状态码: {response.status_code})"
try:
error_detail = response.json()
logger.error(f"错误详情: {json.dumps(error_detail, ensure_ascii=False, indent=2)}")
if "error" in error_detail:
error_msg += f": {error_detail['error']}"
except:
logger.error(f"错误响应: {response.text[:500]}")
error_msg += f": {response.text[:200]}"
# 根据状态码确定错误类型
if response.status_code >= 500:
error_type = LLMClientError.TYPE_SERVER
elif response.status_code == 429:
error_type = LLMClientError.TYPE_SERVER # 限流也可重试
else:
error_type = LLMClientError.TYPE_CLIENT
raise LLMClientError(error_msg, error_type=error_type)
try:
result = response.json()
content = result["choices"][0]["message"]["content"]
# 记录输出 - 完整内容不截断
logger.info("输出响应:")
logger.info(f" 长度: {len(content)} 字符")
for line in content.split('\n'):
logger.info(f" {line}")
logger.info("=" * 80)
return content
except (KeyError, IndexError, TypeError) as e:
logger.error(f"解析 API 响应失败: {str(e)}")
logger.error(f"原始响应: {response.text[:1000]}")
raise LLMClientError(
f"解析 API 响应失败: {str(e)}",
error_type=LLMClientError.TYPE_PARSE
)
return self._do_request_with_retry(do_request, "LLM调用")
def chat_stream(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float = 0.7,
max_tokens: int = 2048,
timeout: int = 180
) -> Generator[str, None, None]:
"""
调用 LLM 进行对话(流式,带自动重试)
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大生成 token 数
timeout: 超时时间(秒)
Yields:
逐个返回生成的文本片段
"""
# 记录输入 - 完整内容不截断
logger.info("=" * 80)
logger.info(f"LLM 调用 [流式] - 模型: {model}")
logger.info(f"参数: temperature={temperature}, max_tokens={max_tokens}, timeout={timeout}s")
logger.info(f"时间戳: {datetime.now().isoformat()}")
logger.info("-" * 80)
logger.info("输入消息:")
for i, msg in enumerate(messages):
role = msg.get('role', 'unknown')
content = msg.get('content', '')
logger.info(f" [{i+1}] {role} ({len(content)} 字符):")
# 完整记录,不截断
for line in content.split('\n'):
logger.info(f" {line}")
logger.info("-" * 80)
logger.info("开始接收流式输出...")
def do_request():
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": True,
"temperature": temperature,
"max_tokens": max_tokens
}
# 记录请求详情
logger.debug(f"API URL: {self.api_url}")
logger.debug(f"请求 Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
start_time = time.time()
response = requests.post(
self.api_url,
headers=headers,
json=payload,
timeout=timeout,
stream=True
)
elapsed_time = time.time() - start_time
logger.info(f"连接建立耗时: {elapsed_time:.2f}")
except requests.exceptions.Timeout as e:
logger.error(f"请求超时: {timeout}")
raise LLMClientError(
f"请求超时({timeout}秒),请检查网络连接或稍后重试",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
except requests.exceptions.ConnectionError as e:
logger.error(f"网络连接失败: {str(e)}")
raise LLMClientError(
"网络连接失败,请检查网络设置",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
except requests.exceptions.RequestException as e:
logger.error(f"网络请求异常: {str(e)}")
raise LLMClientError(
f"网络请求异常: {str(e)}",
error_type=LLMClientError.TYPE_NETWORK,
original_exception=e
)
# 记录响应状态
logger.debug(f"响应状态码: {response.status_code}")
if response.status_code != 200:
error_msg = f"API 返回错误 (状态码: {response.status_code})"
try:
error_detail = response.json()
logger.error(f"错误详情: {json.dumps(error_detail, ensure_ascii=False, indent=2)}")
if "error" in error_detail:
error_msg += f": {error_detail['error']}"
except:
logger.error(f"错误响应: {response.text[:500]}")
error_msg += f": {response.text[:200]}"
# 根据状态码确定错误类型
if response.status_code >= 500:
error_type = LLMClientError.TYPE_SERVER
elif response.status_code == 429:
error_type = LLMClientError.TYPE_SERVER # 限流也可重试
else:
error_type = LLMClientError.TYPE_CLIENT
raise LLMClientError(error_msg, error_type=error_type)
return response
# 流式请求的重试只在建立连接阶段
response = self._do_request_with_retry(do_request, "流式LLM调用")
# 收集完整输出用于日志
full_output = []
# 解析 SSE 流
try:
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = line[6:] # 去掉 "data: " 前缀
if data == '[DONE]':
break
try:
chunk = json.loads(data)
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
full_output.append(content)
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"流式输出异常: {str(e)}")
raise
# 记录完整输出 - 不截断
complete_output = ''.join(full_output)
logger.info("流式输出完成:")
logger.info(f" 总长度: {len(complete_output)} 字符")
for line in complete_output.split('\n'):
logger.info(f" {line}")
logger.info("=" * 80)
def chat_stream_collect(
self,
messages: List[Dict[str, str]],
model: str,
temperature: float = 0.7,
max_tokens: int = 2048,
timeout: int = 180,
on_chunk: Optional[Callable[[str], None]] = None
) -> str:
"""
流式调用并收集完整结果
Args:
messages: 消息列表
model: 模型名称
temperature: 温度参数
max_tokens: 最大生成 token 数
timeout: 超时时间(秒)
on_chunk: 每收到一个片段时的回调函数
Returns:
完整的生成文本
"""
full_content = []
for chunk in self.chat_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout
):
full_content.append(chunk)
if on_chunk:
on_chunk(chunk)
return ''.join(full_content)
# 全局单例(延迟初始化)
_client: Optional[LLMClient] = None
def get_client() -> LLMClient:
"""获取 LLM 客户端单例"""
global _client
if _client is None:
_client = LLMClient()
return _client
def reset_client() -> None:
"""重置 LLM 客户端单例(配置变更后调用)"""
global _client
_client = None
def test_connection(timeout: int = 10) -> tuple[bool, str]:
"""
测试 API 连接是否正常
Args:
timeout: 超时时间(秒)
Returns:
(是否成功, 消息)
"""
try:
client = get_client()
# 发送简单的测试请求
response = client.chat(
messages=[{"role": "user", "content": "hi"}],
model=os.getenv("INTENT_MODEL_NAME") or "Qwen/Qwen2.5-7B-Instruct",
temperature=0.1,
max_tokens=10,
timeout=timeout
)
return (True, "连接成功")
except LLMClientError as e:
error_msg = str(e)
if "未配置" in error_msg or "API Key" in error_msg:
return (False, f"配置错误: {error_msg}")
elif "状态码: 401" in error_msg or "Unauthorized" in error_msg:
return (False, "API Key 无效,请检查配置")
elif "状态码: 403" in error_msg:
return (False, "API Key 权限不足")
elif "状态码: 404" in error_msg:
return (False, "API 地址错误或模型不存在")
elif "网络连接失败" in error_msg:
return (False, "网络连接失败,请检查网络设置")
elif "请求超时" in error_msg:
return (False, f"连接超时({timeout}秒),请检查网络或稍后重试")
else:
return (False, f"连接失败: {error_msg}")
except Exception as e:
return (False, f"未知错误: {str(e)}")