diff --git a/ai_daily_report/rewrite.py b/ai_daily_report/rewrite.py index acae2b9..8aa00a3 100644 --- a/ai_daily_report/rewrite.py +++ b/ai_daily_report/rewrite.py @@ -74,15 +74,13 @@ def _is_transient_llm_error(exc: Exception) -> bool: return False -def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]: - obj = parse_json_object(llm_call(_build_prompt(batch))) - rewrites = obj.get("rewrites", []) - if not isinstance(rewrites, list): - raise ValueError("rewrites is not a list") +def _apply_rewrite_results(batch: list[NewsItem], rewrites: list[Any]) -> tuple[int, int]: by_id = {item.id: item for item in batch} seen_ids: set[str] = set() section_count = 0 for entry in rewrites: + if not isinstance(entry, dict): + continue item_id = entry.get("id") title = str(entry.get("title") or "").strip() summary = str(entry.get("summary") or "").strip() @@ -97,11 +95,20 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tup return len(seen_ids), section_count +def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]: + obj = parse_json_object(llm_call(_build_prompt(batch))) + rewrites = obj.get("rewrites", []) + if not isinstance(rewrites, list): + raise ValueError("rewrites is not a list") + return _apply_rewrite_results(batch, rewrites) + + def rewrite_items( items: list[NewsItem], *, llm_call: RewriteLlmCall, batch_size: int = 30, + retry_batch_size: int = 10, max_fallback_ratio: float = 0.2, retry_single_items: bool = False, ) -> tuple[list[NewsItem], dict[str, Any]]: @@ -109,6 +116,7 @@ def rewrite_items( llm_section_count = 0 fallback_count = 0 missing_rewrite_count = 0 + batch_retry_count = 0 errors: list[str] = [] for batch in _chunks(items, max(1, batch_size)): @@ -129,6 +137,25 @@ def rewrite_items( _fallback(item) fallback_count += 1 continue + if len(batch) > max(1, retry_batch_size): + for retry_batch in _chunks(batch, max(1, retry_batch_size)): + batch_retry_count += 1 + try: + retry_rewritten_count, retry_section_count = _apply_rewrite_batch(retry_batch, llm_call) + rewritten_count += retry_rewritten_count + llm_section_count += retry_section_count + for item in retry_batch: + if item.title is None or item.summary is None: + errors.append(f"missing_rewrite_for_item: {item.id}") + _fallback(item) + fallback_count += 1 + missing_rewrite_count += 1 + except Exception as retry_exc: + errors.append(f"batch_retry:{type(retry_exc).__name__}: {retry_exc}") + for item in retry_batch: + _fallback(item) + fallback_count += 1 + continue if not retry_single_items: for item in batch: _fallback(item) @@ -157,6 +184,7 @@ def rewrite_items( "missing_rewrite_count": missing_rewrite_count, "fallback_ratio": round(fallback_ratio, 4), "batch_count": len(_chunks(items, max(1, batch_size))), + "batch_retry_count": batch_retry_count, "errors": errors, "blocking_errors": blocking_errors, "quality_gate_failed": bool(blocking_errors), diff --git a/tests/test_stage4_rewrite.py b/tests/test_stage4_rewrite.py index a9338cf..b623001 100644 --- a/tests/test_stage4_rewrite.py +++ b/tests/test_stage4_rewrite.py @@ -132,6 +132,42 @@ class Stage4RewriteTests(unittest.TestCase): self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"]) self.assertEqual(report["fallback_count"], 2) + def test_rewrite_items_retries_failed_large_batch_as_smaller_batches_by_default(self): + items = [news_item(str(index)) for index in range(30)] + calls = [] + + def llm_call(prompt): + payload = json.loads(prompt) + ids = [item["id"] for item in payload["items"]] + calls.append(ids) + if len(ids) == 30: + return "not json" + return json.dumps( + { + "rewrites": [ + { + "id": item_id, + "title": f"title {item_id}", + "summary": f"summary {item_id}", + "section": "模型与能力", + "flags": [], + } + for item_id in ids + ] + }, + ensure_ascii=False, + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call) + + self.assertEqual([len(call) for call in calls], [30, 10, 10, 10]) + self.assertEqual(report["rewritten_count"], 30) + self.assertEqual(report["llm_section_count"], 30) + self.assertEqual(report["fallback_count"], 0) + self.assertEqual(report["batch_retry_count"], 3) + self.assertEqual(report["blocking_errors"], []) + self.assertEqual(rewritten[0].title, "title 0") + def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self): items = [news_item("a"), news_item("b"), news_item("c")]