From f7e4c9722bf1d453663d55e170b2cccadd32712e Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Thu, 4 Jun 2026 16:29:40 +0800 Subject: [PATCH] Block publish when LLM rewrite quality degrades --- ai_daily_report/pipeline.py | 7 +++ ai_daily_report/rewrite.py | 23 ++++++++++ script/ai_daily_blog_pipeline.py | 6 ++- tests/test_legacy_script_delegation.py | 13 ++++++ tests/test_stage0_to_8_pipeline.py | 60 ++++++++++++++++++++++++++ tests/test_stage4_rewrite.py | 24 +++++++++++ 6 files changed, 132 insertions(+), 1 deletion(-) diff --git a/ai_daily_report/pipeline.py b/ai_daily_report/pipeline.py index e2bc8a9..6f036c5 100644 --- a/ai_daily_report/pipeline.py +++ b/ai_daily_report/pipeline.py @@ -158,6 +158,13 @@ def run_stage0_to_stage7( guide_llm_call=guide_llm_call, ) markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"]) + upstream_blocking_errors: list[str] = [] + for stage_name in ("stage3", "stage4", "stage5", "stage6"): + for error in stage6_result["reports"].get(stage_name, {}).get("blocking_errors", []) or []: + upstream_blocking_errors.append(str(error)) + if upstream_blocking_errors: + existing_errors = list(stage7_report.get("blocking_errors", []) or []) + stage7_report["blocking_errors"] = existing_errors + upstream_blocking_errors reports = dict(stage6_result["reports"]) reports["stage7"] = stage7_report return { diff --git a/ai_daily_report/rewrite.py b/ai_daily_report/rewrite.py index 6bc9063..aa857b8 100644 --- a/ai_daily_report/rewrite.py +++ b/ai_daily_report/rewrite.py @@ -2,6 +2,7 @@ from __future__ import annotations import json from typing import Any, Callable +from urllib.error import HTTPError from .llm import parse_json_object from .models import NewsItem @@ -49,6 +50,14 @@ def _fallback(item: NewsItem) -> None: item.summary = item.summary_raw or "该条目暂无摘要。" +def _is_transient_llm_error(exc: Exception) -> bool: + if isinstance(exc, TimeoutError): + return True + if isinstance(exc, HTTPError): + return exc.code in {429, 500, 502, 503, 504} + return False + + def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int: obj = parse_json_object(llm_call(_build_prompt(batch))) rewrites = obj.get("rewrites", []) @@ -75,6 +84,7 @@ def rewrite_items( *, llm_call: RewriteLlmCall, batch_size: int = 10, + max_fallback_ratio: float = 0.2, ) -> tuple[list[NewsItem], dict[str, Any]]: rewritten_count = 0 fallback_count = 0 @@ -85,6 +95,11 @@ def rewrite_items( rewritten_count += _apply_rewrite_batch(batch, llm_call) except Exception as exc: errors.append(f"batch:{type(exc).__name__}: {exc}") + if _is_transient_llm_error(exc): + for item in batch: + _fallback(item) + fallback_count += 1 + continue for item in batch: try: rewritten_count += _apply_rewrite_batch([item], llm_call) @@ -93,11 +108,19 @@ def rewrite_items( _fallback(item) fallback_count += 1 + fallback_ratio = fallback_count / len(items) if items else 0 + blocking_errors: list[str] = [] + if fallback_ratio > max_fallback_ratio: + blocking_errors.append("rewrite_fallback_ratio_exceeded") + report = { "input_count": len(items), "rewritten_count": rewritten_count, "fallback_count": fallback_count, + "fallback_ratio": round(fallback_ratio, 4), "batch_count": len(_chunks(items, max(1, batch_size))), "errors": errors, + "blocking_errors": blocking_errors, + "quality_gate_failed": bool(blocking_errors), } return items, report diff --git a/script/ai_daily_blog_pipeline.py b/script/ai_daily_blog_pipeline.py index b71003e..caeb4ab 100644 --- a/script/ai_daily_blog_pipeline.py +++ b/script/ai_daily_blog_pipeline.py @@ -37,7 +37,7 @@ def main() -> None: env = load_env() dry_run = is_dry_run(env) - run_daily_report( + result = run_daily_report( run_date=env.get("AI_DAILY_RUN_DATE") or "today", mode="dry-run" if dry_run else env.get("AI_DAILY_MODE", "publish"), source_mode=env.get("AI_DAILY_SOURCE_MODE", "live"), @@ -47,6 +47,10 @@ def main() -> None: sources_path=Path(env["AI_DAILY_SOURCES_PATH"]) if env.get("AI_DAILY_SOURCES_PATH") else None, env=env, ) + stage8 = result.get("reports", {}).get("stage8", {}) + if stage8.get("status") in {"blocked", "failed"}: + print(f"AI daily report failed quality gate: {stage8.get('error') or stage8.get('status')}", file=sys.stderr) + raise SystemExit(2) if __name__ == "__main__": diff --git a/tests/test_legacy_script_delegation.py b/tests/test_legacy_script_delegation.py index 7c24e61..0441089 100644 --- a/tests/test_legacy_script_delegation.py +++ b/tests/test_legacy_script_delegation.py @@ -52,6 +52,19 @@ class LegacyScriptDelegationTests(unittest.TestCase): self.assertEqual(calls[0]["source_mode"], "mock") self.assertEqual(calls[0]["llm_mode"], "mock") + def test_main_exits_nonzero_when_new_pipeline_blocks_publish(self): + module = load_pipeline_module() + + def fake_run_daily_report(**kwargs): + return {"reports": {"stage8": {"status": "blocked", "error": "rewrite_fallback_ratio_exceeded"}}} + + with patch.object(module, "load_env", return_value={}): + with patch("ai_daily_report.runner.run_daily_report", side_effect=fake_run_daily_report): + with self.assertRaises(SystemExit) as raised: + module.main() + + self.assertEqual(raised.exception.code, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_stage0_to_8_pipeline.py b/tests/test_stage0_to_8_pipeline.py index a81861c..68e8550 100644 --- a/tests/test_stage0_to_8_pipeline.py +++ b/tests/test_stage0_to_8_pipeline.py @@ -1,5 +1,6 @@ import json import unittest +from urllib.error import HTTPError from ai_daily_report.pipeline import run_stage0_to_stage8 @@ -74,6 +75,65 @@ class Stage0To8PipelineTests(unittest.TestCase): self.assertIn("stage8", result["reports"]) self.assertEqual(result["reports"]["stage8"]["status"], "ok") + def test_run_stage0_to_stage8_blocks_publish_when_rewrite_quality_gate_fails(self): + configs = [{"name": "AI HOT", "type": "fake", "role": "primary", "priority": 10}] + + def fetcher(config, run_date): + return [ + { + "title_raw": f"News {index}", + "summary_raw": f"Summary {index}", + "url": f"https://example.com/{index}", + "source_label": "Example", + "section_hint": "模型发布/更新", + } + for index in range(6) + ] + + def semantic_llm_call(prompt): + return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []}) + + def rewrite_llm_call(prompt): + raise HTTPError( + url="https://llm.example/v1/chat/completions", + code=503, + msg="Service Unavailable", + hdrs=None, + fp=None, + ) + + def guide_llm_call(prompt): + payload = json.loads(prompt) + return json.dumps( + { + "theme": "模型能力继续更新。", + "threads": [ + { + "title": "模型更新", + "text": "多条模型新闻更新。", + "item_ids": [payload["items"][0]["id"]], + "kind": "thread", + } + ], + } + ) + + result = run_stage0_to_stage8( + configs, + "2026-06-04", + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + guide_llm_call=guide_llm_call, + mode="publish", + base_url="https://blog.example", + client=None, + ) + + self.assertEqual(result["publish"].status, "blocked") + self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage7"]["blocking_errors"]) + self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage8"]["error"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_stage4_rewrite.py b/tests/test_stage4_rewrite.py index 62ef346..c46b6a5 100644 --- a/tests/test_stage4_rewrite.py +++ b/tests/test_stage4_rewrite.py @@ -1,5 +1,6 @@ import json import unittest +from urllib.error import HTTPError from ai_daily_report.models import NewsItem from ai_daily_report.rewrite import rewrite_items @@ -91,6 +92,29 @@ class Stage4RewriteTests(unittest.TestCase): self.assertEqual(report["fallback_count"], 0) self.assertEqual(calls, [["a", "b"], ["a"], ["b"]]) + def test_rewrite_items_does_not_retry_single_items_after_transient_http_error(self): + items = [news_item("a"), news_item("b")] + calls = 0 + + def llm_call(prompt): + nonlocal calls + calls += 1 + raise HTTPError( + url="https://llm.example/v1/chat/completions", + code=503, + msg="Service Unavailable", + hdrs=None, + fp=None, + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=2) + + self.assertEqual(calls, 1) + self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"]) + self.assertEqual(report["fallback_count"], 2) + self.assertTrue(report["quality_gate_failed"]) + self.assertIn("rewrite_fallback_ratio_exceeded", report["blocking_errors"]) + if __name__ == "__main__": unittest.main()