Files
ai-daily-report/ai_daily_report/rewrite.py
2026-06-04 16:29:40 +08:00

127 lines
4.1 KiB
Python

from __future__ import annotations
import json
from typing import Any, Callable
from urllib.error import HTTPError
from .llm import parse_json_object
from .models import NewsItem
RewriteLlmCall = Callable[[str], str]
def _chunks(items: list[NewsItem], size: int) -> list[list[NewsItem]]:
return [items[index : index + size] for index in range(0, len(items), size)]
def _build_prompt(batch: list[NewsItem]) -> str:
payload = {
"task": (
"Rewrite AI news titles and summaries into concise Chinese. Preserve brand/model/API names "
"such as GPT-5, Codex, Gemini, Claude, API, MCP. Do not add facts."
),
"items": [
{
"id": item.id,
"title_raw": item.title_raw,
"summary_raw": item.summary_raw,
"source": item.source_label,
"language_hint": item.language_hint,
}
for item in batch
],
"output_schema": {
"rewrites": [
{
"id": "item id",
"title": "display title",
"summary": "display summary",
"flags": [],
}
]
},
}
return json.dumps(payload, ensure_ascii=False)
def _fallback(item: NewsItem) -> None:
item.title = item.title_raw
item.summary = item.summary_raw or "该条目暂无摘要。"
def _is_transient_llm_error(exc: Exception) -> bool:
if isinstance(exc, TimeoutError):
return True
if isinstance(exc, HTTPError):
return exc.code in {429, 500, 502, 503, 504}
return False
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
obj = parse_json_object(llm_call(_build_prompt(batch)))
rewrites = obj.get("rewrites", [])
if not isinstance(rewrites, list):
raise ValueError("rewrites is not a list")
by_id = {item.id: item for item in batch}
seen_ids: set[str] = set()
for entry in rewrites:
item_id = entry.get("id")
title = str(entry.get("title") or "").strip()
summary = str(entry.get("summary") or "").strip()
if item_id in by_id and title and summary:
by_id[item_id].title = title
by_id[item_id].summary = summary
seen_ids.add(item_id)
for item in batch:
if item.id not in seen_ids:
raise ValueError(f"missing_rewrite_for_item: {item.id}")
return len(seen_ids)
def rewrite_items(
items: list[NewsItem],
*,
llm_call: RewriteLlmCall,
batch_size: int = 10,
max_fallback_ratio: float = 0.2,
) -> tuple[list[NewsItem], dict[str, Any]]:
rewritten_count = 0
fallback_count = 0
errors: list[str] = []
for batch in _chunks(items, max(1, batch_size)):
try:
rewritten_count += _apply_rewrite_batch(batch, llm_call)
except Exception as exc:
errors.append(f"batch:{type(exc).__name__}: {exc}")
if _is_transient_llm_error(exc):
for item in batch:
_fallback(item)
fallback_count += 1
continue
for item in batch:
try:
rewritten_count += _apply_rewrite_batch([item], llm_call)
except Exception as item_exc:
errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}")
_fallback(item)
fallback_count += 1
fallback_ratio = fallback_count / len(items) if items else 0
blocking_errors: list[str] = []
if fallback_ratio > max_fallback_ratio:
blocking_errors.append("rewrite_fallback_ratio_exceeded")
report = {
"input_count": len(items),
"rewritten_count": rewritten_count,
"fallback_count": fallback_count,
"fallback_ratio": round(fallback_ratio, 4),
"batch_count": len(_chunks(items, max(1, batch_size))),
"errors": errors,
"blocking_errors": blocking_errors,
"quality_gate_failed": bool(blocking_errors),
}
return items, report