Keep partial rewrite results from LLM batches
This commit is contained in:
@@ -73,9 +73,6 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int
|
|||||||
by_id[item_id].title = title
|
by_id[item_id].title = title
|
||||||
by_id[item_id].summary = summary
|
by_id[item_id].summary = summary
|
||||||
seen_ids.add(item_id)
|
seen_ids.add(item_id)
|
||||||
for item in batch:
|
|
||||||
if item.id not in seen_ids:
|
|
||||||
raise ValueError(f"missing_rewrite_for_item: {item.id}")
|
|
||||||
return len(seen_ids)
|
return len(seen_ids)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,11 +86,19 @@ def rewrite_items(
|
|||||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||||
rewritten_count = 0
|
rewritten_count = 0
|
||||||
fallback_count = 0
|
fallback_count = 0
|
||||||
|
missing_rewrite_count = 0
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
for batch in _chunks(items, max(1, batch_size)):
|
for batch in _chunks(items, max(1, batch_size)):
|
||||||
try:
|
try:
|
||||||
rewritten_count += _apply_rewrite_batch(batch, llm_call)
|
batch_rewritten_count = _apply_rewrite_batch(batch, llm_call)
|
||||||
|
rewritten_count += batch_rewritten_count
|
||||||
|
for item in batch:
|
||||||
|
if item.title is None or item.summary is None:
|
||||||
|
errors.append(f"missing_rewrite_for_item: {item.id}")
|
||||||
|
_fallback(item)
|
||||||
|
fallback_count += 1
|
||||||
|
missing_rewrite_count += 1
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
||||||
if _is_transient_llm_error(exc):
|
if _is_transient_llm_error(exc):
|
||||||
@@ -123,6 +128,7 @@ def rewrite_items(
|
|||||||
"input_count": len(items),
|
"input_count": len(items),
|
||||||
"rewritten_count": rewritten_count,
|
"rewritten_count": rewritten_count,
|
||||||
"fallback_count": fallback_count,
|
"fallback_count": fallback_count,
|
||||||
|
"missing_rewrite_count": missing_rewrite_count,
|
||||||
"fallback_ratio": round(fallback_ratio, 4),
|
"fallback_ratio": round(fallback_ratio, 4),
|
||||||
"batch_count": len(_chunks(items, max(1, batch_size))),
|
"batch_count": len(_chunks(items, max(1, batch_size))),
|
||||||
"errors": errors,
|
"errors": errors,
|
||||||
|
|||||||
@@ -107,6 +107,27 @@ class Stage4RewriteTests(unittest.TestCase):
|
|||||||
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
||||||
self.assertEqual(report["fallback_count"], 2)
|
self.assertEqual(report["fallback_count"], 2)
|
||||||
|
|
||||||
|
def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self):
|
||||||
|
items = [news_item("a"), news_item("b"), news_item("c")]
|
||||||
|
|
||||||
|
def llm_call(prompt):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"rewrites": [
|
||||||
|
{"id": "a", "title": "title a", "summary": "summary a", "flags": []},
|
||||||
|
{"id": "c", "title": "title c", "summary": "summary c", "flags": []},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=3, max_fallback_ratio=0.5)
|
||||||
|
|
||||||
|
self.assertEqual([item.title for item in rewritten], ["title a", "OpenAI launches GPT-5 API", "title c"])
|
||||||
|
self.assertEqual(report["rewritten_count"], 2)
|
||||||
|
self.assertEqual(report["fallback_count"], 1)
|
||||||
|
self.assertEqual(report["missing_rewrite_count"], 1)
|
||||||
|
self.assertEqual(report["blocking_errors"], [])
|
||||||
|
|
||||||
def test_rewrite_items_defaults_to_large_batches_to_reduce_llm_requests(self):
|
def test_rewrite_items_defaults_to_large_batches_to_reduce_llm_requests(self):
|
||||||
items = [news_item(str(index)) for index in range(61)]
|
items = [news_item(str(index)) for index in range(61)]
|
||||||
batch_sizes = []
|
batch_sizes = []
|
||||||
|
|||||||
Reference in New Issue
Block a user