- Updated .env.example to include API key placeholder and configuration instructions. - Refactored main.py to support streaming responses from the LLM, improving user experience during chat interactions. - Enhanced LLMClient to include methods for streaming chat and collecting responses. - Modified safety review process to pass static analysis warnings to the LLM for better code safety evaluation. - Improved UI components in chat_view.py to handle streaming messages effectively.
247 lines
7.5 KiB
Python
247 lines
7.5 KiB
Python
"""
|
||
LLM 统一调用客户端
|
||
所有模型通过 SiliconFlow API 调用
|
||
支持流式和非流式两种模式
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import requests
|
||
from pathlib import Path
|
||
from typing import Optional, Generator, Callable
|
||
from dotenv import load_dotenv
|
||
|
||
# 获取项目根目录
|
||
PROJECT_ROOT = Path(__file__).parent.parent
|
||
ENV_PATH = PROJECT_ROOT / ".env"
|
||
|
||
|
||
class LLMClientError(Exception):
|
||
"""LLM 客户端异常"""
|
||
pass
|
||
|
||
|
||
class LLMClient:
|
||
"""
|
||
统一的 LLM 调用客户端
|
||
|
||
使用方式:
|
||
client = LLMClient()
|
||
|
||
# 非流式调用
|
||
response = client.chat(
|
||
messages=[{"role": "user", "content": "你好"}],
|
||
model="Qwen/Qwen2.5-7B-Instruct"
|
||
)
|
||
|
||
# 流式调用
|
||
for chunk in client.chat_stream(
|
||
messages=[{"role": "user", "content": "你好"}],
|
||
model="Qwen/Qwen2.5-7B-Instruct"
|
||
):
|
||
print(chunk, end="", flush=True)
|
||
"""
|
||
|
||
def __init__(self):
|
||
load_dotenv(ENV_PATH)
|
||
|
||
self.api_url = os.getenv("LLM_API_URL")
|
||
self.api_key = os.getenv("LLM_API_KEY")
|
||
|
||
if not self.api_url:
|
||
raise LLMClientError("未配置 LLM_API_URL,请检查 .env 文件")
|
||
if not self.api_key or self.api_key == "your_api_key_here":
|
||
raise LLMClientError("未配置有效的 LLM_API_KEY,请检查 .env 文件")
|
||
|
||
def chat(
|
||
self,
|
||
messages: list[dict],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 1024,
|
||
timeout: int = 180
|
||
) -> str:
|
||
"""
|
||
调用 LLM 进行对话(非流式)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒),默认 180 秒
|
||
|
||
Returns:
|
||
LLM 生成的文本内容
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": False,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
try:
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=timeout
|
||
)
|
||
except requests.exceptions.Timeout:
|
||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||
except requests.exceptions.ConnectionError:
|
||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||
except requests.exceptions.RequestException as e:
|
||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||
try:
|
||
error_detail = response.json()
|
||
if "error" in error_detail:
|
||
error_msg += f": {error_detail['error']}"
|
||
except:
|
||
error_msg += f": {response.text[:200]}"
|
||
raise LLMClientError(error_msg)
|
||
|
||
try:
|
||
result = response.json()
|
||
content = result["choices"][0]["message"]["content"]
|
||
return content
|
||
except (KeyError, IndexError, TypeError) as e:
|
||
raise LLMClientError(f"解析 API 响应失败: {str(e)}")
|
||
|
||
def chat_stream(
|
||
self,
|
||
messages: list[dict],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 180
|
||
) -> Generator[str, None, None]:
|
||
"""
|
||
调用 LLM 进行对话(流式)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒)
|
||
|
||
Yields:
|
||
逐个返回生成的文本片段
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": True,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
try:
|
||
response = requests.post(
|
||
self.api_url,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=timeout,
|
||
stream=True
|
||
)
|
||
except requests.exceptions.Timeout:
|
||
raise LLMClientError(f"请求超时({timeout}秒),请检查网络连接或稍后重试")
|
||
except requests.exceptions.ConnectionError:
|
||
raise LLMClientError("网络连接失败,请检查网络设置")
|
||
except requests.exceptions.RequestException as e:
|
||
raise LLMClientError(f"网络请求异常: {str(e)}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = f"API 返回错误 (状态码: {response.status_code})"
|
||
try:
|
||
error_detail = response.json()
|
||
if "error" in error_detail:
|
||
error_msg += f": {error_detail['error']}"
|
||
except:
|
||
error_msg += f": {response.text[:200]}"
|
||
raise LLMClientError(error_msg)
|
||
|
||
# 解析 SSE 流
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith('data: '):
|
||
data = line[6:] # 去掉 "data: " 前缀
|
||
if data == '[DONE]':
|
||
break
|
||
try:
|
||
chunk = json.loads(data)
|
||
if 'choices' in chunk and len(chunk['choices']) > 0:
|
||
delta = chunk['choices'][0].get('delta', {})
|
||
content = delta.get('content', '')
|
||
if content:
|
||
yield content
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
def chat_stream_collect(
|
||
self,
|
||
messages: list[dict],
|
||
model: str,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 180,
|
||
on_chunk: Optional[Callable[[str], None]] = None
|
||
) -> str:
|
||
"""
|
||
流式调用并收集完整结果
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成 token 数
|
||
timeout: 超时时间(秒)
|
||
on_chunk: 每收到一个片段时的回调函数
|
||
|
||
Returns:
|
||
完整的生成文本
|
||
"""
|
||
full_content = []
|
||
|
||
for chunk in self.chat_stream(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
timeout=timeout
|
||
):
|
||
full_content.append(chunk)
|
||
if on_chunk:
|
||
on_chunk(chunk)
|
||
|
||
return ''.join(full_content)
|
||
|
||
|
||
# 全局单例(延迟初始化)
|
||
_client: Optional[LLMClient] = None
|
||
|
||
|
||
def get_client() -> LLMClient:
|
||
"""获取 LLM 客户端单例"""
|
||
global _client
|
||
if _client is None:
|
||
_client = LLMClient()
|
||
return _client
|