Refactor AI daily report pipeline
This commit is contained in:
109
ai_daily_report/classify.py
Normal file
109
ai_daily_report/classify.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
from .models import NewsItem
|
||||
|
||||
|
||||
SECTION_ORDER = [
|
||||
"模型与能力",
|
||||
"产品与应用",
|
||||
"开发与基础设施",
|
||||
"公司与资本",
|
||||
"政策与安全",
|
||||
"论文与研究",
|
||||
"观点与教程",
|
||||
"人物与动态",
|
||||
]
|
||||
|
||||
SECTION_ALIASES = {
|
||||
"模型发布/更新": "模型与能力",
|
||||
"产品发布/更新": "产品与应用",
|
||||
"产品与工具": "产品与应用",
|
||||
"开发与工程": "开发与基础设施",
|
||||
"行业动态": "公司与资本",
|
||||
"行业与公司": "公司与资本",
|
||||
"论文研究": "论文与研究",
|
||||
"论文与研究": "论文与研究",
|
||||
"技巧与观点": "观点与教程",
|
||||
"观点与教程": "观点与教程",
|
||||
"人物与花絮": "人物与动态",
|
||||
}
|
||||
|
||||
|
||||
RULES = [
|
||||
("政策与安全", ("监管", "政策", "安全", "风险", "滥用", "攻击", "合规", "版权")),
|
||||
("论文与研究", ("论文", "研究", "arxiv", "cvpr", "benchmark", "评测", "实验")),
|
||||
("开发与基础设施", ("sdk", "api", "mcp", "kubernetes", "框架", "开源", "github", "部署", "基础设施")),
|
||||
("公司与资本", ("融资", "ipo", "上市", "招股书", "合作", "估值", "收购", "资本")),
|
||||
("模型与能力", ("模型", "gpt", "claude", "gemini", "grok", "token", "参数", "多模态", "语音", "推理")),
|
||||
("产品与应用", ("agent", "应用", "产品", "平台", "上线", "工具", "智能体")),
|
||||
("观点与教程", ("教程", "观点", "方法论", "guide", "实践", "技巧")),
|
||||
("人物与动态", ("黄仁勋", "纳德拉", "访谈", "演讲", "人物")),
|
||||
]
|
||||
|
||||
|
||||
def normalize_section_hint(section_hint: str) -> str:
|
||||
hint = (section_hint or "").strip()
|
||||
if hint in SECTION_ORDER:
|
||||
return hint
|
||||
return SECTION_ALIASES.get(hint, "")
|
||||
|
||||
|
||||
def rule_classify(item: NewsItem) -> str:
|
||||
text = f"{item.title or item.title_raw} {item.summary or item.summary_raw}".lower()
|
||||
for section, keywords in RULES:
|
||||
if any(keyword.lower() in text for keyword in keywords):
|
||||
return section
|
||||
return "公司与资本"
|
||||
|
||||
|
||||
def rank_score(item: NewsItem) -> int:
|
||||
text = f"{item.title or item.title_raw} {item.summary or item.summary_raw}"
|
||||
score = max(0, 200 - item.source_priority)
|
||||
if item.source_role == "primary":
|
||||
score += 10
|
||||
if item.canonical_url:
|
||||
score += 10
|
||||
if any(ch.isdigit() for ch in text):
|
||||
score += 10
|
||||
if item.duplicate_sources:
|
||||
score += min(20, len(item.duplicate_sources) * 5)
|
||||
score -= len(item.quality_flags) * 10
|
||||
return score
|
||||
|
||||
|
||||
def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||
hint_classified = 0
|
||||
rule_classified = 0
|
||||
|
||||
for item in items:
|
||||
mapped = normalize_section_hint(item.section_hint)
|
||||
if mapped:
|
||||
item.section = mapped
|
||||
hint_classified += 1
|
||||
else:
|
||||
item.section = rule_classify(item)
|
||||
rule_classified += 1
|
||||
|
||||
section_index = {section: index for index, section in enumerate(SECTION_ORDER)}
|
||||
ordered = sorted(
|
||||
items,
|
||||
key=lambda item: (
|
||||
section_index.get(item.section or "", len(SECTION_ORDER)),
|
||||
-rank_score(item),
|
||||
item.title or item.title_raw,
|
||||
),
|
||||
)
|
||||
section_counts = Counter(item.section for item in ordered if item.section)
|
||||
report = {
|
||||
"input_count": len(items),
|
||||
"section_counts": dict(section_counts),
|
||||
"hint_classified": hint_classified,
|
||||
"rule_classified": rule_classified,
|
||||
"llm_classified": 0,
|
||||
"fallback_classified": 0,
|
||||
"invalid_section_count": sum(1 for item in ordered if item.section not in SECTION_ORDER),
|
||||
}
|
||||
return ordered, report
|
||||
Reference in New Issue
Block a user