Improve LLM rewrite classification pipeline
This commit is contained in:
@@ -54,10 +54,6 @@ def assemble_markdown(items: list[NewsItem], guide: dict[str, Any] | None = None
|
|||||||
intro = _ensure_sentence(str(guide.get("intro") or "")) or _fallback_intro(items)
|
intro = _ensure_sentence(str(guide.get("intro") or "")) or _fallback_intro(items)
|
||||||
lines.extend(["## 引言", "", f"> {intro}", ""])
|
lines.extend(["## 引言", "", f"> {intro}", ""])
|
||||||
|
|
||||||
theme = _clean_text(str(guide.get("theme") or ""))
|
|
||||||
if theme:
|
|
||||||
lines.extend(["## 导览", "", f"> {_ensure_sentence(theme)}", ""])
|
|
||||||
|
|
||||||
item_number = 1
|
item_number = 1
|
||||||
for section in SECTION_ORDER:
|
for section in SECTION_ORDER:
|
||||||
section_items = [item for item in items if item.section == section]
|
section_items = [item for item in items if item.section == section]
|
||||||
|
|||||||
@@ -75,10 +75,18 @@ def rank_score(item: NewsItem) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]:
|
def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||||
|
llm_classified = 0
|
||||||
hint_classified = 0
|
hint_classified = 0
|
||||||
rule_classified = 0
|
rule_classified = 0
|
||||||
|
invalid_llm_section_count = 0
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
|
if item.section:
|
||||||
|
if item.section in SECTION_ORDER:
|
||||||
|
llm_classified += 1
|
||||||
|
continue
|
||||||
|
invalid_llm_section_count += 1
|
||||||
|
|
||||||
mapped = normalize_section_hint(item.section_hint)
|
mapped = normalize_section_hint(item.section_hint)
|
||||||
if mapped:
|
if mapped:
|
||||||
item.section = mapped
|
item.section = mapped
|
||||||
@@ -102,8 +110,9 @@ def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dic
|
|||||||
"section_counts": dict(section_counts),
|
"section_counts": dict(section_counts),
|
||||||
"hint_classified": hint_classified,
|
"hint_classified": hint_classified,
|
||||||
"rule_classified": rule_classified,
|
"rule_classified": rule_classified,
|
||||||
"llm_classified": 0,
|
"llm_classified": llm_classified,
|
||||||
"fallback_classified": 0,
|
"fallback_classified": hint_classified + rule_classified,
|
||||||
|
"invalid_llm_section_count": invalid_llm_section_count,
|
||||||
"invalid_section_count": sum(1 for item in ordered if item.section not in SECTION_ORDER),
|
"invalid_section_count": sum(1 for item in ordered if item.section not in SECTION_ORDER),
|
||||||
}
|
}
|
||||||
return ordered, report
|
return ordered, report
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
|
from .classify import SECTION_ORDER
|
||||||
from .llm import parse_json_object
|
from .llm import parse_json_object
|
||||||
from .models import NewsItem
|
from .models import NewsItem
|
||||||
|
|
||||||
@@ -18,9 +19,21 @@ def _chunks(items: list[NewsItem], size: int) -> list[list[NewsItem]]:
|
|||||||
def _build_prompt(batch: list[NewsItem]) -> str:
|
def _build_prompt(batch: list[NewsItem]) -> str:
|
||||||
payload = {
|
payload = {
|
||||||
"task": (
|
"task": (
|
||||||
"Rewrite AI news titles and summaries into concise Chinese. Preserve brand/model/API names "
|
"For each AI news item, translate when needed, rewrite the title and summary into concise Chinese, "
|
||||||
"such as GPT-5, Codex, Gemini, Claude, API, MCP. Do not add facts."
|
"and classify it into exactly one allowed section. Preserve brand/model/API names such as GPT-5, "
|
||||||
|
"Codex, Gemini, Claude, API, MCP. Do not add facts."
|
||||||
),
|
),
|
||||||
|
"allowed_sections": SECTION_ORDER,
|
||||||
|
"section_guidance": {
|
||||||
|
"模型与能力": "model releases, capability upgrades, modalities, context windows, inference, benchmarks tied to model ability",
|
||||||
|
"产品与应用": "end-user products, apps, agents, workflows, product launches, practical business or consumer use cases",
|
||||||
|
"开发与基础设施": "developer tools, APIs, SDKs, MCP, frameworks, deployment, chips, cloud, infra, open source engineering",
|
||||||
|
"公司与资本": "company strategy, financing, IPO, acquisitions, partnerships, revenue, business competition",
|
||||||
|
"政策与安全": "policy, regulation, safety, privacy, copyright, misuse, security incidents, governance",
|
||||||
|
"论文与研究": "papers, academic research, arXiv, methods, experiments, datasets, evaluations",
|
||||||
|
"观点与教程": "opinions, analysis, explainers, tutorials, guides, practices",
|
||||||
|
"人物与动态": "people-focused interviews, speeches, career moves, public appearances",
|
||||||
|
},
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
"id": item.id,
|
"id": item.id,
|
||||||
@@ -28,6 +41,7 @@ def _build_prompt(batch: list[NewsItem]) -> str:
|
|||||||
"summary_raw": item.summary_raw,
|
"summary_raw": item.summary_raw,
|
||||||
"source": item.source_label,
|
"source": item.source_label,
|
||||||
"language_hint": item.language_hint,
|
"language_hint": item.language_hint,
|
||||||
|
"source_section_hint": item.section_hint,
|
||||||
}
|
}
|
||||||
for item in batch
|
for item in batch
|
||||||
],
|
],
|
||||||
@@ -37,6 +51,8 @@ def _build_prompt(batch: list[NewsItem]) -> str:
|
|||||||
"id": "item id",
|
"id": "item id",
|
||||||
"title": "display title",
|
"title": "display title",
|
||||||
"summary": "display summary",
|
"summary": "display summary",
|
||||||
|
"section": "one allowed section",
|
||||||
|
"confidence": 0.0,
|
||||||
"flags": [],
|
"flags": [],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -58,13 +74,14 @@ def _is_transient_llm_error(exc: Exception) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int:
|
def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> tuple[int, int]:
|
||||||
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
obj = parse_json_object(llm_call(_build_prompt(batch)))
|
||||||
rewrites = obj.get("rewrites", [])
|
rewrites = obj.get("rewrites", [])
|
||||||
if not isinstance(rewrites, list):
|
if not isinstance(rewrites, list):
|
||||||
raise ValueError("rewrites is not a list")
|
raise ValueError("rewrites is not a list")
|
||||||
by_id = {item.id: item for item in batch}
|
by_id = {item.id: item for item in batch}
|
||||||
seen_ids: set[str] = set()
|
seen_ids: set[str] = set()
|
||||||
|
section_count = 0
|
||||||
for entry in rewrites:
|
for entry in rewrites:
|
||||||
item_id = entry.get("id")
|
item_id = entry.get("id")
|
||||||
title = str(entry.get("title") or "").strip()
|
title = str(entry.get("title") or "").strip()
|
||||||
@@ -72,8 +89,12 @@ def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int
|
|||||||
if item_id in by_id and title and summary:
|
if item_id in by_id and title and summary:
|
||||||
by_id[item_id].title = title
|
by_id[item_id].title = title
|
||||||
by_id[item_id].summary = summary
|
by_id[item_id].summary = summary
|
||||||
|
section = str(entry.get("section") or "").strip()
|
||||||
|
if section in SECTION_ORDER:
|
||||||
|
by_id[item_id].section = section
|
||||||
|
section_count += 1
|
||||||
seen_ids.add(item_id)
|
seen_ids.add(item_id)
|
||||||
return len(seen_ids)
|
return len(seen_ids), section_count
|
||||||
|
|
||||||
|
|
||||||
def rewrite_items(
|
def rewrite_items(
|
||||||
@@ -85,14 +106,16 @@ def rewrite_items(
|
|||||||
retry_single_items: bool = False,
|
retry_single_items: bool = False,
|
||||||
) -> tuple[list[NewsItem], dict[str, Any]]:
|
) -> tuple[list[NewsItem], dict[str, Any]]:
|
||||||
rewritten_count = 0
|
rewritten_count = 0
|
||||||
|
llm_section_count = 0
|
||||||
fallback_count = 0
|
fallback_count = 0
|
||||||
missing_rewrite_count = 0
|
missing_rewrite_count = 0
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
for batch in _chunks(items, max(1, batch_size)):
|
for batch in _chunks(items, max(1, batch_size)):
|
||||||
try:
|
try:
|
||||||
batch_rewritten_count = _apply_rewrite_batch(batch, llm_call)
|
batch_rewritten_count, batch_section_count = _apply_rewrite_batch(batch, llm_call)
|
||||||
rewritten_count += batch_rewritten_count
|
rewritten_count += batch_rewritten_count
|
||||||
|
llm_section_count += batch_section_count
|
||||||
for item in batch:
|
for item in batch:
|
||||||
if item.title is None or item.summary is None:
|
if item.title is None or item.summary is None:
|
||||||
errors.append(f"missing_rewrite_for_item: {item.id}")
|
errors.append(f"missing_rewrite_for_item: {item.id}")
|
||||||
@@ -113,7 +136,9 @@ def rewrite_items(
|
|||||||
continue
|
continue
|
||||||
for item in batch:
|
for item in batch:
|
||||||
try:
|
try:
|
||||||
rewritten_count += _apply_rewrite_batch([item], llm_call)
|
item_rewritten_count, item_section_count = _apply_rewrite_batch([item], llm_call)
|
||||||
|
rewritten_count += item_rewritten_count
|
||||||
|
llm_section_count += item_section_count
|
||||||
except Exception as item_exc:
|
except Exception as item_exc:
|
||||||
errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}")
|
errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}")
|
||||||
_fallback(item)
|
_fallback(item)
|
||||||
@@ -127,6 +152,7 @@ def rewrite_items(
|
|||||||
report = {
|
report = {
|
||||||
"input_count": len(items),
|
"input_count": len(items),
|
||||||
"rewritten_count": rewritten_count,
|
"rewritten_count": rewritten_count,
|
||||||
|
"llm_section_count": llm_section_count,
|
||||||
"fallback_count": fallback_count,
|
"fallback_count": fallback_count,
|
||||||
"missing_rewrite_count": missing_rewrite_count,
|
"missing_rewrite_count": missing_rewrite_count,
|
||||||
"fallback_ratio": round(fallback_ratio, 4),
|
"fallback_ratio": round(fallback_ratio, 4),
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class MarkdownRenderingTests(unittest.TestCase):
|
|||||||
|
|
||||||
md, _ = assemble_markdown(items, guide)
|
md, _ = assemble_markdown(items, guide)
|
||||||
|
|
||||||
self.assertIn("## 导览", md)
|
self.assertNotIn("## 导览", md)
|
||||||
self.assertIn("## 模型与能力", md)
|
self.assertIn("## 模型与能力", md)
|
||||||
self.assertIn("[OpenAI:Blog ↗](https://openai.com/blog/test)", md)
|
self.assertIn("[OpenAI:Blog ↗](https://openai.com/blog/test)", md)
|
||||||
self.assertNotIn("> >", md)
|
self.assertNotIn("> >", md)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class Stage0To5PipelineTests(unittest.TestCase):
|
|||||||
"id": entry["id"],
|
"id": entry["id"],
|
||||||
"title": entry["title_raw"],
|
"title": entry["title_raw"],
|
||||||
"summary": entry["summary_raw"],
|
"summary": entry["summary_raw"],
|
||||||
|
"section": "模型与能力" if "GPT-5" in entry["title_raw"] else "公司与资本",
|
||||||
"flags": [],
|
"flags": [],
|
||||||
}
|
}
|
||||||
for entry in payload["items"]
|
for entry in payload["items"]
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class Stage0To7PipelineTests(unittest.TestCase):
|
|||||||
guide_llm_call=guide_llm_call,
|
guide_llm_call=guide_llm_call,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIn("## 导览", result["markdown"])
|
self.assertNotIn("## 导览", result["markdown"])
|
||||||
self.assertIn("## 模型与能力", result["markdown"])
|
self.assertIn("## 模型与能力", result["markdown"])
|
||||||
self.assertIn("## 今日脉络", result["markdown"])
|
self.assertIn("## 今日脉络", result["markdown"])
|
||||||
self.assertEqual(result["reports"]["stage7"]["blocking_errors"], [])
|
self.assertEqual(result["reports"]["stage7"]["blocking_errors"], [])
|
||||||
|
|||||||
@@ -48,6 +48,31 @@ class Stage4RewriteTests(unittest.TestCase):
|
|||||||
self.assertEqual(report["rewritten_count"], 1)
|
self.assertEqual(report["rewritten_count"], 1)
|
||||||
self.assertEqual(report["fallback_count"], 0)
|
self.assertEqual(report["fallback_count"], 0)
|
||||||
|
|
||||||
|
def test_rewrite_items_accepts_llm_section_classification(self):
|
||||||
|
items = [news_item("a")]
|
||||||
|
|
||||||
|
def llm_call(prompt):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"rewrites": [
|
||||||
|
{
|
||||||
|
"id": "a",
|
||||||
|
"title": "OpenAI 发布 GPT-5 API",
|
||||||
|
"summary": "OpenAI 发布 GPT-5 API,延迟表现更好。",
|
||||||
|
"section": "模型与能力",
|
||||||
|
"confidence": 0.92,
|
||||||
|
"flags": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=10)
|
||||||
|
|
||||||
|
self.assertEqual(rewritten[0].section, "模型与能力")
|
||||||
|
self.assertEqual(report["llm_section_count"], 1)
|
||||||
|
|
||||||
def test_rewrite_items_falls_back_when_llm_fails(self):
|
def test_rewrite_items_falls_back_when_llm_fails(self):
|
||||||
items = [news_item("a")]
|
items = [news_item("a")]
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,33 @@ class Stage5ClassifyTests(unittest.TestCase):
|
|||||||
self.assertEqual(by_id["b"].section, "开发与基础设施")
|
self.assertEqual(by_id["b"].section, "开发与基础设施")
|
||||||
self.assertEqual(report["rule_classified"], 2)
|
self.assertEqual(report["rule_classified"], 2)
|
||||||
|
|
||||||
|
def test_classify_prefers_valid_llm_section_from_rewrite_stage(self):
|
||||||
|
item = news_item(
|
||||||
|
"a",
|
||||||
|
"API 发布",
|
||||||
|
summary="这其实是一个面向开发者的基础设施能力更新。",
|
||||||
|
section_hint="产品发布/更新",
|
||||||
|
)
|
||||||
|
item.section = "开发与基础设施"
|
||||||
|
|
||||||
|
classified, report = classify_and_order_items([item])
|
||||||
|
|
||||||
|
self.assertEqual(classified[0].section, "开发与基础设施")
|
||||||
|
self.assertEqual(report["llm_classified"], 1)
|
||||||
|
self.assertEqual(report["hint_classified"], 0)
|
||||||
|
self.assertEqual(report["rule_classified"], 0)
|
||||||
|
|
||||||
|
def test_classify_falls_back_when_llm_section_is_invalid(self):
|
||||||
|
item = news_item("a", "GPT-5 发布", section_hint="模型发布/更新")
|
||||||
|
item.section = "热点新闻"
|
||||||
|
|
||||||
|
classified, report = classify_and_order_items([item])
|
||||||
|
|
||||||
|
self.assertEqual(classified[0].section, "模型与能力")
|
||||||
|
self.assertEqual(report["llm_classified"], 0)
|
||||||
|
self.assertEqual(report["hint_classified"], 1)
|
||||||
|
self.assertEqual(report["invalid_llm_section_count"], 1)
|
||||||
|
|
||||||
def test_classify_orders_items_by_local_rank_score_within_sections(self):
|
def test_classify_orders_items_by_local_rank_score_within_sections(self):
|
||||||
items = [
|
items = [
|
||||||
news_item("low", "普通模型更新", section_hint="模型发布/更新", source_priority=80),
|
news_item("low", "普通模型更新", section_hint="模型发布/更新", source_priority=80),
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ class Stage7AssembleTests(unittest.TestCase):
|
|||||||
md, report = assemble_markdown(items, guide)
|
md, report = assemble_markdown(items, guide)
|
||||||
|
|
||||||
self.assertTrue(md.startswith("## 引言\n\n> 今天的 AI 行业继续围绕模型、产品和资本展开。"))
|
self.assertTrue(md.startswith("## 引言\n\n> 今天的 AI 行业继续围绕模型、产品和资本展开。"))
|
||||||
self.assertIn("## 导览", md)
|
self.assertNotIn("## 导览", md)
|
||||||
self.assertIn("> 模型和资本两条线都在推进。", md)
|
self.assertNotIn("> 模型和资本两条线都在推进。", md)
|
||||||
self.assertIn("## 模型与能力", md)
|
self.assertIn("## 模型与能力", md)
|
||||||
self.assertIn("**1. GPT-5 API 发布**", md)
|
self.assertIn("**1. GPT-5 API 发布**", md)
|
||||||
self.assertIn("**2. Anthropic 提交 IPO 文件**", md)
|
self.assertIn("**2. Anthropic 提交 IPO 文件**", md)
|
||||||
|
|||||||
Reference in New Issue
Block a user