diff --git a/ai_daily_report/rewrite.py b/ai_daily_report/rewrite.py index 0384706..c53fd31 100644 --- a/ai_daily_report/rewrite.py +++ b/ai_daily_report/rewrite.py @@ -73,9 +73,6 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int by_id[item_id].title = title by_id[item_id].summary = summary seen_ids.add(item_id) - for item in batch: - if item.id not in seen_ids: - raise ValueError(f"missing_rewrite_for_item: {item.id}") return len(seen_ids) @@ -89,11 +86,19 @@ def rewrite_items( ) -> tuple[list[NewsItem], dict[str, Any]]: rewritten_count = 0 fallback_count = 0 + missing_rewrite_count = 0 errors: list[str] = [] for batch in _chunks(items, max(1, batch_size)): try: - rewritten_count += _apply_rewrite_batch(batch, llm_call) + batch_rewritten_count = _apply_rewrite_batch(batch, llm_call) + rewritten_count += batch_rewritten_count + for item in batch: + if item.title is None or item.summary is None: + errors.append(f"missing_rewrite_for_item: {item.id}") + _fallback(item) + fallback_count += 1 + missing_rewrite_count += 1 except Exception as exc: errors.append(f"batch:{type(exc).__name__}: {exc}") if _is_transient_llm_error(exc): @@ -123,6 +128,7 @@ def rewrite_items( "input_count": len(items), "rewritten_count": rewritten_count, "fallback_count": fallback_count, + "missing_rewrite_count": missing_rewrite_count, "fallback_ratio": round(fallback_ratio, 4), "batch_count": len(_chunks(items, max(1, batch_size))), "errors": errors, diff --git a/tests/test_stage4_rewrite.py b/tests/test_stage4_rewrite.py index 3625e4e..21e3201 100644 --- a/tests/test_stage4_rewrite.py +++ b/tests/test_stage4_rewrite.py @@ -107,6 +107,27 @@ class Stage4RewriteTests(unittest.TestCase): self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"]) self.assertEqual(report["fallback_count"], 2) + def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self): + items = [news_item("a"), news_item("b"), news_item("c")] + + def llm_call(prompt): + return json.dumps( + { + "rewrites": [ + {"id": "a", "title": "title a", "summary": "summary a", "flags": []}, + {"id": "c", "title": "title c", "summary": "summary c", "flags": []}, + ] + } + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=3, max_fallback_ratio=0.5) + + self.assertEqual([item.title for item in rewritten], ["title a", "OpenAI launches GPT-5 API", "title c"]) + self.assertEqual(report["rewritten_count"], 2) + self.assertEqual(report["fallback_count"], 1) + self.assertEqual(report["missing_rewrite_count"], 1) + self.assertEqual(report["blocking_errors"], []) + def test_rewrite_items_defaults_to_large_batches_to_reduce_llm_requests(self): items = [news_item(str(index)) for index in range(61)] batch_sizes = []