Initial commit
This commit is contained in:
2
intent/__init__.py
Normal file
2
intent/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# 意图识别模块
|
||||
|
||||
BIN
intent/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
intent/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
intent/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
intent/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
intent/__pycache__/classifier.cpython-310.pyc
Normal file
BIN
intent/__pycache__/classifier.cpython-310.pyc
Normal file
Binary file not shown.
BIN
intent/__pycache__/classifier.cpython-313.pyc
Normal file
BIN
intent/__pycache__/classifier.cpython-313.pyc
Normal file
Binary file not shown.
BIN
intent/__pycache__/labels.cpython-310.pyc
Normal file
BIN
intent/__pycache__/labels.cpython-310.pyc
Normal file
Binary file not shown.
BIN
intent/__pycache__/labels.cpython-313.pyc
Normal file
BIN
intent/__pycache__/labels.cpython-313.pyc
Normal file
Binary file not shown.
152
intent/classifier.py
Normal file
152
intent/classifier.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
意图识别器
|
||||
使用小参数 LLM 进行意图二分类
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llm.client import get_client, LLMClientError, ENV_PATH
|
||||
from llm.prompts import INTENT_CLASSIFICATION_SYSTEM, INTENT_CLASSIFICATION_USER
|
||||
from intent.labels import CHAT, EXECUTION, EXECUTION_CONFIDENCE_THRESHOLD, VALID_LABELS
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentResult:
|
||||
"""意图识别结果"""
|
||||
label: str # chat 或 execution
|
||||
confidence: float # 0.0 ~ 1.0
|
||||
reason: str # 中文解释
|
||||
raw_response: Optional[str] = None # 原始 LLM 响应(调试用)
|
||||
|
||||
|
||||
class IntentClassifier:
|
||||
"""
|
||||
意图分类器
|
||||
|
||||
使用小参数 LLM(如 qwen2.5:7b-instruct)进行快速意图识别
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
load_dotenv(ENV_PATH)
|
||||
self.model_name = os.getenv("INTENT_MODEL_NAME")
|
||||
|
||||
def classify(self, user_input: str) -> IntentResult:
|
||||
"""
|
||||
对用户输入进行意图分类
|
||||
|
||||
Args:
|
||||
user_input: 用户输入的文本
|
||||
|
||||
Returns:
|
||||
IntentResult: 包含 label, confidence, reason 的结果
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": INTENT_CLASSIFICATION_SYSTEM},
|
||||
{"role": "user", "content": INTENT_CLASSIFICATION_USER.format(user_input=user_input)}
|
||||
]
|
||||
|
||||
response = client.chat(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
temperature=0.1, # 低温度,更确定性的输出
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
except LLMClientError as e:
|
||||
# LLM 调用失败,走兜底逻辑
|
||||
return IntentResult(
|
||||
label=CHAT,
|
||||
confidence=0.0,
|
||||
reason=f"意图识别失败({str(e)}),默认为对话模式"
|
||||
)
|
||||
except Exception as e:
|
||||
# 其他异常,走兜底逻辑
|
||||
return IntentResult(
|
||||
label=CHAT,
|
||||
confidence=0.0,
|
||||
reason=f"意图识别异常({str(e)}),默认为对话模式"
|
||||
)
|
||||
|
||||
def _parse_response(self, response: str) -> IntentResult:
|
||||
"""
|
||||
解析 LLM 响应
|
||||
|
||||
尝试解析 JSON,若失败则走兜底逻辑
|
||||
"""
|
||||
try:
|
||||
# 尝试提取 JSON(LLM 可能会在 JSON 前后加一些文字)
|
||||
json_str = self._extract_json(response)
|
||||
data = json.loads(json_str)
|
||||
|
||||
# 验证必要字段
|
||||
label = data.get("label", "").lower()
|
||||
confidence = float(data.get("confidence", 0.0))
|
||||
reason = data.get("reason", "无")
|
||||
|
||||
# 验证 label 有效性
|
||||
if label not in VALID_LABELS:
|
||||
return IntentResult(
|
||||
label=CHAT,
|
||||
confidence=0.0,
|
||||
reason=f"无效的意图标签 '{label}',默认为对话模式",
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
# 应用置信度阈值
|
||||
if label == EXECUTION and confidence < EXECUTION_CONFIDENCE_THRESHOLD:
|
||||
return IntentResult(
|
||||
label=CHAT,
|
||||
confidence=confidence,
|
||||
reason=f"执行任务置信度不足({confidence:.2f} < {EXECUTION_CONFIDENCE_THRESHOLD}),降级为对话模式。原因: {reason}",
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
return IntentResult(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
reason=reason,
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
# JSON 解析失败,走兜底逻辑
|
||||
return IntentResult(
|
||||
label=CHAT,
|
||||
confidence=0.0,
|
||||
reason=f"响应解析失败,默认为对话模式",
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
def _extract_json(self, text: str) -> str:
|
||||
"""
|
||||
从文本中提取 JSON 字符串
|
||||
|
||||
LLM 可能会在 JSON 前后添加解释文字,需要提取纯 JSON 部分
|
||||
"""
|
||||
# 尝试找到 JSON 对象的起止位置
|
||||
start = text.find('{')
|
||||
end = text.rfind('}')
|
||||
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return text[start:end + 1]
|
||||
|
||||
# 如果找不到,返回原文本让 json.loads 报错
|
||||
return text
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def classify_intent(user_input: str) -> IntentResult:
|
||||
"""快速进行意图分类"""
|
||||
classifier = IntentClassifier()
|
||||
return classifier.classify(user_input)
|
||||
|
||||
15
intent/labels.py
Normal file
15
intent/labels.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
意图标签定义
|
||||
"""
|
||||
|
||||
# 意图类型常量
|
||||
CHAT = "chat"
|
||||
EXECUTION = "execution"
|
||||
|
||||
# 执行任务置信度阈值
|
||||
# 低于此阈值一律判定为 chat(宁可少执行,不可误执行)
|
||||
EXECUTION_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
# 所有有效标签
|
||||
VALID_LABELS = {CHAT, EXECUTION}
|
||||
|
||||
Reference in New Issue
Block a user