Refactor AI daily report pipeline
This commit is contained in:
156
ai_daily_report/runner.py
Normal file
156
ai_daily_report/runner.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .clients import BlogApiClient, OpenAICompatibleClient, fetch_text as default_fetch_text
|
||||
from .config import load_source_configs
|
||||
from .env import load_env, resolve_blog_token, resolve_llm_config
|
||||
from .models import SourceConfig
|
||||
from .pipeline import run_stage0_to_stage8
|
||||
from .sources.registry import get_source_fetcher
|
||||
|
||||
|
||||
def _json_default(value: Any):
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
raise TypeError(f"Object is not JSON serializable: {type(value).__name__}")
|
||||
|
||||
|
||||
def _mock_source_configs() -> list[SourceConfig]:
|
||||
return [SourceConfig(name="Mock AI HOT", type="mock", role="primary", priority=10)]
|
||||
|
||||
|
||||
def _mock_fetcher(config: SourceConfig, run_date: str) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"title_raw": "GPT-5 API 发布",
|
||||
"summary_raw": "OpenAI 发布 GPT-5 API,用于本地 mock 测试。",
|
||||
"url": "https://example.com/gpt5",
|
||||
"source_label": "OpenAI:Blog",
|
||||
"section_hint": "模型发布/更新",
|
||||
"origin_type": "mock",
|
||||
"language_hint": "zh",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _mock_semantic_llm(prompt: str) -> str:
|
||||
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _mock_rewrite_llm(prompt: str) -> str:
|
||||
payload = json.loads(prompt)
|
||||
return json.dumps(
|
||||
{
|
||||
"rewrites": [
|
||||
{
|
||||
"id": item["id"],
|
||||
"title": item["title_raw"],
|
||||
"summary": item["summary_raw"],
|
||||
"flags": [],
|
||||
}
|
||||
for item in payload["items"]
|
||||
]
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def _mock_guide_llm(prompt: str) -> str:
|
||||
payload = json.loads(prompt)
|
||||
item_ids = [item["id"] for item in payload["items"][:3]]
|
||||
return json.dumps(
|
||||
{
|
||||
"theme": "本地 mock 模式已生成 AI 日报,用于验证流水线。",
|
||||
"threads": [
|
||||
{
|
||||
"title": "本地链路验证",
|
||||
"text": "采集、改写、分类、导览、Markdown 和发布报告都已通过 mock 数据串联。",
|
||||
"item_ids": item_ids,
|
||||
"kind": "thread",
|
||||
}
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def run_daily_report(
|
||||
*,
|
||||
run_date: str,
|
||||
mode: str,
|
||||
source_mode: str,
|
||||
llm_mode: str,
|
||||
out_dir: Path,
|
||||
base_url: str,
|
||||
sources_path: Path | None = None,
|
||||
fetch_text=None,
|
||||
env: dict[str, str] | None = None,
|
||||
llm_client_factory=OpenAICompatibleClient,
|
||||
blog_client_factory=BlogApiClient,
|
||||
) -> dict[str, Any]:
|
||||
fetch_text = fetch_text or default_fetch_text
|
||||
env = env if env is not None else load_env()
|
||||
|
||||
if source_mode == "mock":
|
||||
source_configs = _mock_source_configs()
|
||||
fetcher = _mock_fetcher
|
||||
elif source_mode == "live":
|
||||
if sources_path is None:
|
||||
sources_path = Path("config") / "sources.json"
|
||||
source_configs = load_source_configs(sources_path)
|
||||
|
||||
def fetcher(config: SourceConfig, current_date: str) -> list[dict[str, Any]]:
|
||||
source_fetcher = get_source_fetcher(config.type)
|
||||
return source_fetcher(config, current_date, fetch_text)
|
||||
|
||||
else:
|
||||
raise ValueError("source_mode must be 'mock' or 'live'")
|
||||
|
||||
if llm_mode == "mock":
|
||||
semantic_llm_call = _mock_semantic_llm
|
||||
rewrite_llm_call = _mock_rewrite_llm
|
||||
guide_llm_call = _mock_guide_llm
|
||||
elif llm_mode == "live":
|
||||
llm_client = llm_client_factory(**resolve_llm_config(env))
|
||||
semantic_llm_call = llm_client.chat
|
||||
rewrite_llm_call = llm_client.chat
|
||||
guide_llm_call = llm_client.chat
|
||||
else:
|
||||
raise ValueError("llm_mode must be 'mock' or 'live'")
|
||||
|
||||
blog_client = None
|
||||
if mode in ("draft", "publish"):
|
||||
token = resolve_blog_token(env)
|
||||
if not token:
|
||||
raise ValueError("missing_blog_token: set BLOG_SERVICE_TOKEN or EPHRON_SERVICE_TOKEN")
|
||||
blog_client = blog_client_factory(base_url=base_url, token=token)
|
||||
|
||||
result = run_stage0_to_stage8(
|
||||
source_configs,
|
||||
run_date,
|
||||
fetcher=fetcher,
|
||||
semantic_llm_call=semantic_llm_call,
|
||||
rewrite_llm_call=rewrite_llm_call,
|
||||
guide_llm_call=guide_llm_call,
|
||||
mode=mode,
|
||||
base_url=base_url,
|
||||
client=blog_client,
|
||||
)
|
||||
|
||||
run_dir = out_dir / run_date
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
(run_dir / "blog_markdown.md").write_text(result["markdown"], encoding="utf-8")
|
||||
(run_dir / "run_report.json").write_text(
|
||||
json.dumps(result["reports"], ensure_ascii=False, indent=2, default=_json_default),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {
|
||||
"run_dir": str(run_dir),
|
||||
"markdown": result["markdown"],
|
||||
"reports": result["reports"],
|
||||
"publish": result["publish"],
|
||||
}
|
||||
Reference in New Issue
Block a user