Files
LocalAgent/llm/client.py
Mimikko-zeus 4b3286f546 Initial commit
2026-01-07 00:17:46 +08:00

125 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM 统一调用客户端
所有模型通过 SiliconFlow API 调用
"""
import os
import requests
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
# 获取项目根目录
PROJECT_ROOT = Path(__file__).parent.parent
ENV_PATH = PROJECT_ROOT / ".env"
class LLMClientError(Exception):
"""LLM 客户端异常"""
pass
class LLMClient:
"""
统一的 LLM 调用客户端
使用方式:
client = LLMClient()
response = client.chat(
messages=[{"role": "user", "content": "你好"}],
model="Qwen/Qwen2.5-7B-Instruct",
temperature=0.7,
max_tokens=1024
)
"""
def __init__(self):
load_dotenv(ENV_PATH)
self.api_url = os.getenv("LLM_API_URL")
self.api_key = os.getenv("LLM_API_KEY")
if not self.api_url:
raise LLMClientError("未配置 LLM_API_URL请检查 .env 文件")
if not self.api_key or self.api_key == "your_api_key_here":
raise LLMClientError("未配置有效的 LLM_API_KEY请检查 .env 文件")
def chat(
self,
messages: list[dict],
model: str,
temperature: float = 0.7,
max_tokens: int = 1024
) -> str:
"""
调用 LLM 进行对话
Args:
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
model: 模型名称
temperature: 温度参数,控制随机性
max_tokens: 最大生成 token 数
Returns:
LLM 生成的文本内容
Raises:
LLMClientError: 网络异常或 API 返回错误
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": False,
"temperature": temperature,
"max_tokens": max_tokens
}
try:
response = requests.post(
self.api_url,
headers=headers,
json=payload,
timeout=60
)
except requests.exceptions.Timeout:
raise LLMClientError("请求超时,请检查网络连接")
except requests.exceptions.ConnectionError:
raise LLMClientError("网络连接失败,请检查网络设置")
except requests.exceptions.RequestException as e:
raise LLMClientError(f"网络请求异常: {str(e)}")
if response.status_code != 200:
error_msg = f"API 返回错误 (状态码: {response.status_code})"
try:
error_detail = response.json()
if "error" in error_detail:
error_msg += f": {error_detail['error']}"
except:
error_msg += f": {response.text[:200]}"
raise LLMClientError(error_msg)
try:
result = response.json()
content = result["choices"][0]["message"]["content"]
return content
except (KeyError, IndexError, TypeError) as e:
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
# 全局单例(延迟初始化)
_client: Optional[LLMClient] = None
def get_client() -> LLMClient:
"""获取 LLM 客户端单例"""
global _client
if _client is None:
_client = LLMClient()
return _client