Files
ai-daily-report/ai_daily_report/runner.py

226 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any
from .clients import BlogApiClient, OpenAICompatibleClient, fetch_text as default_fetch_text
from .config import load_pipeline_config, load_source_configs
from .env import load_env, resolve_blog_token, resolve_llm_config
from .models import SourceConfig
from .observability import LlmCallObserver, summarize_observed_calls
from .pipeline import run_stage0_to_stage8
from .publish import load_published_urls, update_published_urls
from .sources.registry import get_source_fetcher
def _json_default(value: Any):
if is_dataclass(value):
return asdict(value)
raise TypeError(f"Object is not JSON serializable: {type(value).__name__}")
def _mock_source_configs() -> list[SourceConfig]:
return [SourceConfig(name="Mock AI HOT", type="mock", role="primary", priority=10)]
def _mock_fetcher(config: SourceConfig, run_date: str) -> list[dict[str, Any]]:
return [
{
"title_raw": "GPT-5 API 发布",
"summary_raw": "OpenAI 发布 GPT-5 API用于本地 mock 测试。",
"url": "https://example.com/gpt5",
"source_label": "OpenAIBlog",
"section_hint": "模型发布/更新",
"origin_type": "mock",
"language_hint": "zh",
}
]
def _mock_semantic_llm(prompt: str) -> str:
return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []}, ensure_ascii=False)
def _mock_rewrite_llm(prompt: str) -> str:
payload = json.loads(prompt)
return json.dumps(
{
"rewrites": [
{
"id": item["id"],
"title": item["title_raw"],
"summary": item["summary_raw"],
"flags": [],
}
for item in payload["items"]
]
},
ensure_ascii=False,
)
def _mock_guide_llm(prompt: str) -> str:
payload = json.loads(prompt)
item_ids = [item["id"] for item in payload["items"][:3]]
return json.dumps(
{
"intro": "本地 mock 模式已生成 AI 日报,用于验证流水线。",
"theme": "本地 mock 模式已生成 AI 日报,用于验证流水线。",
"threads": [
{
"title": "本地链路验证",
"text": "采集、改写、分类、导览、Markdown 和发布报告都已通过 mock 数据串联。",
"item_ids": item_ids,
"kind": "thread",
}
],
"conclusion": "本地 mock 结果可用于确认定时任务入口和文件输出是否正常。",
},
ensure_ascii=False,
)
def run_daily_report(
*,
run_date: str,
mode: str,
source_mode: str,
llm_mode: str,
out_dir: Path,
base_url: str,
sources_path: Path | None = None,
pipeline_path: Path | None = None,
history_path: Path | None = None,
fetch_text=None,
env: dict[str, str] | None = None,
llm_client_factory=OpenAICompatibleClient,
blog_client_factory=BlogApiClient,
) -> dict[str, Any]:
fetch_text = fetch_text or default_fetch_text
env = env if env is not None else load_env()
pipeline_config_path = pipeline_path or Path("config") / "pipeline.json"
pipeline_config = load_pipeline_config(pipeline_config_path)
cross_day_config = pipeline_config.get("cross_day_dedup", {}) or {}
cross_day_enabled = bool(cross_day_config.get("enabled", True))
cross_day_max_age_days = int(cross_day_config.get("max_age_days", 7))
semantic_dedup_max_deletion_ratio = float(pipeline_config.get("semantic_dedup_max_deletion_ratio", 0.5))
rewrite_batch_size = int(pipeline_config.get("rewrite_batch_size", 30))
semantic_candidate_recall_config = pipeline_config.get("semantic_candidate_recall", {}) or {}
quality_gate_config = pipeline_config.get("quality_gate", {}) or {}
publish_idempotency_config = pipeline_config.get("publish_idempotency", {}) or {}
configured_history_path = history_path or Path(
str(cross_day_config.get("history_path") or "~/.hermes/scripts/ai_morning_out/published_urls.json")
).expanduser()
published_urls = load_published_urls(configured_history_path) if cross_day_enabled else None
if source_mode == "mock":
source_configs = _mock_source_configs()
fetcher = _mock_fetcher
elif source_mode == "live":
if sources_path is None:
sources_path = Path("config") / "sources.json"
source_configs = load_source_configs(sources_path)
def fetcher(config: SourceConfig, current_date: str) -> list[dict[str, Any]]:
source_fetcher = get_source_fetcher(config.type)
def configured_fetch_text(url: str, timeout_seconds: int) -> str:
try:
return fetch_text(url, timeout_seconds, retries=config.retries)
except TypeError:
return fetch_text(url, timeout_seconds)
return source_fetcher(config, current_date, configured_fetch_text)
else:
raise ValueError("source_mode must be 'mock' or 'live'")
llm_observability_config = pipeline_config.get("llm_observability", {}) or {}
llm_observers: list[LlmCallObserver] = []
observe_llm = bool(llm_observability_config.get("enabled", True))
prompt_preview_chars = int(llm_observability_config.get("prompt_preview_chars", 500))
response_preview_chars = int(llm_observability_config.get("response_preview_chars", 500))
def maybe_observe(stage: str, call):
if not observe_llm:
return call
observer = LlmCallObserver(
call=call,
stage=stage,
prompt_preview_chars=prompt_preview_chars,
response_preview_chars=response_preview_chars,
)
llm_observers.append(observer)
return observer
if llm_mode == "mock":
semantic_llm_call = maybe_observe("stage3", _mock_semantic_llm)
rewrite_llm_call = maybe_observe("stage4", _mock_rewrite_llm)
guide_llm_call = maybe_observe("stage6", _mock_guide_llm)
elif llm_mode == "live":
llm_client = llm_client_factory(**resolve_llm_config(env))
semantic_llm_call = maybe_observe("stage3", llm_client.chat)
rewrite_llm_call = maybe_observe("stage4", llm_client.chat)
guide_llm_call = maybe_observe("stage6", llm_client.chat)
else:
raise ValueError("llm_mode must be 'mock' or 'live'")
blog_client = None
if mode in ("draft", "publish"):
token = resolve_blog_token(env)
if not token:
raise ValueError("missing_blog_token: set BLOG_SERVICE_TOKEN or EPHRON_SERVICE_TOKEN")
blog_client = blog_client_factory(base_url=base_url, token=token)
result = run_stage0_to_stage8(
source_configs,
run_date,
fetcher=fetcher,
semantic_llm_call=semantic_llm_call,
rewrite_llm_call=rewrite_llm_call,
guide_llm_call=guide_llm_call,
mode=mode,
base_url=base_url,
client=blog_client,
published_urls=published_urls,
cross_day_dedup_enabled=cross_day_enabled,
cross_day_dedup_max_age_days=cross_day_max_age_days,
semantic_dedup_max_deletion_ratio=semantic_dedup_max_deletion_ratio,
rewrite_batch_size=rewrite_batch_size,
semantic_candidate_recall_config=semantic_candidate_recall_config,
quality_gate_config=quality_gate_config,
publish_idempotency_config=publish_idempotency_config,
)
if cross_day_enabled and result["publish"].mode == "publish" and result["publish"].status == "ok":
update_published_urls(
configured_history_path,
result["items"],
run_date=run_date,
max_age_days=cross_day_max_age_days,
)
llm_observability_report = summarize_observed_calls(llm_observers)
result["reports"]["llm_observability"] = llm_observability_report
run_dir = out_dir / run_date
run_dir.mkdir(parents=True, exist_ok=True)
(run_dir / "blog_markdown.md").write_text(result["markdown"], encoding="utf-8")
(run_dir / "run_report.json").write_text(
json.dumps(result["reports"], ensure_ascii=False, indent=2, default=_json_default),
encoding="utf-8",
)
for artifact_name, artifact_value in result.get("artifacts", {}).items():
(run_dir / f"{artifact_name}.json").write_text(
json.dumps(artifact_value, ensure_ascii=False, indent=2, default=_json_default),
encoding="utf-8",
)
return {
"run_dir": str(run_dir),
"markdown": result["markdown"],
"reports": result["reports"],
"publish": result["publish"],
"artifacts": result.get("artifacts", {}),
}