Refactor AI daily report pipeline
This commit is contained in:
103
ai_daily_report/rewrite.py
Normal file
103
ai_daily_report/rewrite.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
|
||||
from .llm import parse_json_object
|
||||
from .models import NewsItem
|
||||
|
||||
|
||||
RewriteLlmCall = Callable[[str], str]
|
||||
|
||||
|
||||
def _chunks(items: list[NewsItem], size: int) -> list[list[NewsItem]]:
|
||||
return [items[index : index + size] for index in range(0, len(items), size)]
|
||||
|
||||
|
||||
def _build_prompt(batch: list[NewsItem]) -> str:
|
||||
payload = {
|
||||
"task": (
|
||||
"Rewrite AI news titles and summaries into concise Chinese. Preserve brand/model/API names "
|
||||
"such as GPT-5, Codex, Gemini, Claude, API, MCP. Do not add facts."
|
||||
),
|
||||
"items": [
|
||||
{
|
||||
"id": item.id,
|
||||
"title_raw": item.title_raw,
|
||||
"summary_raw": item.summary_raw,
|
||||
"source": item.source_label,
|
||||
"language_hint": item.language_hint,
|
||||
}
|
||||
for item in batch
|
||||
],
|
||||
"output_schema": {
|
||||
"rewrites": [
|
||||
{
|
||||
"id": "item id",
|
||||
"title": "display title",
|
||||
"summary": "display summary",
|
||||
"flags": [],
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
|
||||
def _fallback(item: NewsItem) -> None:
|
||||
item.title = item.title_raw
|
||||
item.summary = item.summary_raw or "该条目暂无摘要。"
|
||||
|
||||
|
||||
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
|
||||
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
||||
rewrites = obj.get("rewrites", [])
|
||||
if not isinstance(rewrites, list):
|
||||
raise ValueError("rewrites is not a list")
|
||||
by_id = {item.id: item for item in batch}
|
||||
seen_ids: set[str] = set()
|
||||
for entry in rewrites:
|
||||
item_id = entry.get("id")
|
||||
title = str(entry.get("title") or "").strip()
|
||||
summary = str(entry.get("summary") or "").strip()
|
||||
if item_id in by_id and title and summary:
|
||||
by_id[item_id].title = title
|
||||
by_id[item_id].summary = summary
|
||||
seen_ids.add(item_id)
|
||||
for item in batch:
|
||||
if item.id not in seen_ids:
|
||||
raise ValueError(f"missing_rewrite_for_item: {item.id}")
|
||||
return len(seen_ids)
|
||||
|
||||
|
||||
def rewrite_items(
|
||||
items: list[NewsItem],
|
||||
*,
|
||||
llm_call: RewriteLlmCall,
|
||||
batch_size: int = 10,
|
||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||
rewritten_count = 0
|
||||
fallback_count = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for batch in _chunks(items, max(1, batch_size)):
|
||||
try:
|
||||
rewritten_count += _apply_rewrite_batch(batch, llm_call)
|
||||
except Exception as exc:
|
||||
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
||||
for item in batch:
|
||||
try:
|
||||
rewritten_count += _apply_rewrite_batch([item], llm_call)
|
||||
except Exception as item_exc:
|
||||
errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}")
|
||||
_fallback(item)
|
||||
fallback_count += 1
|
||||
|
||||
report = {
|
||||
"input_count": len(items),
|
||||
"rewritten_count": rewritten_count,
|
||||
"fallback_count": fallback_count,
|
||||
"batch_count": len(_chunks(items, max(1, batch_size))),
|
||||
"errors": errors,
|
||||
}
|
||||
return items, report
|
||||
Reference in New Issue
Block a user