Files
ai-daily-report/ai_daily_report/runner.py

159 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any
from .clients import BlogApiClient, OpenAICompatibleClient, fetch_text as default_fetch_text
from .config import load_source_configs
from .env import load_env, resolve_blog_token, resolve_llm_config
from .models import SourceConfig
from .pipeline import run_stage0_to_stage8
from .sources.registry import get_source_fetcher
def _json_default(value: Any):
if is_dataclass(value):
return asdict(value)
raise TypeError(f"Object is not JSON serializable: {type(value).__name__}")
def _mock_source_configs() -> list[SourceConfig]:
return [SourceConfig(name="Mock AI HOT", type="mock", role="primary", priority=10)]
def _mock_fetcher(config: SourceConfig, run_date: str) -> list[dict[str, Any]]:
return [
{
"title_raw": "GPT-5 API 发布",
"summary_raw": "OpenAI 发布 GPT-5 API用于本地 mock 测试。",
"url": "https://example.com/gpt5",
"source_label": "OpenAIBlog",
"section_hint": "模型发布/更新",
"origin_type": "mock",
"language_hint": "zh",
}
]
def _mock_semantic_llm(prompt: str) -> str:
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []}, ensure_ascii=False)
def _mock_rewrite_llm(prompt: str) -> str:
payload = json.loads(prompt)
return json.dumps(
{
"rewrites": [
{
"id": item["id"],
"title": item["title_raw"],
"summary": item["summary_raw"],
"flags": [],
}
for item in payload["items"]
]
},
ensure_ascii=False,
)
def _mock_guide_llm(prompt: str) -> str:
payload = json.loads(prompt)
item_ids = [item["id"] for item in payload["items"][:3]]
return json.dumps(
{
"intro": "本地 mock 模式已生成 AI 日报,用于验证流水线。",
"theme": "本地 mock 模式已生成 AI 日报,用于验证流水线。",
"threads": [
{
"title": "本地链路验证",
"text": "采集、改写、分类、导览、Markdown 和发布报告都已通过 mock 数据串联。",
"item_ids": item_ids,
"kind": "thread",
}
],
"conclusion": "本地 mock 结果可用于确认定时任务入口和文件输出是否正常。",
},
ensure_ascii=False,
)
def run_daily_report(
*,
run_date: str,
mode: str,
source_mode: str,
llm_mode: str,
out_dir: Path,
base_url: str,
sources_path: Path | None = None,
fetch_text=None,
env: dict[str, str] | None = None,
llm_client_factory=OpenAICompatibleClient,
blog_client_factory=BlogApiClient,
) -> dict[str, Any]:
fetch_text = fetch_text or default_fetch_text
env = env if env is not None else load_env()
if source_mode == "mock":
source_configs = _mock_source_configs()
fetcher = _mock_fetcher
elif source_mode == "live":
if sources_path is None:
sources_path = Path("config") / "sources.json"
source_configs = load_source_configs(sources_path)
def fetcher(config: SourceConfig, current_date: str) -> list[dict[str, Any]]:
source_fetcher = get_source_fetcher(config.type)
return source_fetcher(config, current_date, fetch_text)
else:
raise ValueError("source_mode must be 'mock' or 'live'")
if llm_mode == "mock":
semantic_llm_call = _mock_semantic_llm
rewrite_llm_call = _mock_rewrite_llm
guide_llm_call = _mock_guide_llm
elif llm_mode == "live":
llm_client = llm_client_factory(**resolve_llm_config(env))
semantic_llm_call = llm_client.chat
rewrite_llm_call = llm_client.chat
guide_llm_call = llm_client.chat
else:
raise ValueError("llm_mode must be 'mock' or 'live'")
blog_client = None
if mode in ("draft", "publish"):
token = resolve_blog_token(env)
if not token:
raise ValueError("missing_blog_token: set BLOG_SERVICE_TOKEN or EPHRON_SERVICE_TOKEN")
blog_client = blog_client_factory(base_url=base_url, token=token)
result = run_stage0_to_stage8(
source_configs,
run_date,
fetcher=fetcher,
semantic_llm_call=semantic_llm_call,
rewrite_llm_call=rewrite_llm_call,
guide_llm_call=guide_llm_call,
mode=mode,
base_url=base_url,
client=blog_client,
)
run_dir = out_dir / run_date
run_dir.mkdir(parents=True, exist_ok=True)
(run_dir / "blog_markdown.md").write_text(result["markdown"], encoding="utf-8")
(run_dir / "run_report.json").write_text(
json.dumps(result["reports"], ensure_ascii=False, indent=2, default=_json_default),
encoding="utf-8",
)
return {
"run_dir": str(run_dir),
"markdown": result["markdown"],
"reports": result["reports"],
"publish": result["publish"],
}