Files
LocalAgent/intent/classifier.py
Mimikko-zeus 4b3286f546 Initial commit
2026-01-07 00:17:46 +08:00

153 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
意图识别器
使用小参数 LLM 进行意图二分类
"""
import os
import json
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from dotenv import load_dotenv
from llm.client import get_client, LLMClientError, ENV_PATH
from llm.prompts import INTENT_CLASSIFICATION_SYSTEM, INTENT_CLASSIFICATION_USER
from intent.labels import CHAT, EXECUTION, EXECUTION_CONFIDENCE_THRESHOLD, VALID_LABELS
@dataclass
class IntentResult:
"""意图识别结果"""
label: str # chat 或 execution
confidence: float # 0.0 ~ 1.0
reason: str # 中文解释
raw_response: Optional[str] = None # 原始 LLM 响应(调试用)
class IntentClassifier:
"""
意图分类器
使用小参数 LLM如 qwen2.5:7b-instruct进行快速意图识别
"""
def __init__(self):
load_dotenv(ENV_PATH)
self.model_name = os.getenv("INTENT_MODEL_NAME")
def classify(self, user_input: str) -> IntentResult:
"""
对用户输入进行意图分类
Args:
user_input: 用户输入的文本
Returns:
IntentResult: 包含 label, confidence, reason 的结果
"""
try:
client = get_client()
messages = [
{"role": "system", "content": INTENT_CLASSIFICATION_SYSTEM},
{"role": "user", "content": INTENT_CLASSIFICATION_USER.format(user_input=user_input)}
]
response = client.chat(
messages=messages,
model=self.model_name,
temperature=0.1, # 低温度,更确定性的输出
max_tokens=256
)
return self._parse_response(response)
except LLMClientError as e:
# LLM 调用失败,走兜底逻辑
return IntentResult(
label=CHAT,
confidence=0.0,
reason=f"意图识别失败({str(e)}),默认为对话模式"
)
except Exception as e:
# 其他异常,走兜底逻辑
return IntentResult(
label=CHAT,
confidence=0.0,
reason=f"意图识别异常({str(e)}),默认为对话模式"
)
def _parse_response(self, response: str) -> IntentResult:
"""
解析 LLM 响应
尝试解析 JSON若失败则走兜底逻辑
"""
try:
# 尝试提取 JSONLLM 可能会在 JSON 前后加一些文字)
json_str = self._extract_json(response)
data = json.loads(json_str)
# 验证必要字段
label = data.get("label", "").lower()
confidence = float(data.get("confidence", 0.0))
reason = data.get("reason", "")
# 验证 label 有效性
if label not in VALID_LABELS:
return IntentResult(
label=CHAT,
confidence=0.0,
reason=f"无效的意图标签 '{label}',默认为对话模式",
raw_response=response
)
# 应用置信度阈值
if label == EXECUTION and confidence < EXECUTION_CONFIDENCE_THRESHOLD:
return IntentResult(
label=CHAT,
confidence=confidence,
reason=f"执行任务置信度不足({confidence:.2f} < {EXECUTION_CONFIDENCE_THRESHOLD}),降级为对话模式。原因: {reason}",
raw_response=response
)
return IntentResult(
label=label,
confidence=confidence,
reason=reason,
raw_response=response
)
except (json.JSONDecodeError, ValueError, TypeError) as e:
# JSON 解析失败,走兜底逻辑
return IntentResult(
label=CHAT,
confidence=0.0,
reason=f"响应解析失败,默认为对话模式",
raw_response=response
)
def _extract_json(self, text: str) -> str:
"""
从文本中提取 JSON 字符串
LLM 可能会在 JSON 前后添加解释文字,需要提取纯 JSON 部分
"""
# 尝试找到 JSON 对象的起止位置
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
return text[start:end + 1]
# 如果找不到,返回原文本让 json.loads 报错
return text
# 便捷函数
def classify_intent(user_input: str) -> IntentResult:
"""快速进行意图分类"""
classifier = IntentClassifier()
return classifier.classify(user_input)