Block publish when LLM rewrite quality degrades

This commit is contained in:
Mimikko-zeus
2026-06-04 16:29:40 +08:00
parent 5a98696255
commit f7e4c9722b
6 changed files with 132 additions and 1 deletions

View File

@@ -158,6 +158,13 @@ def run_stage0_to_stage7(
guide_llm_call=guide_llm_call,
)
markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"])
upstream_blocking_errors: list[str] = []
for stage_name in ("stage3", "stage4", "stage5", "stage6"):
for error in stage6_result["reports"].get(stage_name, {}).get("blocking_errors", []) or []:
upstream_blocking_errors.append(str(error))
if upstream_blocking_errors:
existing_errors = list(stage7_report.get("blocking_errors", []) or [])
stage7_report["blocking_errors"] = existing_errors + upstream_blocking_errors
reports = dict(stage6_result["reports"])
reports["stage7"] = stage7_report
return {

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json
from typing import Any, Callable
from urllib.error import HTTPError
from .llm import parse_json_object
from .models import NewsItem
@@ -49,6 +50,14 @@ def _fallback(item: NewsItem) -> None:
item.summary = item.summary_raw or "该条目暂无摘要。"
def _is_transient_llm_error(exc: Exception) -> bool:
if isinstance(exc, TimeoutError):
return True
if isinstance(exc, HTTPError):
return exc.code in {429, 500, 502, 503, 504}
return False
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
obj = parse_json_object(llm_call(_build_prompt(batch)))
rewrites = obj.get("rewrites", [])
@@ -75,6 +84,7 @@ def rewrite_items(
*,
llm_call: RewriteLlmCall,
batch_size: int = 10,
max_fallback_ratio: float = 0.2,
) -> tuple[list[NewsItem], dict[str, Any]]:
rewritten_count = 0
fallback_count = 0
@@ -85,6 +95,11 @@ def rewrite_items(
rewritten_count += _apply_rewrite_batch(batch, llm_call)
except Exception as exc:
errors.append(f"batch:{type(exc).__name__}: {exc}")
if _is_transient_llm_error(exc):
for item in batch:
_fallback(item)
fallback_count += 1
continue
for item in batch:
try:
rewritten_count += _apply_rewrite_batch([item], llm_call)
@@ -93,11 +108,19 @@ def rewrite_items(
_fallback(item)
fallback_count += 1
fallback_ratio = fallback_count / len(items) if items else 0
blocking_errors: list[str] = []
if fallback_ratio > max_fallback_ratio:
blocking_errors.append("rewrite_fallback_ratio_exceeded")
report = {
"input_count": len(items),
"rewritten_count": rewritten_count,
"fallback_count": fallback_count,
"fallback_ratio": round(fallback_ratio, 4),
"batch_count": len(_chunks(items, max(1, batch_size))),
"errors": errors,
"blocking_errors": blocking_errors,
"quality_gate_failed": bool(blocking_errors),
}
return items, report

View File

@@ -37,7 +37,7 @@ def main() -> None:
env = load_env()
dry_run = is_dry_run(env)
run_daily_report(
result = run_daily_report(
run_date=env.get("AI_DAILY_RUN_DATE") or "today",
mode="dry-run" if dry_run else env.get("AI_DAILY_MODE", "publish"),
source_mode=env.get("AI_DAILY_SOURCE_MODE", "live"),
@@ -47,6 +47,10 @@ def main() -> None:
sources_path=Path(env["AI_DAILY_SOURCES_PATH"]) if env.get("AI_DAILY_SOURCES_PATH") else None,
env=env,
)
stage8 = result.get("reports", {}).get("stage8", {})
if stage8.get("status") in {"blocked", "failed"}:
print(f"AI daily report failed quality gate: {stage8.get('error') or stage8.get('status')}", file=sys.stderr)
raise SystemExit(2)
if __name__ == "__main__":

View File

@@ -52,6 +52,19 @@ class LegacyScriptDelegationTests(unittest.TestCase):
self.assertEqual(calls[0]["source_mode"], "mock")
self.assertEqual(calls[0]["llm_mode"], "mock")
def test_main_exits_nonzero_when_new_pipeline_blocks_publish(self):
module = load_pipeline_module()
def fake_run_daily_report(**kwargs):
return {"reports": {"stage8": {"status": "blocked", "error": "rewrite_fallback_ratio_exceeded"}}}
with patch.object(module, "load_env", return_value={}):
with patch("ai_daily_report.runner.run_daily_report", side_effect=fake_run_daily_report):
with self.assertRaises(SystemExit) as raised:
module.main()
self.assertEqual(raised.exception.code, 2)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,6 @@
import json
import unittest
from urllib.error import HTTPError
from ai_daily_report.pipeline import run_stage0_to_stage8
@@ -74,6 +75,65 @@ class Stage0To8PipelineTests(unittest.TestCase):
self.assertIn("stage8", result["reports"])
self.assertEqual(result["reports"]["stage8"]["status"], "ok")
def test_run_stage0_to_stage8_blocks_publish_when_rewrite_quality_gate_fails(self):
configs = [{"name": "AI HOT", "type": "fake", "role": "primary", "priority": 10}]
def fetcher(config, run_date):
return [
{
"title_raw": f"News {index}",
"summary_raw": f"Summary {index}",
"url": f"https://example.com/{index}",
"source_label": "Example",
"section_hint": "模型发布/更新",
}
for index in range(6)
]
def semantic_llm_call(prompt):
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []})
def rewrite_llm_call(prompt):
raise HTTPError(
url="https://llm.example/v1/chat/completions",
code=503,
msg="Service Unavailable",
hdrs=None,
fp=None,
)
def guide_llm_call(prompt):
payload = json.loads(prompt)
return json.dumps(
{
"theme": "模型能力继续更新。",
"threads": [
{
"title": "模型更新",
"text": "多条模型新闻更新。",
"item_ids": [payload["items"][0]["id"]],
"kind": "thread",
}
],
}
)
result = run_stage0_to_stage8(
configs,
"2026-06-04",
fetcher=fetcher,
semantic_llm_call=semantic_llm_call,
rewrite_llm_call=rewrite_llm_call,
guide_llm_call=guide_llm_call,
mode="publish",
base_url="https://blog.example",
client=None,
)
self.assertEqual(result["publish"].status, "blocked")
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage7"]["blocking_errors"])
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage8"]["error"])
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,6 @@
import json
import unittest
from urllib.error import HTTPError
from ai_daily_report.models import NewsItem
from ai_daily_report.rewrite import rewrite_items
@@ -91,6 +92,29 @@ class Stage4RewriteTests(unittest.TestCase):
self.assertEqual(report["fallback_count"], 0)
self.assertEqual(calls, [["a", "b"], ["a"], ["b"]])
def test_rewrite_items_does_not_retry_single_items_after_transient_http_error(self):
items = [news_item("a"), news_item("b")]
calls = 0
def llm_call(prompt):
nonlocal calls
calls += 1
raise HTTPError(
url="https://llm.example/v1/chat/completions",
code=503,
msg="Service Unavailable",
hdrs=None,
fp=None,
)
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=2)
self.assertEqual(calls, 1)
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
self.assertEqual(report["fallback_count"], 2)
self.assertTrue(report["quality_gate_failed"])
self.assertIn("rewrite_fallback_ratio_exceeded", report["blocking_errors"])
if __name__ == "__main__":
unittest.main()