Keep partial rewrite results from LLM batches

This commit is contained in:
Mimikko-zeus
2026-06-04 16:51:12 +08:00
parent 6eca615f42
commit dd12755ff1
2 changed files with 31 additions and 4 deletions

View File

@@ -73,9 +73,6 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int
by_id[item_id].title = title by_id[item_id].title = title
by_id[item_id].summary = summary by_id[item_id].summary = summary
seen_ids.add(item_id) seen_ids.add(item_id)
for item in batch:
if item.id not in seen_ids:
raise ValueError(f"missing_rewrite_for_item: {item.id}")
return len(seen_ids) return len(seen_ids)
@@ -89,11 +86,19 @@ def rewrite_items(
) -> tuple[list[NewsItem], dict[str, Any]]: ) -> tuple[list[NewsItem], dict[str, Any]]:
rewritten_count = 0 rewritten_count = 0
fallback_count = 0 fallback_count = 0
missing_rewrite_count = 0
errors: list[str] = [] errors: list[str] = []
for batch in _chunks(items, max(1, batch_size)): for batch in _chunks(items, max(1, batch_size)):
try: try:
rewritten_count += _apply_rewrite_batch(batch, llm_call) batch_rewritten_count = _apply_rewrite_batch(batch, llm_call)
rewritten_count += batch_rewritten_count
for item in batch:
if item.title is None or item.summary is None:
errors.append(f"missing_rewrite_for_item: {item.id}")
_fallback(item)
fallback_count += 1
missing_rewrite_count += 1
except Exception as exc: except Exception as exc:
errors.append(f"batch:{type(exc).__name__}: {exc}") errors.append(f"batch:{type(exc).__name__}: {exc}")
if _is_transient_llm_error(exc): if _is_transient_llm_error(exc):
@@ -123,6 +128,7 @@ def rewrite_items(
"input_count": len(items), "input_count": len(items),
"rewritten_count": rewritten_count, "rewritten_count": rewritten_count,
"fallback_count": fallback_count, "fallback_count": fallback_count,
"missing_rewrite_count": missing_rewrite_count,
"fallback_ratio": round(fallback_ratio, 4), "fallback_ratio": round(fallback_ratio, 4),
"batch_count": len(_chunks(items, max(1, batch_size))), "batch_count": len(_chunks(items, max(1, batch_size))),
"errors": errors, "errors": errors,

View File

@@ -107,6 +107,27 @@ class Stage4RewriteTests(unittest.TestCase):
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"]) self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
self.assertEqual(report["fallback_count"], 2) self.assertEqual(report["fallback_count"], 2)
def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self):
items = [news_item("a"), news_item("b"), news_item("c")]
def llm_call(prompt):
return json.dumps(
{
"rewrites": [
{"id": "a", "title": "title a", "summary": "summary a", "flags": []},
{"id": "c", "title": "title c", "summary": "summary c", "flags": []},
]
}
)
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=3, max_fallback_ratio=0.5)
self.assertEqual([item.title for item in rewritten], ["title a", "OpenAI launches GPT-5 API", "title c"])
self.assertEqual(report["rewritten_count"], 2)
self.assertEqual(report["fallback_count"], 1)
self.assertEqual(report["missing_rewrite_count"], 1)
self.assertEqual(report["blocking_errors"], [])
def test_rewrite_items_defaults_to_large_batches_to_reduce_llm_requests(self): def test_rewrite_items_defaults_to_large_batches_to_reduce_llm_requests(self):
items = [news_item(str(index)) for index in range(61)] items = [news_item(str(index)) for index in range(61)]
batch_sizes = [] batch_sizes = []