Retry failed rewrite batches in smaller chunks
This commit is contained in:
@@ -74,15 +74,13 @@ def _is_transient_llm_error(exc: Exception) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]:
|
def _apply_rewrite_results(batch: list[NewsItem], rewrites: list[Any]) -> tuple[int, int]:
|
||||||
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
|
||||||
rewrites = obj.get("rewrites", [])
|
|
||||||
if not isinstance(rewrites, list):
|
|
||||||
raise ValueError("rewrites is not a list")
|
|
||||||
by_id = {item.id: item for item in batch}
|
by_id = {item.id: item for item in batch}
|
||||||
seen_ids: set[str] = set()
|
seen_ids: set[str] = set()
|
||||||
section_count = 0
|
section_count = 0
|
||||||
for entry in rewrites:
|
for entry in rewrites:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
item_id = entry.get("id")
|
item_id = entry.get("id")
|
||||||
title = str(entry.get("title") or "").strip()
|
title = str(entry.get("title") or "").strip()
|
||||||
summary = str(entry.get("summary") or "").strip()
|
summary = str(entry.get("summary") or "").strip()
|
||||||
@@ -97,11 +95,20 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tup
|
|||||||
return len(seen_ids), section_count
|
return len(seen_ids), section_count
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]:
|
||||||
|
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
||||||
|
rewrites = obj.get("rewrites", [])
|
||||||
|
if not isinstance(rewrites, list):
|
||||||
|
raise ValueError("rewrites is not a list")
|
||||||
|
return _apply_rewrite_results(batch, rewrites)
|
||||||
|
|
||||||
|
|
||||||
def rewrite_items(
|
def rewrite_items(
|
||||||
items: list[NewsItem],
|
items: list[NewsItem],
|
||||||
*,
|
*,
|
||||||
llm_call: RewriteLlmCall,
|
llm_call: RewriteLlmCall,
|
||||||
batch_size: int = 30,
|
batch_size: int = 30,
|
||||||
|
retry_batch_size: int = 10,
|
||||||
max_fallback_ratio: float = 0.2,
|
max_fallback_ratio: float = 0.2,
|
||||||
retry_single_items: bool = False,
|
retry_single_items: bool = False,
|
||||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||||
@@ -109,6 +116,7 @@ def rewrite_items(
|
|||||||
llm_section_count = 0
|
llm_section_count = 0
|
||||||
fallback_count = 0
|
fallback_count = 0
|
||||||
missing_rewrite_count = 0
|
missing_rewrite_count = 0
|
||||||
|
batch_retry_count = 0
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
for batch in _chunks(items, max(1, batch_size)):
|
for batch in _chunks(items, max(1, batch_size)):
|
||||||
@@ -129,6 +137,25 @@ def rewrite_items(
|
|||||||
_fallback(item)
|
_fallback(item)
|
||||||
fallback_count += 1
|
fallback_count += 1
|
||||||
continue
|
continue
|
||||||
|
if len(batch) > max(1, retry_batch_size):
|
||||||
|
for retry_batch in _chunks(batch, max(1, retry_batch_size)):
|
||||||
|
batch_retry_count += 1
|
||||||
|
try:
|
||||||
|
retry_rewritten_count, retry_section_count = _apply_rewrite_batch(retry_batch, llm_call)
|
||||||
|
rewritten_count += retry_rewritten_count
|
||||||
|
llm_section_count += retry_section_count
|
||||||
|
for item in retry_batch:
|
||||||
|
if item.title is None or item.summary is None:
|
||||||
|
errors.append(f"missing_rewrite_for_item: {item.id}")
|
||||||
|
_fallback(item)
|
||||||
|
fallback_count += 1
|
||||||
|
missing_rewrite_count += 1
|
||||||
|
except Exception as retry_exc:
|
||||||
|
errors.append(f"batch_retry:{type(retry_exc).__name__}: {retry_exc}")
|
||||||
|
for item in retry_batch:
|
||||||
|
_fallback(item)
|
||||||
|
fallback_count += 1
|
||||||
|
continue
|
||||||
if not retry_single_items:
|
if not retry_single_items:
|
||||||
for item in batch:
|
for item in batch:
|
||||||
_fallback(item)
|
_fallback(item)
|
||||||
@@ -157,6 +184,7 @@ def rewrite_items(
|
|||||||
"missing_rewrite_count": missing_rewrite_count,
|
"missing_rewrite_count": missing_rewrite_count,
|
||||||
"fallback_ratio": round(fallback_ratio, 4),
|
"fallback_ratio": round(fallback_ratio, 4),
|
||||||
"batch_count": len(_chunks(items, max(1, batch_size))),
|
"batch_count": len(_chunks(items, max(1, batch_size))),
|
||||||
|
"batch_retry_count": batch_retry_count,
|
||||||
"errors": errors,
|
"errors": errors,
|
||||||
"blocking_errors": blocking_errors,
|
"blocking_errors": blocking_errors,
|
||||||
"quality_gate_failed": bool(blocking_errors),
|
"quality_gate_failed": bool(blocking_errors),
|
||||||
|
|||||||
@@ -132,6 +132,42 @@ class Stage4RewriteTests(unittest.TestCase):
|
|||||||
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
||||||
self.assertEqual(report["fallback_count"], 2)
|
self.assertEqual(report["fallback_count"], 2)
|
||||||
|
|
||||||
|
def test_rewrite_items_retries_failed_large_batch_as_smaller_batches_by_default(self):
|
||||||
|
items = [news_item(str(index)) for index in range(30)]
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def llm_call(prompt):
|
||||||
|
payload = json.loads(prompt)
|
||||||
|
ids = [item["id"] for item in payload["items"]]
|
||||||
|
calls.append(ids)
|
||||||
|
if len(ids) == 30:
|
||||||
|
return "not json"
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"rewrites": [
|
||||||
|
{
|
||||||
|
"id": item_id,
|
||||||
|
"title": f"title {item_id}",
|
||||||
|
"summary": f"summary {item_id}",
|
||||||
|
"section": "模型与能力",
|
||||||
|
"flags": [],
|
||||||
|
}
|
||||||
|
for item_id in ids
|
||||||
|
]
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewritten, report = rewrite_items(items, llm_call=llm_call)
|
||||||
|
|
||||||
|
self.assertEqual([len(call) for call in calls], [30, 10, 10, 10])
|
||||||
|
self.assertEqual(report["rewritten_count"], 30)
|
||||||
|
self.assertEqual(report["llm_section_count"], 30)
|
||||||
|
self.assertEqual(report["fallback_count"], 0)
|
||||||
|
self.assertEqual(report["batch_retry_count"], 3)
|
||||||
|
self.assertEqual(report["blocking_errors"], [])
|
||||||
|
self.assertEqual(rewritten[0].title, "title 0")
|
||||||
|
|
||||||
def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self):
|
def test_rewrite_items_keeps_partial_batch_rewrites_when_some_ids_are_missing(self):
|
||||||
items = [news_item("a"), news_item("b"), news_item("c")]
|
items = [news_item("a"), news_item("b"), news_item("c")]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user