- Renamed `check_environment` to `check_api_key_configured` for clarity, simplifying the API key validation logic. - Removed the blocking behavior of the API key check during application startup, allowing the app to run while providing a prompt for configuration. - Updated `LocalAgentApp` to accept an `api_configured` parameter, enabling conditional messaging for API key setup. - Enhanced the `SandboxRunner` to support backup management and improved execution result handling with detailed metrics. - Integrated data governance strategies into the `HistoryManager`, ensuring compliance and improved data management. - Added privacy settings and metrics tracking across various components to enhance user experience and application safety.
581 lines
22 KiB
Python
581 lines
22 KiB
Python
"""
|
||
LLM 统一调用客户端
|
||
所有模型通过 SiliconFlow API 调用
|
||
支持流式和非流式两种模式
|
||
支持自动重试机制
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import requests
|
||
from pathlib import Path
|
||
from typing import Optional, Generator, Callable, List, Dict, Any
|
||
from dotenv import load_dotenv
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
# 获取项目根目录
|
||
PROJECT_ROOT = Path(__file__).parent.parent
|
||
ENV_PATH = PROJECT_ROOT / ".env"
|
||
|
||
# 配置日志目录
|
||
LOGS_DIR = PROJECT_ROOT / "workspace" / "logs"
|
||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 配置日志记录器
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
# 创建文件处理器 - 按日期命名
|
||
log_file = LOGS_DIR / f"llm_calls_{datetime.now().strftime('%Y%m%d')}.log"
|
||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||
file_handler.setLevel(logging.DEBUG)
|
||
|
||
# 设置日志格式
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler.setFormatter(formatter)
|
||
|
||
# 添加处理器
|
||
logger.addHandler(file_handler)
|
||
|
||
|
||
class LLMClientError(Exception):
|
||
"""LLM 客户端异常"""
|
||
|
||
# 异常类型分类
|
||
TYPE_NETWORK = "network" # 网络错误(超时、连接失败等)
|
||
TYPE_SERVER = "server" # 服务器错误(5xx)
|
||
TYPE_CLIENT = "client" # 客户端错误(4xx)
|
||
TYPE_PARSE = "parse" # 解析错误
|
||
TYPE_CONFIG = "config" # 配置错误
|
||
|
||
def __init__(self, message: str, error_type: str = TYPE_CLIENT, original_exception: Optional[Exception] = None):
|
||
super().__init__(message)
|
||
self.error_type = error_type
|
||
self.original_exception = original_exception
|
||
|
||
|
||
class LLMClient:
|
||
"""
|
||
统一的 LLM 调用客户端
|
||
|
||
使用方式:
|
||
client = LLMClient()
|
||
|
||
# 非流式调用
|
||
response = client.chat(
|
||
messages=[{"role": "user", "content": "你好"}],
|
||
model="Qwen/Qwen2.5-7B-Instruct"
|
||
)
|
||
|
||
# 流式调用
|
||
for chunk in client.chat_stream(
|
||
messages=[{"role": "user", "content": "你好"}],
|
||
model="Qwen/Qwen2.5-7B-Instruct"
|
||
):
|
||
print(chunk, end="", flush=True)
|
||
|
||
特性:
|
||
- 自动重试:网络错误时自动重试(默认3次)
|
||
- 指数退避:重试间隔逐渐增加
|
||
"""
|
||
|
||
# 重试配置
|
||
DEFAULT_MAX_RETRIES = 3
|
||
DEFAULT_RETRY_DELAY = 1.0 # 初始重试延迟(秒)
|
||
DEFAULT_RETRY_BACKOFF = 2.0 # 退避倍数
|
||
|
||
def __init__(self, max_retries: int = DEFAULT_MAX_RETRIES):
|
||
load_dotenv(ENV_PATH)
|
||
|
||
self.api_url = os.getenv("LLM_API_URL")
|
||
self.api_key = os.getenv("LLM_API_KEY")
|
||
self.max_retries = max_retries
|
||
|
||
if not self.api_url:
|
||
raise LLMClientError("未配置 LLM_API_URL,请检查 .env 文件", error_type=LLMClientError.TYPE_CONFIG)
|
||
if not self.api_key or self.api_key == "your_api_key_here":
|
||
raise LLMClientError("未配置有效的 LLM_API_KEY,请检查 .env 文件", error_type=LLMClientError.TYPE_CONFIG)
|
||
|
||
def _should_retry(self, exception: Exception) -> bool:
|
||
"""
|
||
判断是否应该重试
|
||
|
||
可重试的异常类型:
|
||
- 网络错误(超时、连接失败)
|
||
- 服务器错误(5xx)
|
||
- 限流错误(429)
|
||
"""
|
||
# 直接的网络异常(理论上不应该到这里,但保留作为兜底)
|
||
if isinstance(exception, (requests.exceptions.ConnectionError,
|
||
requests.exceptions.Timeout)):
|
||
return True
|
||
|
||
# LLMClientError 根据错误类型判断
|
||
if isinstance(exception, LLMClientError):
|
||
# 网络错误和服务器错误可以重试
|
||
if exception.error_type in (LLMClientError.TYPE_NETWORK, LLMClientError.TYPE_SERVER):
|
||
return True
|
||
|
||
# 检查原始异常
|
||
if exception.original_exception:
|
||
if isinstance(exception.original_exception,
|
||
(requests.exceptions.ConnectionError,
|
||
requests.exceptions.Timeout,
|
||
requests.exceptions.ChunkedEncodingError)):
|
||
return True
|
||
|
||
return False
|
||
|
||
def _do_request_with_retry(
|
||
self,
|
||
request_func: Callable,
|
||
operation_name: str = "请求"
|
||
):
|
||
"""带重试的请求执行"""
|
||
last_exception = None
|
||
retry_count = 0
|
||
|
||
for attempt in range(self.max_retries + 1):
|
||
try:
|
||
result = request_func()
|
||
|
||
# 记录成功的请求(包括重试后成功)
|
||
if retry_count > 0:
|
||
try:
|
||
from llm.config_metrics import get_config_metrics
|
||
workspace = PROJECT_ROOT / "workspace"
|
||
if workspace.exists():
|
||
metrics = get_config_metrics(workspace)
|
||
metrics.record_retry_success(retry_count)
|
||
except:
|
||
pass
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
last_exception = e
|
||
|
||
# 判断是否应该重试
|
||
if attempt < self.max_retries and self._should_retry(e):
|
||
retry_count += 1
|
||
delay = self.DEFAULT_RETRY_DELAY * (self.DEFAULT_RETRY_BACKOFF ** attempt)
|
||
|
||
# 记录重试信息
|
||
error_type = getattr(e, 'error_type', 'unknown') if isinstance(e, LLMClientError) else type(e).__name__
|
||
print(f"[重试] {operation_name}失败 (错误类型: {error_type}),{delay:.1f}秒后重试 ({attempt + 1}/{self.max_retries})...")
|
||
|
||
# 记录重试次数到配置度量
|
||
try:
|
||
from llm.config_metrics import get_config_metrics
|
||
workspace = PROJECT_ROOT / "workspace"
|
||
if workspace.exists():
|
||
metrics = get_config_metrics(workspace)
|
||
metrics.increment_retry()
|
||
except:
|
||
pass # 度量记录失败不影响主流程
|
||
|
||
time.sleep(delay)
|
||
continue
|
||
else:
|
||
# 记录最终失败
|
||
if retry_count > 0:
|
||
try:
|
||
from llm.config_metrics import get_config_metrics
|
||
workspace = PROJECT_ROOT / "workspace"
|
||
if workspace.exists():
|
||
metrics = get_config_metrics(workspace)
|
||
metrics.record_retry_failure(retry_count)
|
||
except:
|
||
pass
|
||
raise
|
||
|
||
# 所有重试都失败
|
||
raise last_exception
|
||
|
||
def chat(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 1024,
|
||
timeout: int = 180
|
||
) -> str:
|
||
"""
|
||
调用 LLM 进行对话(非流式,带自动重试)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒),默认 180 秒
|
||
|
||
Returns:
|
||
LLM 生成的文本内容
|
||
"""
|
||
# 记录输入 - 完整内容不截断
|
||
logger.info("=" * 80)
|
||
logger.info(f"LLM 调用 [非流式] - 模型: {model}")
|
||
logger.info(f"参数: temperature={temperature}, max_tokens={max_tokens}, timeout={timeout}s")
|
||
logger.info(f"时间戳: {datetime.now().isoformat()}")
|
||
logger.info("-" * 80)
|
||
logger.info("输入消息:")
|
||
for i, msg in enumerate(messages):
|
||
role = msg.get('role', 'unknown')
|
||
content = msg.get('content', '')
|
||
logger.info(f" [{i+1}] {role} ({len(content)} 字符):")
|
||
# 完整记录,不截断
|
||
for line in content.split('\n'):
|
||
logger.info(f" {line}")
|
||
logger.info("-" * 80)
|
||
|
||
def do_request():
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": False,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
# 记录请求详情
|
||
logger.debug(f"API URL: {self.api_url}")
|
||
logger.debug(f"请求 Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||
|
||
try:
|
||
start_time = time.time()
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=timeout
|
||
)
|
||
elapsed_time = time.time() - start_time
|
||
logger.info(f"请求耗时: {elapsed_time:.2f}秒")
|
||
except requests.exceptions.Timeout as e:
|
||
logger.error(f"请求超时: {timeout}秒")
|
||
raise LLMClientError(
|
||
f"请求超时({timeout}秒),请检查网络连接或稍后重试",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
except requests.exceptions.ConnectionError as e:
|
||
logger.error(f"网络连接失败: {str(e)}")
|
||
raise LLMClientError(
|
||
"网络连接失败,请检查网络设置",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"网络请求异常: {str(e)}")
|
||
raise LLMClientError(
|
||
f"网络请求异常: {str(e)}",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
|
||
# 记录响应状态
|
||
logger.debug(f"响应状态码: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||
try:
|
||
error_detail = response.json()
|
||
logger.error(f"错误详情: {json.dumps(error_detail, ensure_ascii=False, indent=2)}")
|
||
if "error" in error_detail:
|
||
error_msg += f": {error_detail['error']}"
|
||
except:
|
||
logger.error(f"错误响应: {response.text[:500]}")
|
||
error_msg += f": {response.text[:200]}"
|
||
|
||
# 根据状态码确定错误类型
|
||
if response.status_code >= 500:
|
||
error_type = LLMClientError.TYPE_SERVER
|
||
elif response.status_code == 429:
|
||
error_type = LLMClientError.TYPE_SERVER # 限流也可重试
|
||
else:
|
||
error_type = LLMClientError.TYPE_CLIENT
|
||
|
||
raise LLMClientError(error_msg, error_type=error_type)
|
||
|
||
try:
|
||
result = response.json()
|
||
content = result["choices"][0]["message"]["content"]
|
||
|
||
# 记录输出 - 完整内容不截断
|
||
logger.info("输出响应:")
|
||
logger.info(f" 长度: {len(content)} 字符")
|
||
for line in content.split('\n'):
|
||
logger.info(f" {line}")
|
||
logger.info("=" * 80)
|
||
|
||
return content
|
||
except (KeyError, IndexError, TypeError) as e:
|
||
logger.error(f"解析 API 响应失败: {str(e)}")
|
||
logger.error(f"原始响应: {response.text[:1000]}")
|
||
raise LLMClientError(
|
||
f"解析 API 响应失败: {str(e)}",
|
||
error_type=LLMClientError.TYPE_PARSE
|
||
)
|
||
|
||
return self._do_request_with_retry(do_request, "LLM调用")
|
||
|
||
def chat_stream(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 180
|
||
) -> Generator[str, None, None]:
|
||
"""
|
||
调用 LLM 进行对话(流式,带自动重试)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒)
|
||
|
||
Yields:
|
||
逐个返回生成的文本片段
|
||
"""
|
||
# 记录输入 - 完整内容不截断
|
||
logger.info("=" * 80)
|
||
logger.info(f"LLM 调用 [流式] - 模型: {model}")
|
||
logger.info(f"参数: temperature={temperature}, max_tokens={max_tokens}, timeout={timeout}s")
|
||
logger.info(f"时间戳: {datetime.now().isoformat()}")
|
||
logger.info("-" * 80)
|
||
logger.info("输入消息:")
|
||
for i, msg in enumerate(messages):
|
||
role = msg.get('role', 'unknown')
|
||
content = msg.get('content', '')
|
||
logger.info(f" [{i+1}] {role} ({len(content)} 字符):")
|
||
# 完整记录,不截断
|
||
for line in content.split('\n'):
|
||
logger.info(f" {line}")
|
||
logger.info("-" * 80)
|
||
logger.info("开始接收流式输出...")
|
||
|
||
def do_request():
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": True,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
# 记录请求详情
|
||
logger.debug(f"API URL: {self.api_url}")
|
||
logger.debug(f"请求 Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}")
|
||
|
||
try:
|
||
start_time = time.time()
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=timeout,
|
||
stream=True
|
||
)
|
||
elapsed_time = time.time() - start_time
|
||
logger.info(f"连接建立耗时: {elapsed_time:.2f}秒")
|
||
except requests.exceptions.Timeout as e:
|
||
logger.error(f"请求超时: {timeout}秒")
|
||
raise LLMClientError(
|
||
f"请求超时({timeout}秒),请检查网络连接或稍后重试",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
except requests.exceptions.ConnectionError as e:
|
||
logger.error(f"网络连接失败: {str(e)}")
|
||
raise LLMClientError(
|
||
"网络连接失败,请检查网络设置",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"网络请求异常: {str(e)}")
|
||
raise LLMClientError(
|
||
f"网络请求异常: {str(e)}",
|
||
error_type=LLMClientError.TYPE_NETWORK,
|
||
original_exception=e
|
||
)
|
||
|
||
# 记录响应状态
|
||
logger.debug(f"响应状态码: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||
try:
|
||
error_detail = response.json()
|
||
logger.error(f"错误详情: {json.dumps(error_detail, ensure_ascii=False, indent=2)}")
|
||
if "error" in error_detail:
|
||
error_msg += f": {error_detail['error']}"
|
||
except:
|
||
logger.error(f"错误响应: {response.text[:500]}")
|
||
error_msg += f": {response.text[:200]}"
|
||
|
||
# 根据状态码确定错误类型
|
||
if response.status_code >= 500:
|
||
error_type = LLMClientError.TYPE_SERVER
|
||
elif response.status_code == 429:
|
||
error_type = LLMClientError.TYPE_SERVER # 限流也可重试
|
||
else:
|
||
error_type = LLMClientError.TYPE_CLIENT
|
||
|
||
raise LLMClientError(error_msg, error_type=error_type)
|
||
|
||
return response
|
||
|
||
# 流式请求的重试只在建立连接阶段
|
||
response = self._do_request_with_retry(do_request, "流式LLM调用")
|
||
|
||
# 收集完整输出用于日志
|
||
full_output = []
|
||
|
||
# 解析 SSE 流
|
||
try:
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith('data: '):
|
||
data = line[6:] # 去掉 "data: " 前缀
|
||
if data == '[DONE]':
|
||
break
|
||
try:
|
||
chunk = json.loads(data)
|
||
if 'choices' in chunk and len(chunk['choices']) > 0:
|
||
delta = chunk['choices'][0].get('delta', {})
|
||
content = delta.get('content', '')
|
||
if content:
|
||
full_output.append(content)
|
||
yield content
|
||
except json.JSONDecodeError:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"流式输出异常: {str(e)}")
|
||
raise
|
||
|
||
# 记录完整输出 - 不截断
|
||
complete_output = ''.join(full_output)
|
||
logger.info("流式输出完成:")
|
||
logger.info(f" 总长度: {len(complete_output)} 字符")
|
||
for line in complete_output.split('\n'):
|
||
logger.info(f" {line}")
|
||
logger.info("=" * 80)
|
||
|
||
def chat_stream_collect(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 180,
|
||
on_chunk: Optional[Callable[[str], None]] = None
|
||
) -> str:
|
||
"""
|
||
流式调用并收集完整结果
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒)
|
||
on_chunk: 每收到一个片段时的回调函数
|
||
|
||
Returns:
|
||
完整的生成文本
|
||
"""
|
||
full_content = []
|
||
|
||
for chunk in self.chat_stream(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
timeout=timeout
|
||
):
|
||
full_content.append(chunk)
|
||
if on_chunk:
|
||
on_chunk(chunk)
|
||
|
||
return ''.join(full_content)
|
||
|
||
|
||
# 全局单例(延迟初始化)
|
||
_client: Optional[LLMClient] = None
|
||
|
||
|
||
def get_client() -> LLMClient:
|
||
"""获取 LLM 客户端单例"""
|
||
global _client
|
||
if _client is None:
|
||
_client = LLMClient()
|
||
return _client
|
||
|
||
|
||
def reset_client() -> None:
|
||
"""重置 LLM 客户端单例(配置变更后调用)"""
|
||
global _client
|
||
_client = None
|
||
|
||
|
||
def test_connection(timeout: int = 10) -> tuple[bool, str]:
|
||
"""
|
||
测试 API 连接是否正常
|
||
|
||
Args:
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
(是否成功, 消息)
|
||
"""
|
||
try:
|
||
client = get_client()
|
||
|
||
# 发送简单的测试请求
|
||
response = client.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=os.getenv("INTENT_MODEL_NAME") or "Qwen/Qwen2.5-7B-Instruct",
|
||
temperature=0.1,
|
||
max_tokens=10,
|
||
timeout=timeout
|
||
)
|
||
|
||
return (True, "连接成功")
|
||
|
||
except LLMClientError as e:
|
||
error_msg = str(e)
|
||
if "未配置" in error_msg or "API Key" in error_msg:
|
||
return (False, f"配置错误: {error_msg}")
|
||
elif "状态码: 401" in error_msg or "Unauthorized" in error_msg:
|
||
return (False, "API Key 无效,请检查配置")
|
||
elif "状态码: 403" in error_msg:
|
||
return (False, "API Key 权限不足")
|
||
elif "状态码: 404" in error_msg:
|
||
return (False, "API 地址错误或模型不存在")
|
||
elif "网络连接失败" in error_msg:
|
||
return (False, "网络连接失败,请检查网络设置")
|
||
elif "请求超时" in error_msg:
|
||
return (False, f"连接超时({timeout}秒),请检查网络或稍后重试")
|
||
else:
|
||
return (False, f"连接失败: {error_msg}")
|
||
except Exception as e:
|
||
return (False, f"未知错误: {str(e)}")
|