From 22cdd71a084add061df3009ea2d3a70676949319 Mon Sep 17 00:00:00 2001 From: Mimikko-zeus Date: Thu, 4 Jun 2026 17:12:59 +0800 Subject: [PATCH] Improve LLM rewrite classification pipeline --- ai_daily_report/assemble.py | 4 ---- ai_daily_report/classify.py | 13 ++++++++-- ai_daily_report/rewrite.py | 38 +++++++++++++++++++++++++----- tests/test_markdown_rendering.py | 2 +- tests/test_stage0_to_5_pipeline.py | 1 + tests/test_stage0_to_7_pipeline.py | 2 +- tests/test_stage4_rewrite.py | 25 ++++++++++++++++++++ tests/test_stage5_classify.py | 27 +++++++++++++++++++++ tests/test_stage7_assemble.py | 4 ++-- 9 files changed, 100 insertions(+), 16 deletions(-) diff --git a/ai_daily_report/assemble.py b/ai_daily_report/assemble.py index b1dc35f..36aaa32 100644 --- a/ai_daily_report/assemble.py +++ b/ai_daily_report/assemble.py @@ -54,10 +54,6 @@ def assemble_markdown(items: list[NewsItem], guide: dict[str, Any] | None = None intro = _ensure_sentence(str(guide.get("intro") or "")) or _fallback_intro(items) lines.extend(["## 引言", "", f"> {intro}", ""]) - theme = _clean_text(str(guide.get("theme") or "")) - if theme: - lines.extend(["## 导览", "", f"> {_ensure_sentence(theme)}", ""]) - item_number = 1 for section in SECTION_ORDER: section_items = [item for item in items if item.section == section] diff --git a/ai_daily_report/classify.py b/ai_daily_report/classify.py index 4beca1f..92c7268 100644 --- a/ai_daily_report/classify.py +++ b/ai_daily_report/classify.py @@ -75,10 +75,18 @@ def rank_score(item: NewsItem) -> int: def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]: + llm_classified = 0 hint_classified = 0 rule_classified = 0 + invalid_llm_section_count = 0 for item in items: + if item.section: + if item.section in SECTION_ORDER: + llm_classified += 1 + continue + invalid_llm_section_count += 1 + mapped = normalize_section_hint(item.section_hint) if mapped: item.section = mapped @@ -102,8 +110,9 @@ def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dic "section_counts": dict(section_counts), "hint_classified": hint_classified, "rule_classified": rule_classified, - "llm_classified": 0, - "fallback_classified": 0, + "llm_classified": llm_classified, + "fallback_classified": hint_classified + rule_classified, + "invalid_llm_section_count": invalid_llm_section_count, "invalid_section_count": sum(1 for item in ordered if item.section not in SECTION_ORDER), } return ordered, report diff --git a/ai_daily_report/rewrite.py b/ai_daily_report/rewrite.py index c53fd31..acae2b9 100644 --- a/ai_daily_report/rewrite.py +++ b/ai_daily_report/rewrite.py @@ -4,6 +4,7 @@ import json from typing import Any, Callable from urllib.error import HTTPError +from .classify import SECTION_ORDER from .llm import parse_json_object from .models import NewsItem @@ -18,9 +19,21 @@ def _chunks(items: list[NewsItem], size: int) -> list[list[NewsItem]]: def _build_prompt(batch: list[NewsItem]) -> str: payload = { "task": ( - "Rewrite AI news titles and summaries into concise Chinese. Preserve brand/model/API names " - "such as GPT-5, Codex, Gemini, Claude, API, MCP. Do not add facts." + "For each AI news item, translate when needed, rewrite the title and summary into concise Chinese, " + "and classify it into exactly one allowed section. Preserve brand/model/API names such as GPT-5, " + "Codex, Gemini, Claude, API, MCP. Do not add facts." ), + "allowed_sections": SECTION_ORDER, + "section_guidance": { + "模型与能力": "model releases, capability upgrades, modalities, context windows, inference, benchmarks tied to model ability", + "产品与应用": "end-user products, apps, agents, workflows, product launches, practical business or consumer use cases", + "开发与基础设施": "developer tools, APIs, SDKs, MCP, frameworks, deployment, chips, cloud, infra, open source engineering", + "公司与资本": "company strategy, financing, IPO, acquisitions, partnerships, revenue, business competition", + "政策与安全": "policy, regulation, safety, privacy, copyright, misuse, security incidents, governance", + "论文与研究": "papers, academic research, arXiv, methods, experiments, datasets, evaluations", + "观点与教程": "opinions, analysis, explainers, tutorials, guides, practices", + "人物与动态": "people-focused interviews, speeches, career moves, public appearances", + }, "items": [ { "id": item.id, @@ -28,6 +41,7 @@ def _build_prompt(batch: list[NewsItem]) -> str: "summary_raw": item.summary_raw, "source": item.source_label, "language_hint": item.language_hint, + "source_section_hint": item.section_hint, } for item in batch ], @@ -37,6 +51,8 @@ def _build_prompt(batch: list[NewsItem]) -> str: "id": "item id", "title": "display title", "summary": "display summary", + "section": "one allowed section", + "confidence": 0.0, "flags": [], } ] @@ -58,13 +74,14 @@ def _is_transient_llm_error(exc: Exception) -> bool: return False -def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int: +def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]: obj = parse_json_object(llm_call(_build_prompt(batch))) rewrites = obj.get("rewrites", []) if not isinstance(rewrites, list): raise ValueError("rewrites is not a list") by_id = {item.id: item for item in batch} seen_ids: set[str] = set() + section_count = 0 for entry in rewrites: item_id = entry.get("id") title = str(entry.get("title") or "").strip() @@ -72,8 +89,12 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int if item_id in by_id and title and summary: by_id[item_id].title = title by_id[item_id].summary = summary + section = str(entry.get("section") or "").strip() + if section in SECTION_ORDER: + by_id[item_id].section = section + section_count += 1 seen_ids.add(item_id) - return len(seen_ids) + return len(seen_ids), section_count def rewrite_items( @@ -85,14 +106,16 @@ def rewrite_items( retry_single_items: bool = False, ) -> tuple[list[NewsItem], dict[str, Any]]: rewritten_count = 0 + llm_section_count = 0 fallback_count = 0 missing_rewrite_count = 0 errors: list[str] = [] for batch in _chunks(items, max(1, batch_size)): try: - batch_rewritten_count = _apply_rewrite_batch(batch, llm_call) + batch_rewritten_count, batch_section_count = _apply_rewrite_batch(batch, llm_call) rewritten_count += batch_rewritten_count + llm_section_count += batch_section_count for item in batch: if item.title is None or item.summary is None: errors.append(f"missing_rewrite_for_item: {item.id}") @@ -113,7 +136,9 @@ def rewrite_items( continue for item in batch: try: - rewritten_count += _apply_rewrite_batch([item], llm_call) + item_rewritten_count, item_section_count = _apply_rewrite_batch([item], llm_call) + rewritten_count += item_rewritten_count + llm_section_count += item_section_count except Exception as item_exc: errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}") _fallback(item) @@ -127,6 +152,7 @@ def rewrite_items( report = { "input_count": len(items), "rewritten_count": rewritten_count, + "llm_section_count": llm_section_count, "fallback_count": fallback_count, "missing_rewrite_count": missing_rewrite_count, "fallback_ratio": round(fallback_ratio, 4), diff --git a/tests/test_markdown_rendering.py b/tests/test_markdown_rendering.py index 205f379..64e2123 100644 --- a/tests/test_markdown_rendering.py +++ b/tests/test_markdown_rendering.py @@ -27,7 +27,7 @@ class MarkdownRenderingTests(unittest.TestCase): md, _ = assemble_markdown(items, guide) - self.assertIn("## 导览", md) + self.assertNotIn("## 导览", md) self.assertIn("## 模型与能力", md) self.assertIn("[OpenAI:Blog ↗](https://openai.com/blog/test)", md) self.assertNotIn("> >", md) diff --git a/tests/test_stage0_to_5_pipeline.py b/tests/test_stage0_to_5_pipeline.py index 2df7038..1e8c2f7 100644 --- a/tests/test_stage0_to_5_pipeline.py +++ b/tests/test_stage0_to_5_pipeline.py @@ -37,6 +37,7 @@ class Stage0To5PipelineTests(unittest.TestCase): "id": entry["id"], "title": entry["title_raw"], "summary": entry["summary_raw"], + "section": "模型与能力" if "GPT-5" in entry["title_raw"] else "公司与资本", "flags": [], } for entry in payload["items"] diff --git a/tests/test_stage0_to_7_pipeline.py b/tests/test_stage0_to_7_pipeline.py index b86e078..fee9ceb 100644 --- a/tests/test_stage0_to_7_pipeline.py +++ b/tests/test_stage0_to_7_pipeline.py @@ -66,7 +66,7 @@ class Stage0To7PipelineTests(unittest.TestCase): guide_llm_call=guide_llm_call, ) - self.assertIn("## 导览", result["markdown"]) + self.assertNotIn("## 导览", result["markdown"]) self.assertIn("## 模型与能力", result["markdown"]) self.assertIn("## 今日脉络", result["markdown"]) self.assertEqual(result["reports"]["stage7"]["blocking_errors"], []) diff --git a/tests/test_stage4_rewrite.py b/tests/test_stage4_rewrite.py index 21e3201..a9338cf 100644 --- a/tests/test_stage4_rewrite.py +++ b/tests/test_stage4_rewrite.py @@ -48,6 +48,31 @@ class Stage4RewriteTests(unittest.TestCase): self.assertEqual(report["rewritten_count"], 1) self.assertEqual(report["fallback_count"], 0) + def test_rewrite_items_accepts_llm_section_classification(self): + items = [news_item("a")] + + def llm_call(prompt): + return json.dumps( + { + "rewrites": [ + { + "id": "a", + "title": "OpenAI 发布 GPT-5 API", + "summary": "OpenAI 发布 GPT-5 API,延迟表现更好。", + "section": "模型与能力", + "confidence": 0.92, + "flags": [], + } + ] + }, + ensure_ascii=False, + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=10) + + self.assertEqual(rewritten[0].section, "模型与能力") + self.assertEqual(report["llm_section_count"], 1) + def test_rewrite_items_falls_back_when_llm_fails(self): items = [news_item("a")] diff --git a/tests/test_stage5_classify.py b/tests/test_stage5_classify.py index a158ca3..3b66177 100644 --- a/tests/test_stage5_classify.py +++ b/tests/test_stage5_classify.py @@ -45,6 +45,33 @@ class Stage5ClassifyTests(unittest.TestCase): self.assertEqual(by_id["b"].section, "开发与基础设施") self.assertEqual(report["rule_classified"], 2) + def test_classify_prefers_valid_llm_section_from_rewrite_stage(self): + item = news_item( + "a", + "API 发布", + summary="这其实是一个面向开发者的基础设施能力更新。", + section_hint="产品发布/更新", + ) + item.section = "开发与基础设施" + + classified, report = classify_and_order_items([item]) + + self.assertEqual(classified[0].section, "开发与基础设施") + self.assertEqual(report["llm_classified"], 1) + self.assertEqual(report["hint_classified"], 0) + self.assertEqual(report["rule_classified"], 0) + + def test_classify_falls_back_when_llm_section_is_invalid(self): + item = news_item("a", "GPT-5 发布", section_hint="模型发布/更新") + item.section = "热点新闻" + + classified, report = classify_and_order_items([item]) + + self.assertEqual(classified[0].section, "模型与能力") + self.assertEqual(report["llm_classified"], 0) + self.assertEqual(report["hint_classified"], 1) + self.assertEqual(report["invalid_llm_section_count"], 1) + def test_classify_orders_items_by_local_rank_score_within_sections(self): items = [ news_item("low", "普通模型更新", section_hint="模型发布/更新", source_priority=80), diff --git a/tests/test_stage7_assemble.py b/tests/test_stage7_assemble.py index c711ee2..979d49e 100644 --- a/tests/test_stage7_assemble.py +++ b/tests/test_stage7_assemble.py @@ -45,8 +45,8 @@ class Stage7AssembleTests(unittest.TestCase): md, report = assemble_markdown(items, guide) self.assertTrue(md.startswith("## 引言\n\n> 今天的 AI 行业继续围绕模型、产品和资本展开。")) - self.assertIn("## 导览", md) - self.assertIn("> 模型和资本两条线都在推进。", md) + self.assertNotIn("## 导览", md) + self.assertNotIn("> 模型和资本两条线都在推进。", md) self.assertIn("## 模型与能力", md) self.assertIn("**1. GPT-5 API 发布**", md) self.assertIn("**2. Anthropic 提交 IPO 文件**", md)