Block publish when LLM rewrite quality degrades
This commit is contained in:
@@ -158,6 +158,13 @@ def run_stage0_to_stage7(
|
|||||||
guide_llm_call=guide_llm_call,
|
guide_llm_call=guide_llm_call,
|
||||||
)
|
)
|
||||||
markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"])
|
markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"])
|
||||||
|
upstream_blocking_errors: list[str] = []
|
||||||
|
for stage_name in ("stage3", "stage4", "stage5", "stage6"):
|
||||||
|
for error in stage6_result["reports"].get(stage_name, {}).get("blocking_errors", []) or []:
|
||||||
|
upstream_blocking_errors.append(str(error))
|
||||||
|
if upstream_blocking_errors:
|
||||||
|
existing_errors = list(stage7_report.get("blocking_errors", []) or [])
|
||||||
|
stage7_report["blocking_errors"] = existing_errors + upstream_blocking_errors
|
||||||
reports = dict(stage6_result["reports"])
|
reports = dict(stage6_result["reports"])
|
||||||
reports["stage7"] = stage7_report
|
reports["stage7"] = stage7_report
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from .llm import parse_json_object
|
from .llm import parse_json_object
|
||||||
from .models import NewsItem
|
from .models import NewsItem
|
||||||
@@ -49,6 +50,14 @@ def _fallback(item: NewsItem) -> None:
|
|||||||
item.summary = item.summary_raw or "该条目暂无摘要。"
|
item.summary = item.summary_raw or "该条目暂无摘要。"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_transient_llm_error(exc: Exception) -> bool:
|
||||||
|
if isinstance(exc, TimeoutError):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, HTTPError):
|
||||||
|
return exc.code in {429, 500, 502, 503, 504}
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
|
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
|
||||||
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
||||||
rewrites = obj.get("rewrites", [])
|
rewrites = obj.get("rewrites", [])
|
||||||
@@ -75,6 +84,7 @@ def rewrite_items(
|
|||||||
*,
|
*,
|
||||||
llm_call: RewriteLlmCall,
|
llm_call: RewriteLlmCall,
|
||||||
batch_size: int = 10,
|
batch_size: int = 10,
|
||||||
|
max_fallback_ratio: float = 0.2,
|
||||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||||
rewritten_count = 0
|
rewritten_count = 0
|
||||||
fallback_count = 0
|
fallback_count = 0
|
||||||
@@ -85,6 +95,11 @@ def rewrite_items(
|
|||||||
rewritten_count += _apply_rewrite_batch(batch, llm_call)
|
rewritten_count += _apply_rewrite_batch(batch, llm_call)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
errors.append(f"batch:{type(exc).__name__}: {exc}")
|
||||||
|
if _is_transient_llm_error(exc):
|
||||||
|
for item in batch:
|
||||||
|
_fallback(item)
|
||||||
|
fallback_count += 1
|
||||||
|
continue
|
||||||
for item in batch:
|
for item in batch:
|
||||||
try:
|
try:
|
||||||
rewritten_count += _apply_rewrite_batch([item], llm_call)
|
rewritten_count += _apply_rewrite_batch([item], llm_call)
|
||||||
@@ -93,11 +108,19 @@ def rewrite_items(
|
|||||||
_fallback(item)
|
_fallback(item)
|
||||||
fallback_count += 1
|
fallback_count += 1
|
||||||
|
|
||||||
|
fallback_ratio = fallback_count / len(items) if items else 0
|
||||||
|
blocking_errors: list[str] = []
|
||||||
|
if fallback_ratio > max_fallback_ratio:
|
||||||
|
blocking_errors.append("rewrite_fallback_ratio_exceeded")
|
||||||
|
|
||||||
report = {
|
report = {
|
||||||
"input_count": len(items),
|
"input_count": len(items),
|
||||||
"rewritten_count": rewritten_count,
|
"rewritten_count": rewritten_count,
|
||||||
"fallback_count": fallback_count,
|
"fallback_count": fallback_count,
|
||||||
|
"fallback_ratio": round(fallback_ratio, 4),
|
||||||
"batch_count": len(_chunks(items, max(1, batch_size))),
|
"batch_count": len(_chunks(items, max(1, batch_size))),
|
||||||
"errors": errors,
|
"errors": errors,
|
||||||
|
"blocking_errors": blocking_errors,
|
||||||
|
"quality_gate_failed": bool(blocking_errors),
|
||||||
}
|
}
|
||||||
return items, report
|
return items, report
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def main() -> None:
|
|||||||
|
|
||||||
env = load_env()
|
env = load_env()
|
||||||
dry_run = is_dry_run(env)
|
dry_run = is_dry_run(env)
|
||||||
run_daily_report(
|
result = run_daily_report(
|
||||||
run_date=env.get("AI_DAILY_RUN_DATE") or "today",
|
run_date=env.get("AI_DAILY_RUN_DATE") or "today",
|
||||||
mode="dry-run" if dry_run else env.get("AI_DAILY_MODE", "publish"),
|
mode="dry-run" if dry_run else env.get("AI_DAILY_MODE", "publish"),
|
||||||
source_mode=env.get("AI_DAILY_SOURCE_MODE", "live"),
|
source_mode=env.get("AI_DAILY_SOURCE_MODE", "live"),
|
||||||
@@ -47,6 +47,10 @@ def main() -> None:
|
|||||||
sources_path=Path(env["AI_DAILY_SOURCES_PATH"]) if env.get("AI_DAILY_SOURCES_PATH") else None,
|
sources_path=Path(env["AI_DAILY_SOURCES_PATH"]) if env.get("AI_DAILY_SOURCES_PATH") else None,
|
||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
|
stage8 = result.get("reports", {}).get("stage8", {})
|
||||||
|
if stage8.get("status") in {"blocked", "failed"}:
|
||||||
|
print(f"AI daily report failed quality gate: {stage8.get('error') or stage8.get('status')}", file=sys.stderr)
|
||||||
|
raise SystemExit(2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -52,6 +52,19 @@ class LegacyScriptDelegationTests(unittest.TestCase):
|
|||||||
self.assertEqual(calls[0]["source_mode"], "mock")
|
self.assertEqual(calls[0]["source_mode"], "mock")
|
||||||
self.assertEqual(calls[0]["llm_mode"], "mock")
|
self.assertEqual(calls[0]["llm_mode"], "mock")
|
||||||
|
|
||||||
|
def test_main_exits_nonzero_when_new_pipeline_blocks_publish(self):
|
||||||
|
module = load_pipeline_module()
|
||||||
|
|
||||||
|
def fake_run_daily_report(**kwargs):
|
||||||
|
return {"reports": {"stage8": {"status": "blocked", "error": "rewrite_fallback_ratio_exceeded"}}}
|
||||||
|
|
||||||
|
with patch.object(module, "load_env", return_value={}):
|
||||||
|
with patch("ai_daily_report.runner.run_daily_report", side_effect=fake_run_daily_report):
|
||||||
|
with self.assertRaises(SystemExit) as raised:
|
||||||
|
module.main()
|
||||||
|
|
||||||
|
self.assertEqual(raised.exception.code, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from ai_daily_report.pipeline import run_stage0_to_stage8
|
from ai_daily_report.pipeline import run_stage0_to_stage8
|
||||||
|
|
||||||
@@ -74,6 +75,65 @@ class Stage0To8PipelineTests(unittest.TestCase):
|
|||||||
self.assertIn("stage8", result["reports"])
|
self.assertIn("stage8", result["reports"])
|
||||||
self.assertEqual(result["reports"]["stage8"]["status"], "ok")
|
self.assertEqual(result["reports"]["stage8"]["status"], "ok")
|
||||||
|
|
||||||
|
def test_run_stage0_to_stage8_blocks_publish_when_rewrite_quality_gate_fails(self):
|
||||||
|
configs = [{"name": "AI HOT", "type": "fake", "role": "primary", "priority": 10}]
|
||||||
|
|
||||||
|
def fetcher(config, run_date):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"title_raw": f"News {index}",
|
||||||
|
"summary_raw": f"Summary {index}",
|
||||||
|
"url": f"https://example.com/{index}",
|
||||||
|
"source_label": "Example",
|
||||||
|
"section_hint": "模型发布/更新",
|
||||||
|
}
|
||||||
|
for index in range(6)
|
||||||
|
]
|
||||||
|
|
||||||
|
def semantic_llm_call(prompt):
|
||||||
|
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []})
|
||||||
|
|
||||||
|
def rewrite_llm_call(prompt):
|
||||||
|
raise HTTPError(
|
||||||
|
url="https://llm.example/v1/chat/completions",
|
||||||
|
code=503,
|
||||||
|
msg="Service Unavailable",
|
||||||
|
hdrs=None,
|
||||||
|
fp=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def guide_llm_call(prompt):
|
||||||
|
payload = json.loads(prompt)
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"theme": "模型能力继续更新。",
|
||||||
|
"threads": [
|
||||||
|
{
|
||||||
|
"title": "模型更新",
|
||||||
|
"text": "多条模型新闻更新。",
|
||||||
|
"item_ids": [payload["items"][0]["id"]],
|
||||||
|
"kind": "thread",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = run_stage0_to_stage8(
|
||||||
|
configs,
|
||||||
|
"2026-06-04",
|
||||||
|
fetcher=fetcher,
|
||||||
|
semantic_llm_call=semantic_llm_call,
|
||||||
|
rewrite_llm_call=rewrite_llm_call,
|
||||||
|
guide_llm_call=guide_llm_call,
|
||||||
|
mode="publish",
|
||||||
|
base_url="https://blog.example",
|
||||||
|
client=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result["publish"].status, "blocked")
|
||||||
|
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage7"]["blocking_errors"])
|
||||||
|
self.assertIn("rewrite_fallback_ratio_exceeded", result["reports"]["stage8"]["error"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from ai_daily_report.models import NewsItem
|
from ai_daily_report.models import NewsItem
|
||||||
from ai_daily_report.rewrite import rewrite_items
|
from ai_daily_report.rewrite import rewrite_items
|
||||||
@@ -91,6 +92,29 @@ class Stage4RewriteTests(unittest.TestCase):
|
|||||||
self.assertEqual(report["fallback_count"], 0)
|
self.assertEqual(report["fallback_count"], 0)
|
||||||
self.assertEqual(calls, [["a", "b"], ["a"], ["b"]])
|
self.assertEqual(calls, [["a", "b"], ["a"], ["b"]])
|
||||||
|
|
||||||
|
def test_rewrite_items_does_not_retry_single_items_after_transient_http_error(self):
|
||||||
|
items = [news_item("a"), news_item("b")]
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
def llm_call(prompt):
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
raise HTTPError(
|
||||||
|
url="https://llm.example/v1/chat/completions",
|
||||||
|
code=503,
|
||||||
|
msg="Service Unavailable",
|
||||||
|
hdrs=None,
|
||||||
|
fp=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=2)
|
||||||
|
|
||||||
|
self.assertEqual(calls, 1)
|
||||||
|
self.assertEqual([item.title for item in rewritten], ["OpenAI launches GPT-5 API", "OpenAI launches GPT-5 API"])
|
||||||
|
self.assertEqual(report["fallback_count"], 2)
|
||||||
|
self.assertTrue(report["quality_gate_failed"])
|
||||||
|
self.assertIn("rewrite_fallback_ratio_exceeded", report["blocking_errors"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user