Block publish when LLM rewrite quality degrades
This commit is contained in:
@@ -158,6 +158,13 @@ def run_stage0_to_stage7(
|
||||
guide_llm_call=guide_llm_call,
|
||||
)
|
||||
markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"])
|
||||
upstream_blocking_errors: list[str] = []
|
||||
for stage_name in ("stage3", "stage4", "stage5", "stage6"):
|
||||
for error in stage6_result["reports"].get(stage_name, {}).get("blocking_errors", []) or []:
|
||||
upstream_blocking_errors.append(str(error))
|
||||
if upstream_blocking_errors:
|
||||
existing_errors = list(stage7_report.get("blocking_errors", []) or [])
|
||||
stage7_report["blocking_errors"] = existing_errors + upstream_blocking_errors
|
||||
reports = dict(stage6_result["reports"])
|
||||
reports["stage7"] = stage7_report
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .llm import parse_json_object
|
||||
from .models import NewsItem
|
||||
@@ -49,6 +50,14 @@ def _fallback(item: NewsItem) -> None:
|
||||
item.summary = item.summary_raw or "该条目暂无摘要。"
|
||||
|
||||
|
||||
def _is_transient_llm_error(exc: Exception) -> bool:
|
||||
if isinstance(exc, TimeoutError):
|
||||
return True
|
||||
if isinstance(exc, HTTPError):
|
||||
return exc.code in {429, 500, 502, 503, 504}
|
||||
return False
|
||||
|
||||
|
||||
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
|
||||
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
||||
rewrites = obj.get("rewrites", [])
|
||||
@@ -75,6 +84,7 @@ def rewrite_items(
|
||||
*,
|
||||
llm_call: RewriteLlmCall,
|
||||
batch_size: int = 10,
|
||||
max_fallback_ratio: float = 0.2,
|
||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||
rewritten_count = 0
|
||||
fallback_count = 0
|
||||
@@ -85,6 +95,11 @@ def rewrite_items(
|
||||
rewritten_count += _apply_rewrite_batch(batch, llm_call)
|
||||
except Exception as exc:
|
||||
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
||||
if _is_transient_llm_error(exc):
|
||||
for item in batch:
|
||||
_fallback(item)
|
||||
fallback_count += 1
|
||||
continue
|
||||
for item in batch:
|
||||
try:
|
||||
rewritten_count += _apply_rewrite_batch([item], llm_call)
|
||||
@@ -93,11 +108,19 @@ def rewrite_items(
|
||||
_fallback(item)
|
||||
fallback_count += 1
|
||||
|
||||
fallback_ratio = fallback_count / len(items) if items else 0
|
||||
blocking_errors: list[str] = []
|
||||
if fallback_ratio > max_fallback_ratio:
|
||||
blocking_errors.append("rewrite_fallback_ratio_exceeded")
|
||||
|
||||
report = {
|
||||
"input_count": len(items),
|
||||
"rewritten_count": rewritten_count,
|
||||
"fallback_count": fallback_count,
|
||||
"fallback_ratio": round(fallback_ratio, 4),
|
||||
"batch_count": len(_chunks(items, max(1, batch_size))),
|
||||
"errors": errors,
|
||||
"blocking_errors": blocking_errors,
|
||||
"quality_gate_failed": bool(blocking_errors),
|
||||
}
|
||||
return items, report
|
||||
|
||||
@@ -37,7 +37,7 @@ def main() -> None:
|
||||
|
||||
env = load_env()
|
||||
dry_run = is_dry_run(env)
|
||||
run_daily_report(
|
||||
result = run_daily_report(
|
||||
run_date=env.get("AI_DAILY_RUN_DATE") or "today",
|
||||
mode="dry-run" if dry_run else env.get("AI_DAILY_MODE", "publish"),
|
||||
source_mode=env.get("AI_DAILY_SOURCE_MODE", "live"),
|
||||
@@ -47,6 +47,10 @@ def main() -> None:
|
||||
sources_path=Path(env["AI_DAILY_SOURCES_PATH"]) if env.get("AI_DAILY_SOURCES_PATH") else None,
|
||||
env=env,
|
||||
)
|
||||
stage8 = result.get("reports", {}).get("stage8", {})
|
||||
if stage8.get("status") in {"blocked", "failed"}:
|
||||
print(f"AI daily report failed quality gate: {stage8.get('error') or stage8.get('status')}", file=sys.stderr)
|
||||
raise SystemExit(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -52,6 +52,19 @@ class LegacyScriptDelegationTests(unittest.TestCase):
|
||||
self.assertEqual(calls[0]["source_mode"], "mock")
|
||||
self.assertEqual(calls[0]["llm_mode"], "mock")
|
||||
|
||||
def test_main_exits_nonzero_when_new_pipeline_blocks_publish(self):
|
||||
module = load_pipeline_module()
|
||||
|
||||
def fake_run_daily_report(**kwargs):
|
||||
return {"reports": {"stage8": {"status": "blocked", "error": "rewrite_fallback_ratio_exceeded"}}}
|
||||
|
||||
with patch.object(module, "load_env", return_value={}):
|
||||
with patch("ai_daily_report.runner.run_daily_report", side_effect=fake_run_daily_report):
|
||||
with self.assertRaises(SystemExit) as raised:
|
||||
module.main()
|
||||
|
||||
self.assertEqual(raised.exception.code, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import unittest
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from ai_daily_report.pipeline import run_stage0_to_stage8
|
||||
|
||||
@@ -74,6 +75,65 @@ class Stage0To8PipelineTests(unittest.TestCase):
|
||||
self.assertIn("stage8", result["reports"])
|
||||
self.assertEqual(result["reports"]["stage8"]["status"], "ok")
|
||||
|
||||
def test_run_stage0_to_stage8_blocks_publish_when_rewrite_quality_gate_fails(self):
|
||||
configs = [{"name": "AI HOT", "type": "fake", "role": "primary", "priority": 10}]
|
||||
|
||||
def fetcher(config, run_date):
|
||||
return [
|
||||
{
|
||||
"title_raw": f"News {index}",
|
||||
"summary_raw": f"Summary {index}",
|
||||
"url": f"https://example.com/{index}",
|
||||
"source_label": "Example",
|
||||
"section_hint": "模型发布/更新",
|
||||
}
|
||||
for index in range(6)
|
||||
]
|
||||
|
||||
def semantic_llm_call(prompt):
|
||||
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []})
|
||||
|
||||
def rewrite_llm_call(prompt):
|
||||
raise HTTPError(
|
||||
url="https://llm.example/v1/chat/completions",
|
||||
code=503,
|
||||
msg="Service Unavailable",
|
||||
hdrs=None,
|
||||
fp=None,
|
||||
)
|
||||
|
||||
def guide_llm_call(prompt):
|
||||
payload = json.loads(prompt)
|
||||
return json.dumps(
|
||||
{
|
||||
"theme": "模型能力继续更新。",
|
||||
"threads": [
|
||||
{
|
||||
"title": "模型更新",
|
||||
"text": "多条模型新闻更新。",
|
||||
"item_ids": [payload["items"][0]["id"]],
|
||||
"kind": "thread",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = run_stage0_to_stage8(
|
||||
configs,
|
||||
"2026-06-04",
|
||||
fetcher=fetcher,
|
||||
semantic_llm_call=semantic_llm_call,
|
||||
rewrite_llm_call=rewrite_llm_call,
|
||||
guide_llm_call=guide_llm_call,
|
||||
mode="publish",
|
||||
base_url="https://blog.example",
|
||||
client=None,
|
||||
)
|
||||
|
||||
self.assertEqual(result["publish"].status, "blocked")
|
||||
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage7"]["blocking_errors"])
|
||||
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage8"]["error"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import unittest
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from ai_daily_report.models import NewsItem
|
||||
from ai_daily_report.rewrite import rewrite_items
|
||||
@@ -91,6 +92,29 @@ class Stage4RewriteTests(unittest.TestCase):
|
||||
self.assertEqual(report["fallback_count"], 0)
|
||||
self.assertEqual(calls, [["a", "b"], ["a"], ["b"]])
|
||||
|
||||
def test_rewrite_items_does_not_retry_single_items_after_transient_http_error(self):
|
||||
items = [news_item("a"), news_item("b")]
|
||||
calls = 0
|
||||
|
||||
def llm_call(prompt):
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise HTTPError(
|
||||
url="https://llm.example/v1/chat/completions",
|
||||
code=503,
|
||||
msg="Service Unavailable",
|
||||
hdrs=None,
|
||||
fp=None,
|
||||
)
|
||||
|
||||
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=2)
|
||||
|
||||
self.assertEqual(calls, 1)
|
||||
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
||||
self.assertEqual(report["fallback_count"], 2)
|
||||
self.assertTrue(report["quality_gate_failed"])
|
||||
self.assertIn("rewrite_fallback_ratio_exceeded", report["blocking_errors"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user