diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1283968 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.env +.env.* +!.env.example +__pycache__/ +*.py[cod] +.pytest_cache/ +runs/ +runs-*/ +.idea/ diff --git a/ai_daily_report/__init__.py b/ai_daily_report/__init__.py new file mode 100644 index 0000000..5f84311 --- /dev/null +++ b/ai_daily_report/__init__.py @@ -0,0 +1,2 @@ +"""Core package for the AI daily report pipeline.""" + diff --git a/ai_daily_report/assemble.py b/ai_daily_report/assemble.py new file mode 100644 index 0000000..b66e6ea --- /dev/null +++ b/ai_daily_report/assemble.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import re +from typing import Any + +from .classify import SECTION_ORDER +from .models import NewsItem +from .validate import validate_markdown + + +END_PUNCTUATION = "。!?;.!?;" + + +def _clean_text(text: str) -> str: + value = re.sub(r"^```(?:\w+)?\s*\n?", "", (text or "").strip()) + value = re.sub(r"\n?```\s*$", "", value) + value = re.sub(r"^\s*>\s*", "", value) + value = re.sub(r"\[\d+\]|\[N\]", "", value) + value = re.sub(r"主线判断[::]\s*", "", value) + value = re.sub(r"\s+", " ", value).strip() + return value + + +def _ensure_sentence(text: str) -> str: + value = _clean_text(text) + if value and value[-1] not in END_PUNCTUATION: + value += "。" + return value + + +def _source_link(item: NewsItem) -> str: + source = item.source_label or item.source_group or "来源" + if item.url: + return f"[{source} ↗]({item.url})" + return source + + +def assemble_markdown(items: list[NewsItem], guide: dict[str, Any] | None = None) -> tuple[str, dict[str, Any]]: + guide = guide or {"theme": "", "threads": []} + lines: list[str] = [] + + theme = _clean_text(str(guide.get("theme") or "")) + if theme: + lines.extend(["## 导览", "", f"> {theme}", ""]) + + item_number = 1 + for section in SECTION_ORDER: + section_items = [item for item in items if item.section == section] + if not section_items: + continue + lines.extend([f"## {section}", ""]) + for item in section_items: + title = _clean_text(item.title or item.title_raw) + summary = _ensure_sentence(item.summary or item.summary_raw or "该条目暂无摘要。") + lines.extend( + [ + f"**{item_number}. {title}**", + "", + f"> {summary}{_source_link(item)}", + "", + ] + ) + item_number += 1 + + threads = guide.get("threads", []) or [] + if threads: + lines.extend(["## 今日脉络", ""]) + for thread in threads: + title = _clean_text(str(thread.get("title") or "")) + text = _ensure_sentence(str(thread.get("text") or "")) + if not title or not text: + continue + lines.extend([f"- **{title}**", f" {text}", ""]) + + markdown = "\n".join(lines).strip() + report = validate_markdown(markdown, items) + return markdown, report diff --git a/ai_daily_report/classify.py b/ai_daily_report/classify.py new file mode 100644 index 0000000..4beca1f --- /dev/null +++ b/ai_daily_report/classify.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from collections import Counter +from typing import Any + +from .models import NewsItem + + +SECTION_ORDER = [ + "模型与能力", + "产品与应用", + "开发与基础设施", + "公司与资本", + "政策与安全", + "论文与研究", + "观点与教程", + "人物与动态", +] + +SECTION_ALIASES = { + "模型发布/更新": "模型与能力", + "产品发布/更新": "产品与应用", + "产品与工具": "产品与应用", + "开发与工程": "开发与基础设施", + "行业动态": "公司与资本", + "行业与公司": "公司与资本", + "论文研究": "论文与研究", + "论文与研究": "论文与研究", + "技巧与观点": "观点与教程", + "观点与教程": "观点与教程", + "人物与花絮": "人物与动态", +} + + +RULES = [ + ("政策与安全", ("监管", "政策", "安全", "风险", "滥用", "攻击", "合规", "版权")), + ("论文与研究", ("论文", "研究", "arxiv", "cvpr", "benchmark", "评测", "实验")), + ("开发与基础设施", ("sdk", "api", "mcp", "kubernetes", "框架", "开源", "github", "部署", "基础设施")), + ("公司与资本", ("融资", "ipo", "上市", "招股书", "合作", "估值", "收购", "资本")), + ("模型与能力", ("模型", "gpt", "claude", "gemini", "grok", "token", "参数", "多模态", "语音", "推理")), + ("产品与应用", ("agent", "应用", "产品", "平台", "上线", "工具", "智能体")), + ("观点与教程", ("教程", "观点", "方法论", "guide", "实践", "技巧")), + ("人物与动态", ("黄仁勋", "纳德拉", "访谈", "演讲", "人物")), +] + + +def normalize_section_hint(section_hint: str) -> str: + hint = (section_hint or "").strip() + if hint in SECTION_ORDER: + return hint + return SECTION_ALIASES.get(hint, "") + + +def rule_classify(item: NewsItem) -> str: + text = f"{item.title or item.title_raw} {item.summary or item.summary_raw}".lower() + for section, keywords in RULES: + if any(keyword.lower() in text for keyword in keywords): + return section + return "公司与资本" + + +def rank_score(item: NewsItem) -> int: + text = f"{item.title or item.title_raw} {item.summary or item.summary_raw}" + score = max(0, 200 - item.source_priority) + if item.source_role == "primary": + score += 10 + if item.canonical_url: + score += 10 + if any(ch.isdigit() for ch in text): + score += 10 + if item.duplicate_sources: + score += min(20, len(item.duplicate_sources) * 5) + score -= len(item.quality_flags) * 10 + return score + + +def classify_and_order_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]: + hint_classified = 0 + rule_classified = 0 + + for item in items: + mapped = normalize_section_hint(item.section_hint) + if mapped: + item.section = mapped + hint_classified += 1 + else: + item.section = rule_classify(item) + rule_classified += 1 + + section_index = {section: index for index, section in enumerate(SECTION_ORDER)} + ordered = sorted( + items, + key=lambda item: ( + section_index.get(item.section or "", len(SECTION_ORDER)), + -rank_score(item), + item.title or item.title_raw, + ), + ) + section_counts = Counter(item.section for item in ordered if item.section) + report = { + "input_count": len(items), + "section_counts": dict(section_counts), + "hint_classified": hint_classified, + "rule_classified": rule_classified, + "llm_classified": 0, + "fallback_classified": 0, + "invalid_section_count": sum(1 for item in ordered if item.section not in SECTION_ORDER), + } + return ordered, report diff --git a/ai_daily_report/cli.py b/ai_daily_report/cli.py new file mode 100644 index 0000000..539cbce --- /dev/null +++ b/ai_daily_report/cli.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from .runner import run_daily_report + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog="ai-daily-report") + subcommands = parser.add_subparsers(dest="command") + run = subcommands.add_parser("run") + run.add_argument("--date", default="today") + run.add_argument("--mode", choices=["dry-run", "draft", "publish"], default="dry-run") + run.add_argument("--source-mode", choices=["mock", "live"], default="mock") + run.add_argument("--llm-mode", choices=["mock", "live"], default="mock") + run.add_argument("--out-dir", default="runs") + run.add_argument("--base-url", default="https://blog.ephron.ren") + run.add_argument("--sources-path", default=None) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + if args.command == "run": + run_daily_report( + run_date=args.date, + mode=args.mode, + source_mode=args.source_mode, + llm_mode=args.llm_mode, + out_dir=Path(args.out_dir), + base_url=args.base_url, + sources_path=Path(args.sources_path) if args.sources_path else None, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ai_daily_report/clients.py b/ai_daily_report/clients.py new file mode 100644 index 0000000..2fd3359 --- /dev/null +++ b/ai_daily_report/clients.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import json +import urllib.request +from typing import Any + + +UA = "Mozilla/5.0 (compatible; ai-daily-report/1.0)" + + +def fetch_text(url: str, timeout_seconds: int) -> str: + req = urllib.request.Request(url, headers={"User-Agent": UA}) + with urllib.request.urlopen(req, timeout=timeout_seconds) as response: + return response.read().decode("utf-8", "ignore") + + +class OpenAICompatibleClient: + def __init__(self, *, api_key: str, base_url: str, model: str, timeout_seconds: int = 600): + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.model = model + self.timeout_seconds = timeout_seconds + + def chat(self, prompt: str) -> str: + payload = json.dumps( + { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.2, + "max_tokens": 8000, + }, + ensure_ascii=False, + ).encode("utf-8") + req = urllib.request.Request( + f"{self.base_url}/chat/completions", + data=payload, + headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=self.timeout_seconds) as response: + data = json.loads(response.read().decode("utf-8")) + return data["choices"][0]["message"]["content"].strip() + + +class BlogApiClient: + def __init__(self, *, base_url: str, token: str, timeout_seconds: int = 25): + self.base_url = base_url.rstrip("/") + self.token = token + self.timeout_seconds = timeout_seconds + + def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + data = None + headers = {"Authorization": f"Bearer {self.token}", "User-Agent": UA} + if payload is not None: + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + headers["Content-Type"] = "application/json" + req = urllib.request.Request(f"{self.base_url}{path}", data=data, headers=headers, method=method) + with urllib.request.urlopen(req, timeout=self.timeout_seconds) as response: + return json.loads(response.read().decode("utf-8")) + + def create_post(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "/api/service/posts", payload) + + def publish_post(self, slug: str) -> None: + self._request("POST", f"/api/service/posts/{slug}/publish") diff --git a/ai_daily_report/collect.py b/ai_daily_report/collect.py new file mode 100644 index 0000000..b1c947e --- /dev/null +++ b/ai_daily_report/collect.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timezone +from time import perf_counter +from typing import Callable, Iterable, Any + +from .models import SourceConfig, SourceResult + + +Fetcher = Callable[[SourceConfig, str], list[dict[str, Any]]] + + +def _status_from_exception(exc: Exception) -> str: + if isinstance(exc, TimeoutError): + return "timeout" + return "error" + + +def _collect_one(config: SourceConfig, run_date: str, fetcher: Fetcher) -> SourceResult: + fetched_at = datetime.now(timezone.utc).isoformat() + if not config.enabled: + return SourceResult( + source=config.name, + role=config.role, + ok=False, + status="disabled", + fetched_at=fetched_at, + ) + + started = perf_counter() + try: + items = fetcher(config, run_date) + elapsed_ms = int((perf_counter() - started) * 1000) + status = "ok" if items else "empty" + return SourceResult( + source=config.name, + role=config.role, + ok=status == "ok", + status=status, + items=items, + elapsed_ms=elapsed_ms, + fetched_at=fetched_at, + ) + except Exception as exc: + elapsed_ms = int((perf_counter() - started) * 1000) + return SourceResult( + source=config.name, + role=config.role, + ok=False, + status=_status_from_exception(exc), + error=f"{type(exc).__name__}: {exc}", + elapsed_ms=elapsed_ms, + fetched_at=fetched_at, + ) + + +def collect_sources( + configs: Iterable[SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + max_workers: int | None = None, +) -> tuple[list[SourceResult], dict[str, Any]]: + ordered_configs = list(configs) + if not ordered_configs: + return [], { + "input_source_count": 0, + "ok_source_count": 0, + "failed_source_count": 0, + "raw_item_count": 0, + } + + workers = max_workers or min(8, len(ordered_configs)) + result_by_name: dict[str, SourceResult] = {} + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_collect_one, config, run_date, fetcher): config + for config in ordered_configs + } + for future in as_completed(futures): + config = futures[future] + result_by_name[config.name] = future.result() + + results = [result_by_name[config.name] for config in ordered_configs] + report = { + "input_source_count": len(results), + "ok_source_count": sum(1 for result in results if result.ok), + "failed_source_count": sum(1 for result in results if not result.ok), + "raw_item_count": sum(len(result.items) for result in results), + "source_counts": {result.source: len(result.items) for result in results}, + "statuses": {result.source: result.status for result in results}, + } + return results, report diff --git a/ai_daily_report/config.py b/ai_daily_report/config.py new file mode 100644 index 0000000..03b426d --- /dev/null +++ b/ai_daily_report/config.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from .models import SourceConfig +from .pipeline import _source_config_from_dict + + +def load_json(path: Path) -> Any: + return json.loads(path.read_text(encoding="utf-8")) + + +def load_source_configs(path: Path) -> list[SourceConfig]: + raw = load_json(path) + if not isinstance(raw, list): + raise ValueError("sources config must be a list") + return [_source_config_from_dict(item) for item in raw] diff --git a/ai_daily_report/dedupe.py b/ai_daily_report/dedupe.py new file mode 100644 index 0000000..6a9e426 --- /dev/null +++ b/ai_daily_report/dedupe.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import difflib +from typing import Any + +from .models import NewsItem + + +def _item_score(item: NewsItem) -> int: + score = 0 + score += max(0, 200 - item.source_priority) + if item.canonical_url: + score += 20 + if item.summary_raw: + score += min(40, len(item.summary_raw)) + if item.section_hint: + score += 10 + if item.source_role == "primary": + score += 10 + score -= len(item.quality_flags) * 10 + return score + + +def _merge_group(group: list[NewsItem], reason: str) -> tuple[NewsItem, list[NewsItem], dict[str, Any]]: + keep = max(group, key=_item_score) + removed = [item for item in group if item is not keep] + for removed_item in removed: + keep.duplicate_sources.append( + { + "id": removed_item.id, + "source_group": removed_item.source_group, + "source_label": removed_item.source_label, + "url": removed_item.url, + "reason": reason, + } + ) + report_group = { + "reason": reason, + "keep_id": keep.id, + "removed_ids": [item.id for item in removed], + "confidence": "high", + } + return keep, removed, report_group + + +def _group_by_key(items: list[NewsItem], key_name: str) -> dict[str, list[NewsItem]]: + groups: dict[str, list[NewsItem]] = {} + for item in items: + key = getattr(item, key_name) + if key: + groups.setdefault(key, []).append(item) + return {key: group for key, group in groups.items() if len(group) > 1} + + +def _possible_duplicates(items: list[NewsItem]) -> list[dict[str, Any]]: + possible: list[dict[str, Any]] = [] + for index, left in enumerate(items): + for right in items[index + 1 :]: + if not left.title_norm or not right.title_norm: + continue + ratio = difflib.SequenceMatcher(None, left.title_norm, right.title_norm).ratio() + if ratio >= 0.65: + possible.append( + { + "item_ids": [left.id, right.id], + "reason": "title_similarity", + "similarity": round(ratio, 3), + "confidence": "medium", + } + ) + return possible + + +def hard_dedup_items(items: list[NewsItem]) -> tuple[list[NewsItem], dict[str, Any]]: + remaining = list(items) + removed_object_ids: set[int] = set() + groups_report: list[dict[str, Any]] = [] + + for key_name, reason in ( + ("canonical_url", "same_canonical_url"), + ("title_norm", "same_title_norm"), + ): + grouped = _group_by_key([item for item in remaining if id(item) not in removed_object_ids], key_name) + for group in grouped.values(): + active_group = [item for item in group if id(item) not in removed_object_ids] + if len(active_group) < 2: + continue + keep, removed, report_group = _merge_group(active_group, reason) + removed_object_ids.update(id(item) for item in removed) + groups_report.append(report_group) + + deduped = [item for item in remaining if id(item) not in removed_object_ids] + report = { + "input_count": len(items), + "output_count": len(deduped), + "removed_count": len(removed_object_ids), + "groups": groups_report, + "possible_duplicates": _possible_duplicates(deduped), + } + return deduped, report diff --git a/ai_daily_report/env.py b/ai_daily_report/env.py new file mode 100644 index 0000000..a5697f0 --- /dev/null +++ b/ai_daily_report/env.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import os +import json +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + + +def read_env_file(env_path: Path) -> dict[str, str]: + env: dict[str, str] = {} + if not env_path.exists(): + return env + text = env_path.read_text(encoding="utf-8", errors="ignore") + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + env[key.strip()] = value.strip().strip('"').strip("'") + return env + + +def load_env() -> dict[str, str]: + env: dict[str, str] = {} + env.update(read_env_file(PROJECT_ROOT / ".env")) + env.update(read_env_file(Path.home() / ".hermes" / ".env")) + env.update({key: value for key, value in os.environ.items() if value}) + return env + + +def first_env(env: dict[str, str], *names: str) -> str: + for name in names: + value = (env.get(name) or "").strip() + if value: + return value + return "" + + +def _load_simple_yaml(path: Path) -> dict[str, object]: + if not path.exists(): + return {} + root: dict[str, object] = {} + stack: list[tuple[int, dict[str, object]]] = [(-1, root)] + for raw_line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): + if not raw_line.strip() or raw_line.lstrip().startswith("#") or ":" not in raw_line: + continue + indent = len(raw_line) - len(raw_line.lstrip(" ")) + key, value = raw_line.strip().split(":", 1) + key = key.strip() + value = value.strip().strip('"').strip("'") + while stack and indent <= stack[-1][0]: + stack.pop() + current = stack[-1][1] + if value: + current[key] = value + else: + child: dict[str, object] = {} + current[key] = child + stack.append((indent, child)) + return root + + +def _env_with_hermes(env: dict[str, str], hermes_dir: Path) -> dict[str, str]: + merged = dict(read_env_file(hermes_dir / ".env")) + merged.update(env) + return merged + + +def _provider_env_names(provider: str) -> tuple[str, str, str]: + prefix = provider.upper().replace("-", "_") + return f"{prefix}_API_KEY", f"{prefix}_BASE_URL", f"{prefix}_MODEL" + + +def _auth_json_key(env: dict[str, str], hermes_dir: Path, provider: str) -> str: + auth_path = hermes_dir / "auth.json" + if not auth_path.exists() or not provider: + return "" + try: + auth = json.loads(auth_path.read_text(encoding="utf-8")) + except Exception: + return "" + pool = auth.get("credential_pool", {}) or {} + provider_keys = [provider, provider.replace("-", "_")] + for key in provider_keys: + creds = pool.get(key, []) or [] + if not creds: + continue + cred = creds[0] + source = str(cred.get("source") or "") + if source.startswith("env:"): + resolved = first_env(env, source[4:]) + if resolved: + return resolved + token = str(cred.get("access_token") or "").strip() + if token: + return token + return "" + + +def resolve_llm_config(env: dict[str, str], *, hermes_dir: Path | None = None) -> dict[str, str]: + hermes_dir = hermes_dir or Path.home() / ".hermes" + env = _env_with_hermes(env, hermes_dir) + hermes_config = _load_simple_yaml(hermes_dir / "config.yaml") + model_config = hermes_config.get("model", {}) if isinstance(hermes_config.get("model"), dict) else {} + provider = str(model_config.get("provider") or "").strip() + provider_key, provider_base_url, provider_model = _provider_env_names(provider) if provider else ("", "", "") + + api_key = first_env(env, "LLM_API_KEY") + base_url = first_env(env, "LLM_BASE_URL") + model = first_env(env, "LLM_MODEL") + + if not api_key and provider: + api_key = first_env(env, provider_key) or _auth_json_key(env, hermes_dir, provider) + if not base_url and provider: + base_url = first_env(env, provider_base_url) or str(model_config.get("base_url") or "").strip() + if not model and provider: + model = first_env(env, provider_model) or str(model_config.get("default") or "").strip() + + if not api_key: + api_key = first_env(env, "SUB2API_API_KEY", "XIAOMI_API_KEY", "OPENROUTER_API_KEY") + if not base_url: + base_url = first_env(env, "SUB2API_BASE_URL", "XIAOMI_BASE_URL", "OPENROUTER_BASE_URL") + if not model: + model = first_env(env, "SUB2API_MODEL", "XIAOMI_MODEL") + + missing = [ + name + for name, value in ( + ("LLM_API_KEY", api_key), + ("LLM_BASE_URL", base_url), + ("LLM_MODEL", model), + ) + if not value + ] + if missing: + raise ValueError("missing_llm_config: " + ",".join(missing)) + return {"api_key": api_key, "base_url": base_url, "model": model} + + +def resolve_blog_token(env: dict[str, str]) -> str: + return first_env(env, "BLOG_SERVICE_TOKEN", "EPHRON_SERVICE_TOKEN") diff --git a/ai_daily_report/guide.py b/ai_daily_report/guide.py new file mode 100644 index 0000000..63d8b89 --- /dev/null +++ b/ai_daily_report/guide.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Callable + +from .llm import parse_json_object +from .models import NewsItem + + +GuideLlmCall = Callable[[str], str] + + +def _clean_text(text: str, limit: int | None = None) -> str: + value = re.sub(r"^\s*>\s*", "", text or "").strip() + value = re.sub(r"\[\d+\]|\[N\]", "", value) + value = re.sub(r"\s+", " ", value).strip() + if limit and len(value) > limit: + value = value[:limit].rstrip() + return value + + +def _build_prompt(items: list[NewsItem]) -> str: + payload = { + "task": ( + "Generate a concise AI daily report guide. Return JSON only. Do not use 强信号/中信号/待验证. " + "Use a short theme and 2-4 daily threads. Every thread must reference existing item_ids." + ), + "items": [ + { + "id": item.id, + "title": item.title or item.title_raw, + "summary": item.summary or item.summary_raw, + "section": item.section, + "source": item.source_label, + } + for item in items + ], + "output_schema": { + "theme": "one sentence under 120 Chinese characters", + "threads": [ + { + "title": "thread title", + "text": "one or two sentences", + "item_ids": ["existing item id"], + "kind": "thread|uncertain", + } + ], + }, + } + return json.dumps(payload, ensure_ascii=False) + + +def generate_guide( + items: list[NewsItem], + *, + llm_call: GuideLlmCall, +) -> tuple[dict[str, Any], dict[str, Any]]: + if not items: + return { + "theme": "", + "threads": [], + }, { + "input_count": 0, + "theme_present": False, + "thread_count": 0, + "dropped_thread_count": 0, + "fallback_used": False, + "errors": [], + } + + try: + obj = parse_json_object(llm_call(_build_prompt(items))) + except Exception as exc: + return { + "theme": "", + "threads": [], + }, { + "input_count": len(items), + "theme_present": False, + "thread_count": 0, + "dropped_thread_count": 0, + "fallback_used": True, + "errors": [f"{type(exc).__name__}: {exc}"], + } + + valid_ids = {item.id for item in items} + threads: list[dict[str, Any]] = [] + dropped = 0 + for thread in obj.get("threads", []) or []: + item_ids = [item_id for item_id in thread.get("item_ids", []) if item_id in valid_ids] + if not item_ids: + dropped += 1 + continue + title = _clean_text(str(thread.get("title") or ""), limit=80) + text = _clean_text(str(thread.get("text") or ""), limit=220) + if not title or not text: + dropped += 1 + continue + kind = thread.get("kind") if thread.get("kind") in ("thread", "uncertain") else "thread" + threads.append({"title": title, "text": text, "item_ids": item_ids, "kind": kind}) + + theme = _clean_text(str(obj.get("theme") or ""), limit=120) + guide = {"theme": theme, "threads": threads} + report = { + "input_count": len(items), + "theme_present": bool(theme), + "thread_count": len(threads), + "dropped_thread_count": dropped, + "fallback_used": False, + "errors": [], + } + return guide, report diff --git a/ai_daily_report/llm.py b/ai_daily_report/llm.py new file mode 100644 index 0000000..33c8769 --- /dev/null +++ b/ai_daily_report/llm.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Callable + + +LlmCall = Callable[[str], str] + + +def parse_json_object(text: str) -> dict[str, Any]: + text = re.sub(r"^```(?:json)?\s*\n?", "", text.strip()) + text = re.sub(r"\n?```\s*$", "", text) + match = re.search(r"\{.*\}\s*$", text, re.S) + if not match: + raise ValueError("LLM output does not contain a JSON object") + return json.loads(match.group(0)) + diff --git a/ai_daily_report/models.py b/ai_daily_report/models.py new file mode 100644 index 0000000..756b629 --- /dev/null +++ b/ai_daily_report/models.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class SourceConfig: + name: str + type: str + role: str = "supplement" + priority: int = 100 + required: bool = False + enabled: bool = True + timeout_seconds: int = 25 + retries: int = 0 + min_items: int = 0 + url: str = "" + + +@dataclass +class SourceResult: + source: str + role: str + ok: bool + status: str + items: list[dict[str, Any]] = field(default_factory=list) + error: str | None = None + elapsed_ms: int = 0 + retry_count: int = 0 + fetched_at: str = "" + + +@dataclass +class NewsItem: + id: str + source_group: str + source_label: str + source_role: str + source_priority: int + title_raw: str + title_norm: str + summary_raw: str + url: str + canonical_url: str + published_at: str | None = None + collected_at: str = "" + origin_type: str = "" + section_hint: str = "" + language_hint: str = "" + title: str | None = None + summary: str | None = None + section: str | None = None + quality_flags: list[str] = field(default_factory=list) + duplicate_sources: list[dict[str, Any]] = field(default_factory=list) diff --git a/ai_daily_report/normalize.py b/ai_daily_report/normalize.py new file mode 100644 index 0000000..dda9dd5 --- /dev/null +++ b/ai_daily_report/normalize.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import hashlib +import html +import re +import unicodedata +from collections import Counter +from datetime import datetime, timezone +from typing import Any +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +from .models import NewsItem, SourceResult + + +TRACKING_QUERY_PREFIXES = ("utm_",) +TRACKING_QUERY_KEYS = {"fbclid", "gclid", "spm", "from", "ref"} + + +def clean_text(value: str) -> str: + text = html.unescape(value or "") + text = re.sub(r"<[^>]+>", " ", text) + text = re.sub(r"\s+", " ", text).strip() + return text + + +def canonicalize_url(url: str) -> str: + if not url: + return "" + parsed = urlparse(url.strip()) + scheme = (parsed.scheme or "https").lower() + host = (parsed.netloc or "").lower() + if host.startswith("www."): + host = host[4:] + if host == "twitter.com": + host = "x.com" + + query = [] + for key, value in parse_qsl(parsed.query, keep_blank_values=True): + key_lower = key.lower() + if key_lower in TRACKING_QUERY_KEYS: + continue + if any(key_lower.startswith(prefix) for prefix in TRACKING_QUERY_PREFIXES): + continue + query.append((key, value)) + + path = parsed.path or "" + if len(path) > 1: + path = path.rstrip("/") + + return urlunparse((scheme, host, path, "", urlencode(query), "")) + + +def normalize_title(title: str) -> str: + text = unicodedata.normalize("NFKC", title or "").lower() + text = re.sub(r"[^\w\u4e00-\u9fff]+", "", text) + return text + + +def _item_id(canonical_url: str, source_group: str, title_norm: str, published_at: str | None) -> str: + seed = canonical_url or "|".join([source_group, title_norm, published_at or ""]) + digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()[:16] + return f"item_{digest}" + + +def _quality_flags(title: str, summary: str, url: str) -> list[str]: + flags: list[str] = [] + if not url: + flags.append("missing_url") + if not summary: + flags.append("missing_summary") + if len(normalize_title(title)) < 3: + flags.append("short_title") + return flags + + +def normalize_items( + source_results: list[SourceResult], + *, + run_date: str, + source_priorities: dict[str, int] | None = None, +) -> tuple[list[NewsItem], dict[str, Any]]: + source_priorities = source_priorities or {} + collected_at = datetime.now(timezone.utc).isoformat() + items: list[NewsItem] = [] + flag_counts: Counter[str] = Counter() + id_counts: Counter[str] = Counter() + input_count = 0 + + for source_result in source_results: + for raw in source_result.items: + input_count += 1 + title = clean_text(str(raw.get("title_raw") or raw.get("title") or "")) + summary = clean_text(str(raw.get("summary_raw") or raw.get("summary") or "")) + url = str(raw.get("url") or "").strip() + canonical_url = canonicalize_url(url) + title_norm = normalize_title(title) + flags = _quality_flags(title, summary, canonical_url) + flag_counts.update(flags) + source_label = clean_text(str(raw.get("source_label") or source_result.source)) + published_at = raw.get("published_at") + base_id = _item_id(canonical_url, source_result.source, title_norm, published_at) + id_counts[base_id] += 1 + item_id = base_id if id_counts[base_id] == 1 else f"{base_id}_{id_counts[base_id]}" + + items.append( + NewsItem( + id=item_id, + source_group=source_result.source, + source_label=source_label, + source_role=source_result.role, + source_priority=source_priorities.get(source_result.source, 100), + title_raw=title, + title_norm=title_norm, + summary_raw=summary, + url=url, + canonical_url=canonical_url, + published_at=published_at, + collected_at=collected_at, + origin_type=str(raw.get("origin_type") or ""), + section_hint=str(raw.get("section_hint") or ""), + language_hint=str(raw.get("language_hint") or ""), + quality_flags=flags, + ) + ) + + report = { + "run_date": run_date, + "input_count": input_count, + "output_count": len(items), + "quality_flag_counts": dict(flag_counts), + } + return items, report diff --git a/ai_daily_report/pipeline.py b/ai_daily_report/pipeline.py new file mode 100644 index 0000000..e2bc8a9 --- /dev/null +++ b/ai_daily_report/pipeline.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from typing import Any + +from .assemble import assemble_markdown +from .classify import classify_and_order_items +from .collect import Fetcher, collect_sources +from .dedupe import hard_dedup_items +from .guide import GuideLlmCall, generate_guide +from .models import SourceConfig +from .normalize import normalize_items +from .publish import BlogClient, publish_markdown +from .rewrite import RewriteLlmCall, rewrite_items +from .semantic_dedupe import SemanticLlmCall, semantic_dedup_items + + +def _source_config_from_dict(value: dict[str, Any]) -> SourceConfig: + return SourceConfig( + name=value["name"], + type=value["type"], + role=value.get("role", "supplement"), + priority=int(value.get("priority", 100)), + required=bool(value.get("required", False)), + enabled=bool(value.get("enabled", True)), + timeout_seconds=int(value.get("timeout_seconds", 25)), + retries=int(value.get("retries", 0)), + min_items=int(value.get("min_items", 0)), + url=value.get("url", ""), + ) + + +def run_stage0_to_stage2( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, +) -> dict[str, Any]: + configs = [ + config if isinstance(config, SourceConfig) else _source_config_from_dict(config) + for config in source_configs + ] + source_results, stage0_report = collect_sources(configs, run_date, fetcher=fetcher) + source_priorities = {config.name: config.priority for config in configs} + normalized_items, stage1_report = normalize_items( + source_results, + run_date=run_date, + source_priorities=source_priorities, + ) + deduped_items, stage2_report = hard_dedup_items(normalized_items) + return { + "source_results": source_results, + "items": deduped_items, + "reports": { + "stage0": stage0_report, + "stage1": stage1_report, + "stage2": stage2_report, + }, + } + + +def run_stage0_to_stage4( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + semantic_llm_call: SemanticLlmCall, + rewrite_llm_call: RewriteLlmCall, +) -> dict[str, Any]: + stage2_result = run_stage0_to_stage2(source_configs, run_date, fetcher=fetcher) + items = stage2_result["items"] + candidates = stage2_result["reports"]["stage2"].get("possible_duplicates", []) + semantic_items, stage3_report = semantic_dedup_items( + items, + candidates, + llm_call=semantic_llm_call, + ) + rewritten_items, stage4_report = rewrite_items( + semantic_items, + llm_call=rewrite_llm_call, + ) + reports = dict(stage2_result["reports"]) + reports["stage3"] = stage3_report + reports["stage4"] = stage4_report + return { + "source_results": stage2_result["source_results"], + "items": rewritten_items, + "reports": reports, + } + + +def run_stage0_to_stage5( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + semantic_llm_call: SemanticLlmCall, + rewrite_llm_call: RewriteLlmCall, +) -> dict[str, Any]: + stage4_result = run_stage0_to_stage4( + source_configs, + run_date, + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + ) + classified_items, stage5_report = classify_and_order_items(stage4_result["items"]) + reports = dict(stage4_result["reports"]) + reports["stage5"] = stage5_report + return { + "source_results": stage4_result["source_results"], + "items": classified_items, + "reports": reports, + } + + +def run_stage0_to_stage6( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + semantic_llm_call: SemanticLlmCall, + rewrite_llm_call: RewriteLlmCall, + guide_llm_call: GuideLlmCall, +) -> dict[str, Any]: + stage5_result = run_stage0_to_stage5( + source_configs, + run_date, + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + ) + guide, stage6_report = generate_guide(stage5_result["items"], llm_call=guide_llm_call) + reports = dict(stage5_result["reports"]) + reports["stage6"] = stage6_report + return { + "source_results": stage5_result["source_results"], + "items": stage5_result["items"], + "guide": guide, + "reports": reports, + } + + +def run_stage0_to_stage7( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + semantic_llm_call: SemanticLlmCall, + rewrite_llm_call: RewriteLlmCall, + guide_llm_call: GuideLlmCall, +) -> dict[str, Any]: + stage6_result = run_stage0_to_stage6( + source_configs, + run_date, + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + guide_llm_call=guide_llm_call, + ) + markdown, stage7_report = assemble_markdown(stage6_result["items"], stage6_result["guide"]) + reports = dict(stage6_result["reports"]) + reports["stage7"] = stage7_report + return { + "source_results": stage6_result["source_results"], + "items": stage6_result["items"], + "guide": stage6_result["guide"], + "markdown": markdown, + "reports": reports, + } + + +def run_stage0_to_stage8( + source_configs: list[dict[str, Any] | SourceConfig], + run_date: str, + *, + fetcher: Fetcher, + semantic_llm_call: SemanticLlmCall, + rewrite_llm_call: RewriteLlmCall, + guide_llm_call: GuideLlmCall, + mode: str, + base_url: str, + client: BlogClient | None, +) -> dict[str, Any]: + stage7_result = run_stage0_to_stage7( + source_configs, + run_date, + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + guide_llm_call=guide_llm_call, + ) + slug = f"ai-{run_date}" + publish_result = publish_markdown( + title=f"AI日报 · {run_date}", + markdown=stage7_result["markdown"], + tags=["AI日报", "AI资讯", "人工智能"], + slug=slug, + base_url=base_url, + mode=mode, + markdown_report=stage7_result["reports"]["stage7"], + client=client, + ) + reports = dict(stage7_result["reports"]) + reports["stage8"] = { + "mode": publish_result.mode, + "status": publish_result.status, + "slug": publish_result.slug, + "blog_url": publish_result.blog_url, + "public_ok": publish_result.public_ok, + "error": publish_result.error, + } + return { + "source_results": stage7_result["source_results"], + "items": stage7_result["items"], + "guide": stage7_result["guide"], + "markdown": stage7_result["markdown"], + "publish": publish_result, + "reports": reports, + } diff --git a/ai_daily_report/publish.py b/ai_daily_report/publish.py new file mode 100644 index 0000000..7cf3ccd --- /dev/null +++ b/ai_daily_report/publish.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Protocol + + +@dataclass +class PublishResult: + mode: str + status: str + slug: str + blog_url: str + public_ok: bool = False + error: str | None = None + + +class BlogClient(Protocol): + def create_post(self, payload: dict[str, Any]) -> dict[str, Any]: + ... + + def publish_post(self, slug: str) -> None: + ... + + +def dry_run_publish(slug: str, base_url: str) -> PublishResult: + return PublishResult( + mode="dry-run", + status="ok", + slug=slug, + blog_url=f"{base_url.rstrip('/')}/posts/{slug}", + public_ok=True, + ) + + +def publish_markdown( + *, + title: str, + markdown: str, + tags: list[str], + slug: str, + base_url: str, + mode: str, + markdown_report: dict[str, Any], + client: BlogClient | None, +) -> PublishResult: + blocking_errors = markdown_report.get("blocking_errors", []) or [] + blog_url = f"{base_url.rstrip('/')}/posts/{slug}" + if blocking_errors: + return PublishResult( + mode=mode, + status="blocked", + slug=slug, + blog_url=blog_url, + public_ok=False, + error=";".join(blocking_errors), + ) + if mode == "dry-run": + return dry_run_publish(slug, base_url) + if client is None: + return PublishResult( + mode=mode, + status="failed", + slug=slug, + blog_url=blog_url, + public_ok=False, + error="missing_blog_client", + ) + + payload = {"title": title, "content": markdown, "tags": tags, "slug": slug} + try: + create_resp = client.create_post(payload) + created_slug = create_resp.get("slug") or slug + if mode == "publish": + client.publish_post(created_slug) + return PublishResult( + mode=mode, + status="ok", + slug=created_slug, + blog_url=f"{base_url.rstrip('/')}/posts/{created_slug}", + public_ok=mode == "publish", + ) + except Exception as exc: + return PublishResult( + mode=mode, + status="failed", + slug=slug, + blog_url=blog_url, + public_ok=False, + error=f"{type(exc).__name__}: {exc}", + ) diff --git a/ai_daily_report/rewrite.py b/ai_daily_report/rewrite.py new file mode 100644 index 0000000..6bc9063 --- /dev/null +++ b/ai_daily_report/rewrite.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import json +from typing import Any, Callable + +from .llm import parse_json_object +from .models import NewsItem + + +RewriteLlmCall = Callable[[str], str] + + +def _chunks(items: list[NewsItem], size: int) -> list[list[NewsItem]]: + return [items[index : index + size] for index in range(0, len(items), size)] + + +def _build_prompt(batch: list[NewsItem]) -> str: + payload = { + "task": ( + "Rewrite AI news titles and summaries into concise Chinese. Preserve brand/model/API names " + "such as GPT-5, Codex, Gemini, Claude, API, MCP. Do not add facts." + ), + "items": [ + { + "id": item.id, + "title_raw": item.title_raw, + "summary_raw": item.summary_raw, + "source": item.source_label, + "language_hint": item.language_hint, + } + for item in batch + ], + "output_schema": { + "rewrites": [ + { + "id": "item id", + "title": "display title", + "summary": "display summary", + "flags": [], + } + ] + }, + } + return json.dumps(payload, ensure_ascii=False) + + +def _fallback(item: NewsItem) -> None: + item.title = item.title_raw + item.summary = item.summary_raw or "该条目暂无摘要。" + + +def _apply_rewrite_batch(batch: list[NewsItem], llm_call: RewriteLlmCall) -> int: + obj = parse_json_object(llm_call(_build_prompt(batch))) + rewrites = obj.get("rewrites", []) + if not isinstance(rewrites, list): + raise ValueError("rewrites is not a list") + by_id = {item.id: item for item in batch} + seen_ids: set[str] = set() + for entry in rewrites: + item_id = entry.get("id") + title = str(entry.get("title") or "").strip() + summary = str(entry.get("summary") or "").strip() + if item_id in by_id and title and summary: + by_id[item_id].title = title + by_id[item_id].summary = summary + seen_ids.add(item_id) + for item in batch: + if item.id not in seen_ids: + raise ValueError(f"missing_rewrite_for_item: {item.id}") + return len(seen_ids) + + +def rewrite_items( + items: list[NewsItem], + *, + llm_call: RewriteLlmCall, + batch_size: int = 10, +) -> tuple[list[NewsItem], dict[str, Any]]: + rewritten_count = 0 + fallback_count = 0 + errors: list[str] = [] + + for batch in _chunks(items, max(1, batch_size)): + try: + rewritten_count += _apply_rewrite_batch(batch, llm_call) + except Exception as exc: + errors.append(f"batch:{type(exc).__name__}: {exc}") + for item in batch: + try: + rewritten_count += _apply_rewrite_batch([item], llm_call) + except Exception as item_exc: + errors.append(f"item:{item.id}:{type(item_exc).__name__}: {item_exc}") + _fallback(item) + fallback_count += 1 + + report = { + "input_count": len(items), + "rewritten_count": rewritten_count, + "fallback_count": fallback_count, + "batch_count": len(_chunks(items, max(1, batch_size))), + "errors": errors, + } + return items, report diff --git a/ai_daily_report/runner.py b/ai_daily_report/runner.py new file mode 100644 index 0000000..295316c --- /dev/null +++ b/ai_daily_report/runner.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import Any + +from .clients import BlogApiClient, OpenAICompatibleClient, fetch_text as default_fetch_text +from .config import load_source_configs +from .env import load_env, resolve_blog_token, resolve_llm_config +from .models import SourceConfig +from .pipeline import run_stage0_to_stage8 +from .sources.registry import get_source_fetcher + + +def _json_default(value: Any): + if is_dataclass(value): + return asdict(value) + raise TypeError(f"Object is not JSON serializable: {type(value).__name__}") + + +def _mock_source_configs() -> list[SourceConfig]: + return [SourceConfig(name="Mock AI HOT", type="mock", role="primary", priority=10)] + + +def _mock_fetcher(config: SourceConfig, run_date: str) -> list[dict[str, Any]]: + return [ + { + "title_raw": "GPT-5 API 发布", + "summary_raw": "OpenAI 发布 GPT-5 API,用于本地 mock 测试。", + "url": "https://example.com/gpt5", + "source_label": "OpenAI:Blog", + "section_hint": "模型发布/更新", + "origin_type": "mock", + "language_hint": "zh", + } + ] + + +def _mock_semantic_llm(prompt: str) -> str: + return json.dumps({"duplicate_groups": [], "not_duplicates": [], "uncertain": []}, ensure_ascii=False) + + +def _mock_rewrite_llm(prompt: str) -> str: + payload = json.loads(prompt) + return json.dumps( + { + "rewrites": [ + { + "id": item["id"], + "title": item["title_raw"], + "summary": item["summary_raw"], + "flags": [], + } + for item in payload["items"] + ] + }, + ensure_ascii=False, + ) + + +def _mock_guide_llm(prompt: str) -> str: + payload = json.loads(prompt) + item_ids = [item["id"] for item in payload["items"][:3]] + return json.dumps( + { + "theme": "本地 mock 模式已生成 AI 日报,用于验证流水线。", + "threads": [ + { + "title": "本地链路验证", + "text": "采集、改写、分类、导览、Markdown 和发布报告都已通过 mock 数据串联。", + "item_ids": item_ids, + "kind": "thread", + } + ], + }, + ensure_ascii=False, + ) + + +def run_daily_report( + *, + run_date: str, + mode: str, + source_mode: str, + llm_mode: str, + out_dir: Path, + base_url: str, + sources_path: Path | None = None, + fetch_text=None, + env: dict[str, str] | None = None, + llm_client_factory=OpenAICompatibleClient, + blog_client_factory=BlogApiClient, +) -> dict[str, Any]: + fetch_text = fetch_text or default_fetch_text + env = env if env is not None else load_env() + + if source_mode == "mock": + source_configs = _mock_source_configs() + fetcher = _mock_fetcher + elif source_mode == "live": + if sources_path is None: + sources_path = Path("config") / "sources.json" + source_configs = load_source_configs(sources_path) + + def fetcher(config: SourceConfig, current_date: str) -> list[dict[str, Any]]: + source_fetcher = get_source_fetcher(config.type) + return source_fetcher(config, current_date, fetch_text) + + else: + raise ValueError("source_mode must be 'mock' or 'live'") + + if llm_mode == "mock": + semantic_llm_call = _mock_semantic_llm + rewrite_llm_call = _mock_rewrite_llm + guide_llm_call = _mock_guide_llm + elif llm_mode == "live": + llm_client = llm_client_factory(**resolve_llm_config(env)) + semantic_llm_call = llm_client.chat + rewrite_llm_call = llm_client.chat + guide_llm_call = llm_client.chat + else: + raise ValueError("llm_mode must be 'mock' or 'live'") + + blog_client = None + if mode in ("draft", "publish"): + token = resolve_blog_token(env) + if not token: + raise ValueError("missing_blog_token: set BLOG_SERVICE_TOKEN or EPHRON_SERVICE_TOKEN") + blog_client = blog_client_factory(base_url=base_url, token=token) + + result = run_stage0_to_stage8( + source_configs, + run_date, + fetcher=fetcher, + semantic_llm_call=semantic_llm_call, + rewrite_llm_call=rewrite_llm_call, + guide_llm_call=guide_llm_call, + mode=mode, + base_url=base_url, + client=blog_client, + ) + + run_dir = out_dir / run_date + run_dir.mkdir(parents=True, exist_ok=True) + (run_dir / "blog_markdown.md").write_text(result["markdown"], encoding="utf-8") + (run_dir / "run_report.json").write_text( + json.dumps(result["reports"], ensure_ascii=False, indent=2, default=_json_default), + encoding="utf-8", + ) + return { + "run_dir": str(run_dir), + "markdown": result["markdown"], + "reports": result["reports"], + "publish": result["publish"], + } diff --git a/ai_daily_report/semantic_dedupe.py b/ai_daily_report/semantic_dedupe.py new file mode 100644 index 0000000..815d298 --- /dev/null +++ b/ai_daily_report/semantic_dedupe.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import json +from typing import Any, Callable + +from .llm import parse_json_object +from .models import NewsItem + + +SemanticLlmCall = Callable[[str], str] + + +def _build_prompt(items: list[NewsItem], candidates: list[dict[str, Any]]) -> str: + item_payload = [ + { + "id": item.id, + "title": item.title or item.title_raw, + "summary": item.summary or item.summary_raw, + "source": item.source_label, + "section_hint": item.section_hint, + } + for item in items + ] + prompt = { + "task": "Identify only high-confidence semantic duplicates. Do not curate or remove by importance.", + "items": item_payload, + "candidates": candidates, + "output_schema": { + "duplicate_groups": [ + { + "keep_id": "item id", + "remove_ids": ["item id"], + "confidence": "high|medium|low", + "reason": "same concrete event reason", + } + ], + "not_duplicates": [], + "uncertain": [], + }, + } + return json.dumps(prompt, ensure_ascii=False) + + +def _score(item: NewsItem) -> int: + score = max(0, 200 - item.source_priority) + if item.source_role == "primary": + score += 10 + if item.summary_raw: + score += min(40, len(item.summary_raw)) + if item.canonical_url: + score += 20 + score -= len(item.quality_flags) * 10 + return score + + +def _choose_keep(group_items: list[NewsItem], suggested_keep_id: str) -> NewsItem: + suggested = [item for item in group_items if item.id == suggested_keep_id] + if suggested: + best = max(group_items, key=_score) + if _score(suggested[0]) >= _score(best) - 10: + return suggested[0] + return max(group_items, key=_score) + + +def semantic_dedup_items( + items: list[NewsItem], + candidates: list[dict[str, Any]], + *, + llm_call: SemanticLlmCall, + max_deletion_ratio: float = 0.5, +) -> tuple[list[NewsItem], dict[str, Any]]: + if not items or not candidates: + return items, { + "input_count": len(items), + "candidate_group_count": len(candidates), + "removed_count": 0, + "duplicate_groups": [], + "uncertain": [], + "errors": [], + "skipped_for_deletion_ratio": False, + } + + errors: list[str] = [] + try: + obj = parse_json_object(llm_call(_build_prompt(items, candidates))) + except Exception as exc: + return items, { + "input_count": len(items), + "candidate_group_count": len(candidates), + "removed_count": 0, + "duplicate_groups": [], + "uncertain": [], + "errors": [f"{type(exc).__name__}: {exc}"], + "skipped_for_deletion_ratio": False, + } + + by_id = {item.id: item for item in items} + candidate_sets = { + frozenset(item_id for item_id in candidate.get("item_ids", []) if isinstance(item_id, str)) + for candidate in candidates + } + candidate_removals: set[str] = set() + valid_groups: list[dict[str, Any]] = [] + + for group in obj.get("duplicate_groups", []) or []: + if group.get("confidence") != "high": + continue + ids = [group.get("keep_id")] + list(group.get("remove_ids") or []) + if any(not isinstance(item_id, str) or item_id not in by_id for item_id in ids): + errors.append(f"invalid_ids_in_group: {group}") + continue + group_set = frozenset(ids) + if not any(group_set.issubset(candidate_set) for candidate_set in candidate_sets): + errors.append(f"group_outside_candidates: {group}") + continue + group_items = [by_id[item_id] for item_id in ids] + keep = _choose_keep(group_items, str(group.get("keep_id"))) + remove_items = [item for item in group_items if item is not keep] + candidate_removals.update(item.id for item in remove_items) + valid_groups.append( + { + "keep_id": keep.id, + "remove_ids": [item.id for item in remove_items], + "confidence": "high", + "reason": str(group.get("reason") or "semantic_duplicate"), + } + ) + + deletion_ratio = len(candidate_removals) / len(items) if items else 0 + if deletion_ratio > max_deletion_ratio: + return items, { + "input_count": len(items), + "candidate_group_count": len(candidates), + "removed_count": 0, + "duplicate_groups": valid_groups, + "uncertain": obj.get("uncertain", []) or [], + "errors": errors, + "skipped_for_deletion_ratio": True, + } + + removed_ids: set[str] = set() + for group in valid_groups: + keep = by_id[group["keep_id"]] + for remove_id in group["remove_ids"]: + removed = by_id[remove_id] + keep.duplicate_sources.append( + { + "id": removed.id, + "source_group": removed.source_group, + "source_label": removed.source_label, + "url": removed.url, + "reason": group["reason"], + } + ) + removed_ids.add(remove_id) + + deduped = [item for item in items if item.id not in removed_ids] + report = { + "input_count": len(items), + "candidate_group_count": len(candidates), + "removed_count": len(removed_ids), + "duplicate_groups": valid_groups, + "uncertain": obj.get("uncertain", []) or [], + "errors": errors, + "skipped_for_deletion_ratio": False, + } + return deduped, report diff --git a/ai_daily_report/sources/__init__.py b/ai_daily_report/sources/__init__.py new file mode 100644 index 0000000..54ac9e1 --- /dev/null +++ b/ai_daily_report/sources/__init__.py @@ -0,0 +1,2 @@ +"""Source adapters for the AI daily report pipeline.""" + diff --git a/ai_daily_report/sources/aihot.py b/ai_daily_report/sources/aihot.py new file mode 100644 index 0000000..9c13d55 --- /dev/null +++ b/ai_daily_report/sources/aihot.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import json +from typing import Any, Callable + +from ai_daily_report.models import SourceConfig + + +FetchText = Callable[[str, int], str] + + +def fetch_aihot(config: SourceConfig, run_date: str, fetch_text: FetchText) -> list[dict[str, Any]]: + data = json.loads(fetch_text(f"https://aihot.virxact.com/api/public/daily/{run_date}", config.timeout_seconds)) + items: list[dict[str, Any]] = [] + generated = data.get("generatedAt") + for section in data.get("sections", []) or []: + for raw in section.get("items", []) or []: + items.append( + { + "source_group": config.name, + "source_label": raw.get("sourceName") or config.name, + "title_raw": raw.get("title") or "", + "summary_raw": raw.get("summary") or "", + "url": raw.get("sourceUrl") or "", + "published_at": generated, + "origin_type": "aihot_json", + "section_hint": section.get("label") or "", + "language_hint": "zh", + } + ) + return items + diff --git a/ai_daily_report/sources/juya.py b/ai_daily_report/sources/juya.py new file mode 100644 index 0000000..533fbbf --- /dev/null +++ b/ai_daily_report/sources/juya.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import re +import xml.etree.ElementTree as ET +from typing import Any, Callable + +from ai_daily_report.models import SourceConfig +from ai_daily_report.normalize import clean_text +from ai_daily_report.sources.labels import source_label_from_url + + +FetchText = Callable[[str, int], str] + + +def parse_juya_rss(config: SourceConfig, xml_text: str, run_date: str) -> list[dict[str, Any]]: + root = ET.fromstring(xml_text) + channel = root.find("channel") + raw_items = channel.findall("item") if channel is not None else [] + article_html = "" + for raw in raw_items: + if (raw.findtext("title") or "").strip() != run_date: + continue + content_el = raw.find("{http://purl.org/rss/1.0/modules/content/}encoded") + article_html = content_el.text if content_el is not None and content_el.text else "" + break + if not article_html: + return [] + + block_pattern = re.compile( + r'
#(?P\d+) \s*#(?P\d+) \s*]*>|', '\n', body_text, flags=re.I) - body_text = re.sub(r'|||
#1MiniMax M3 加速。
+ +OpenAI 发布更新。
", + "url": "https://openai.com/blog/gpt-5?utm_campaign=test", + "source_label": "OpenAI:Blog", + "section_hint": "模型发布/更新", + } + ], + ) + + items, report = normalize_items([source_result], run_date="2026-06-04") + + self.assertEqual(len(items), 1) + self.assertTrue(items[0].id.startswith("item_")) + self.assertEqual(items[0].canonical_url, "https://openai.com/blog/gpt-5") + self.assertEqual(items[0].title_norm, normalize_title("GPT-5 发布:速度提升 2x!")) + self.assertEqual(items[0].summary_raw, "OpenAI 发布更新。") + self.assertEqual(items[0].source_role, "primary") + self.assertEqual(report["input_count"], 1) + self.assertEqual(report["output_count"], 1) + + def test_normalize_items_marks_quality_flags_without_dropping_item(self): + source_result = SourceResult( + source="RSS", + role="supplement", + ok=True, + status="ok", + items=[{"title_raw": "短", "summary_raw": "", "url": ""}], + ) + + items, report = normalize_items([source_result], run_date="2026-06-04") + + self.assertEqual(len(items), 1) + self.assertIn("missing_url", items[0].quality_flags) + self.assertIn("missing_summary", items[0].quality_flags) + self.assertIn("short_title", items[0].quality_flags) + self.assertEqual(report["quality_flag_counts"]["missing_url"], 1) + + def test_normalize_items_keeps_ids_unique_for_same_canonical_url(self): + source_result = SourceResult( + source="AI HOT", + role="primary", + ok=True, + status="ok", + items=[ + { + "title_raw": "OpenAI 发布 GPT-5", + "summary_raw": "summary a", + "url": "https://example.com/news?utm_source=a", + }, + { + "title_raw": "OpenAI 发布 GPT-5", + "summary_raw": "summary b", + "url": "https://example.com/news", + }, + ], + ) + + items, _ = normalize_items([source_result], run_date="2026-06-04") + + self.assertEqual(len({item.id for item in items}), 2) + self.assertEqual(items[0].canonical_url, items[1].canonical_url) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage2_dedupe.py b/tests/test_stage2_dedupe.py new file mode 100644 index 0000000..0809889 --- /dev/null +++ b/tests/test_stage2_dedupe.py @@ -0,0 +1,63 @@ +import unittest + +from ai_daily_report.dedupe import hard_dedup_items +from ai_daily_report.models import NewsItem + + +def item( + item_id, + title, + title_norm, + url, + canonical_url, + source_group="AI HOT", + source_label="AI HOT", + source_priority=100, + summary="summary", +): + return NewsItem( + id=item_id, + source_group=source_group, + source_label=source_label, + source_role="primary" if source_group == "AI HOT" else "supplement", + source_priority=source_priority, + title_raw=title, + title_norm=title_norm, + summary_raw=summary, + url=url, + canonical_url=canonical_url, + ) + + +class Stage2DedupeTests(unittest.TestCase): + def test_hard_dedup_merges_same_canonical_url_and_keeps_better_item(self): + items = [ + item("a", "OpenAI 发布 GPT-5", "openai发布gpt5", "https://example.com/a?utm_source=x", "https://example.com/a", source_group="RSS", source_priority=50, summary="short"), + item("b", "OpenAI 发布 GPT-5", "openai发布gpt5", "https://example.com/a", "https://example.com/a", source_group="AI HOT", source_priority=10, summary="longer summary"), + ] + + deduped, report = hard_dedup_items(items) + + self.assertEqual([i.id for i in deduped], ["b"]) + self.assertEqual(report["input_count"], 2) + self.assertEqual(report["output_count"], 1) + self.assertEqual(report["removed_count"], 1) + self.assertEqual(report["groups"][0]["reason"], "same_canonical_url") + self.assertEqual(deduped[0].duplicate_sources[0]["source_group"], "RSS") + + def test_hard_dedup_marks_similar_titles_without_removing(self): + items = [ + item("a", "Grok API 上线 Cloudflare Gateway", "grokapi上线cloudflaregateway", "https://x.com/a", "https://x.com/a"), + item("b", "Grok 模型登陆 Cloudflare AI Gateway", "grok模型登陆cloudflareaigateway", "https://x.com/b", "https://x.com/b"), + ] + + deduped, report = hard_dedup_items(items) + + self.assertEqual(len(deduped), 2) + self.assertEqual(report["removed_count"], 0) + self.assertEqual(len(report["possible_duplicates"]), 1) + self.assertEqual(set(report["possible_duplicates"][0]["item_ids"]), {"a", "b"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage3_semantic_dedupe.py b/tests/test_stage3_semantic_dedupe.py new file mode 100644 index 0000000..ed876a5 --- /dev/null +++ b/tests/test_stage3_semantic_dedupe.py @@ -0,0 +1,129 @@ +import json +import unittest + +from ai_daily_report.models import NewsItem +from ai_daily_report.semantic_dedupe import semantic_dedup_items + + +def news_item(item_id, title, source_group="AI HOT"): + return NewsItem( + id=item_id, + source_group=source_group, + source_label=source_group, + source_role="primary" if source_group == "AI HOT" else "supplement", + source_priority=10 if source_group == "AI HOT" else 50, + title_raw=title, + title_norm=title.lower(), + summary_raw=f"{title} summary", + url=f"https://example.com/{item_id}", + canonical_url=f"https://example.com/{item_id}", + ) + + +class Stage3SemanticDedupeTests(unittest.TestCase): + def test_semantic_dedup_removes_only_high_confidence_duplicates(self): + items = [ + news_item("a", "Anthropic 提交 IPO 招股书", "AI HOT"), + news_item("b", "刚刚,Anthropic 提交了招股书", "量子位"), + news_item("c", "Grok 上线 Cloudflare Gateway", "AI HOT"), + ] + candidates = [{"item_ids": ["a", "b"], "reason": "title_similarity"}] + + def llm_call(prompt): + return json.dumps( + { + "duplicate_groups": [ + { + "keep_id": "a", + "remove_ids": ["b"], + "confidence": "high", + "reason": "same IPO filing event", + } + ], + "not_duplicates": [], + "uncertain": [], + } + ) + + deduped, report = semantic_dedup_items(items, candidates, llm_call=llm_call) + + self.assertEqual([item.id for item in deduped], ["a", "c"]) + self.assertEqual(report["removed_count"], 1) + self.assertEqual(report["duplicate_groups"][0]["reason"], "same IPO filing event") + self.assertEqual(deduped[0].duplicate_sources[0]["id"], "b") + + def test_semantic_dedup_skips_deletion_when_ratio_exceeds_limit(self): + items = [ + news_item("a", "A"), + news_item("b", "B"), + news_item("c", "C"), + ] + candidates = [{"item_ids": ["a", "b", "c"], "reason": "llm_candidate"}] + + def llm_call(prompt): + return json.dumps( + { + "duplicate_groups": [ + { + "keep_id": "a", + "remove_ids": ["b", "c"], + "confidence": "high", + "reason": "too broad", + } + ], + "not_duplicates": [], + "uncertain": [], + } + ) + + deduped, report = semantic_dedup_items( + items, + candidates, + llm_call=llm_call, + max_deletion_ratio=0.5, + ) + + self.assertEqual(len(deduped), 3) + self.assertEqual(report["removed_count"], 0) + self.assertTrue(report["skipped_for_deletion_ratio"]) + + def test_semantic_dedup_ignores_groups_outside_candidate_sets(self): + items = [ + news_item("a", "Suno 完成融资"), + news_item("b", "Suno 完成 D 轮融资"), + news_item("c", "Ideogram 发布 v4"), + news_item("d", "OpenClaw 发布新版"), + ] + candidates = [{"item_ids": ["a", "b"], "reason": "title_similarity"}] + + def llm_call(prompt): + return json.dumps( + { + "duplicate_groups": [ + { + "keep_id": "a", + "remove_ids": ["b"], + "confidence": "high", + "reason": "same Suno event", + }, + { + "keep_id": "c", + "remove_ids": ["d"], + "confidence": "high", + "reason": "not part of candidates", + }, + ], + "not_duplicates": [], + "uncertain": [], + } + ) + + deduped, report = semantic_dedup_items(items, candidates, llm_call=llm_call) + + self.assertEqual([item.id for item in deduped], ["a", "c", "d"]) + self.assertEqual(report["removed_count"], 1) + self.assertIn("group_outside_candidates", report["errors"][0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage4_rewrite.py b/tests/test_stage4_rewrite.py new file mode 100644 index 0000000..62ef346 --- /dev/null +++ b/tests/test_stage4_rewrite.py @@ -0,0 +1,96 @@ +import json +import unittest + +from ai_daily_report.models import NewsItem +from ai_daily_report.rewrite import rewrite_items + + +def news_item(item_id="a"): + return NewsItem( + id=item_id, + source_group="AI HOT", + source_label="AI HOT", + source_role="primary", + source_priority=10, + title_raw="OpenAI launches GPT-5 API", + title_norm="openailaunchesgpt5api", + summary_raw="OpenAI launched the GPT-5 API with better latency.", + url="https://example.com/a", + canonical_url="https://example.com/a", + ) + + +class Stage4RewriteTests(unittest.TestCase): + def test_rewrite_items_writes_display_fields_without_overwriting_raw(self): + items = [news_item("a")] + + def llm_call(prompt): + return json.dumps( + { + "rewrites": [ + { + "id": "a", + "title": "OpenAI 发布 GPT-5 API", + "summary": "OpenAI 发布 GPT-5 API,延迟表现更好。", + "flags": [], + } + ] + }, + ensure_ascii=False, + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=10) + + self.assertEqual(rewritten[0].title, "OpenAI 发布 GPT-5 API") + self.assertEqual(rewritten[0].summary, "OpenAI 发布 GPT-5 API,延迟表现更好。") + self.assertEqual(rewritten[0].title_raw, "OpenAI launches GPT-5 API") + self.assertEqual(report["rewritten_count"], 1) + self.assertEqual(report["fallback_count"], 0) + + def test_rewrite_items_falls_back_when_llm_fails(self): + items = [news_item("a")] + + def llm_call(prompt): + raise TimeoutError("slow") + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=10) + + self.assertEqual(rewritten[0].title, "OpenAI launches GPT-5 API") + self.assertEqual(rewritten[0].summary, "OpenAI launched the GPT-5 API with better latency.") + self.assertEqual(report["rewritten_count"], 0) + self.assertEqual(report["fallback_count"], 1) + self.assertIn("TimeoutError", report["errors"][0]) + + def test_rewrite_items_retries_failed_batch_as_single_items(self): + items = [news_item("a"), news_item("b")] + calls = [] + + def llm_call(prompt): + payload = json.loads(prompt) + ids = [item["id"] for item in payload["items"]] + calls.append(ids) + if len(ids) > 1: + return "not json" + return json.dumps( + { + "rewrites": [ + { + "id": ids[0], + "title": f"title {ids[0]}", + "summary": f"summary {ids[0]}", + "flags": [], + } + ] + } + ) + + rewritten, report = rewrite_items(items, llm_call=llm_call, batch_size=2) + + self.assertEqual([item.title for item in rewritten], ["title a", "title b"]) + self.assertEqual(report["rewritten_count"], 2) + self.assertEqual(report["fallback_count"], 0) + self.assertEqual(calls, [["a", "b"], ["a"], ["b"]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage5_classify.py b/tests/test_stage5_classify.py new file mode 100644 index 0000000..a158ca3 --- /dev/null +++ b/tests/test_stage5_classify.py @@ -0,0 +1,61 @@ +import unittest + +from ai_daily_report.classify import SECTION_ORDER, classify_and_order_items +from ai_daily_report.models import NewsItem + + +def news_item(item_id, title, summary="", section_hint="", source_priority=50): + return NewsItem( + id=item_id, + source_group="AI HOT", + source_label="AI HOT", + source_role="primary", + source_priority=source_priority, + title_raw=title, + title_norm=title.lower(), + summary_raw=summary or f"{title} summary", + title=title, + summary=summary or f"{title} summary", + url=f"https://example.com/{item_id}", + canonical_url=f"https://example.com/{item_id}", + section_hint=section_hint, + ) + + +class Stage5ClassifyTests(unittest.TestCase): + def test_classify_maps_legacy_section_hints_to_new_sections(self): + items = [news_item("a", "GPT-5 发布", section_hint="模型发布/更新")] + + classified, report = classify_and_order_items(items) + + self.assertEqual(classified[0].section, "模型与能力") + self.assertEqual(report["hint_classified"], 1) + self.assertIn("模型与能力", SECTION_ORDER) + + def test_classify_uses_rules_when_hint_is_missing(self): + items = [ + news_item("a", "Anthropic 提交 IPO 文件", summary="Anthropic 计划上市并提交文件。"), + news_item("b", "MCP SDK 发布新版", summary="开发者可用新版 SDK 构建工具。"), + ] + + classified, report = classify_and_order_items(items) + by_id = {item.id: item for item in classified} + + self.assertEqual(by_id["a"].section, "公司与资本") + self.assertEqual(by_id["b"].section, "开发与基础设施") + self.assertEqual(report["rule_classified"], 2) + + def test_classify_orders_items_by_local_rank_score_within_sections(self): + items = [ + news_item("low", "普通模型更新", section_hint="模型发布/更新", source_priority=80), + news_item("high", "GPT-5 API 发布,延迟降低 30%", section_hint="模型发布/更新", source_priority=10), + ] + + classified, report = classify_and_order_items(items) + + self.assertEqual([item.id for item in classified], ["high", "low"]) + self.assertEqual(report["section_counts"]["模型与能力"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage6_guide.py b/tests/test_stage6_guide.py new file mode 100644 index 0000000..4399c4b --- /dev/null +++ b/tests/test_stage6_guide.py @@ -0,0 +1,77 @@ +import json +import unittest + +from ai_daily_report.guide import generate_guide +from ai_daily_report.models import NewsItem + + +def news_item(item_id, title, section="模型与能力"): + return NewsItem( + id=item_id, + source_group="AI HOT", + source_label="AI HOT", + source_role="primary", + source_priority=10, + title_raw=title, + title_norm=title.lower(), + summary_raw=f"{title} summary", + title=title, + summary=f"{title} summary", + url=f"https://example.com/{item_id}", + canonical_url=f"https://example.com/{item_id}", + section=section, + ) + + +class Stage6GuideTests(unittest.TestCase): + def test_generate_guide_returns_theme_and_valid_threads(self): + items = [ + news_item("a", "GPT-5 API 发布"), + news_item("b", "Miso One 开源语音模型"), + ] + + def llm_call(prompt): + return json.dumps( + { + "theme": "模型能力继续向 API 和实时语音两端推进。", + "threads": [ + { + "title": "模型能力继续推进", + "text": "GPT-5 API 和 Miso One 分别代表 API 能力和语音模型更新。", + "item_ids": ["a", "b"], + "kind": "thread", + }, + { + "title": "无效脉络", + "text": "这条引用了不存在的条目。", + "item_ids": ["missing"], + "kind": "thread", + }, + ], + }, + ensure_ascii=False, + ) + + guide, report = generate_guide(items, llm_call=llm_call) + + self.assertEqual(guide["theme"], "模型能力继续向 API 和实时语音两端推进。") + self.assertEqual(len(guide["threads"]), 1) + self.assertEqual(guide["threads"][0]["item_ids"], ["a", "b"]) + self.assertEqual(report["dropped_thread_count"], 1) + + def test_generate_guide_falls_back_when_llm_fails(self): + items = [news_item("a", "GPT-5 API 发布")] + + def llm_call(prompt): + raise TimeoutError("slow") + + guide, report = generate_guide(items, llm_call=llm_call) + + self.assertEqual(guide["theme"], "") + self.assertEqual(guide["threads"], []) + self.assertTrue(report["fallback_used"]) + self.assertIn("TimeoutError", report["errors"][0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage7_assemble.py b/tests/test_stage7_assemble.py new file mode 100644 index 0000000..e79b7e1 --- /dev/null +++ b/tests/test_stage7_assemble.py @@ -0,0 +1,65 @@ +import unittest + +from ai_daily_report.assemble import assemble_markdown, validate_markdown +from ai_daily_report.models import NewsItem + + +def news_item(item_id, title, section): + return NewsItem( + id=item_id, + source_group="AI HOT", + source_label="OpenAI:Blog", + source_role="primary", + source_priority=10, + title_raw=title, + title_norm=title.lower(), + summary_raw=f"{title} summary", + title=title, + summary=f"{title} summary", + url=f"https://example.com/{item_id}", + canonical_url=f"https://example.com/{item_id}", + section=section, + ) + + +class Stage7AssembleTests(unittest.TestCase): + def test_assemble_markdown_renders_sections_and_daily_threads(self): + items = [ + news_item("a", "GPT-5 API 发布", "模型与能力"), + news_item("b", "Anthropic 提交 IPO 文件", "公司与资本"), + ] + guide = { + "theme": "> 模型和资本两条线都在推进。[1]", + "threads": [ + { + "title": "模型能力产品化", + "text": "GPT-5 API 发布,说明模型能力继续进入产品入口。", + "item_ids": ["a"], + "kind": "thread", + } + ], + } + + md, report = assemble_markdown(items, guide) + + self.assertIn("## 导览", md) + self.assertIn("> 模型和资本两条线都在推进。", md) + self.assertIn("## 模型与能力", md) + self.assertIn("**1. GPT-5 API 发布**", md) + self.assertIn("**2. Anthropic 提交 IPO 文件**", md) + self.assertIn("## 今日脉络", md) + self.assertIn("- **模型能力产品化**", md) + self.assertNotIn("> >", md) + self.assertNotIn("[1]", md) + self.assertEqual(report["item_count"], 2) + self.assertEqual(report["blocking_errors"], []) + + def test_validate_markdown_blocks_empty_report(self): + report = validate_markdown("", []) + + self.assertIn("no_items", report["blocking_errors"]) + self.assertIn("markdown_too_short", report["blocking_errors"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_stage8_publish.py b/tests/test_stage8_publish.py new file mode 100644 index 0000000..0f7e342 --- /dev/null +++ b/tests/test_stage8_publish.py @@ -0,0 +1,76 @@ +import unittest + +from ai_daily_report.publish import publish_markdown + + +class FakeBlogClient: + def __init__(self): + self.created_payload = None + self.published_slug = None + + def create_post(self, payload): + self.created_payload = payload + return {"slug": "ai-2026-06-04"} + + def publish_post(self, slug): + self.published_slug = slug + + +class Stage8PublishTests(unittest.TestCase): + def test_publish_markdown_dry_run_does_not_call_client(self): + result = publish_markdown( + title="AI日报 · 2026-06-04", + markdown="## 导览\n\n> ok", + tags=["AI日报"], + slug="ai-2026-06-04", + base_url="https://blog.example", + mode="dry-run", + markdown_report={"blocking_errors": []}, + client=None, + ) + + self.assertEqual(result.status, "ok") + self.assertEqual(result.mode, "dry-run") + self.assertEqual(result.blog_url, "https://blog.example/posts/ai-2026-06-04") + self.assertTrue(result.public_ok) + + def test_publish_markdown_blocks_when_markdown_has_errors(self): + client = FakeBlogClient() + + result = publish_markdown( + title="AI日报 · 2026-06-04", + markdown="bad", + tags=["AI日报"], + slug="ai-2026-06-04", + base_url="https://blog.example", + mode="publish", + markdown_report={"blocking_errors": ["markdown_too_short"]}, + client=client, + ) + + self.assertEqual(result.status, "blocked") + self.assertIsNone(client.created_payload) + self.assertIn("markdown_too_short", result.error) + + def test_publish_markdown_publish_mode_calls_client(self): + client = FakeBlogClient() + + result = publish_markdown( + title="AI日报 · 2026-06-04", + markdown="## 导览\n\n> ok", + tags=["AI日报"], + slug="ai-2026-06-04", + base_url="https://blog.example", + mode="publish", + markdown_report={"blocking_errors": []}, + client=client, + ) + + self.assertEqual(result.status, "ok") + self.assertEqual(client.created_payload["title"], "AI日报 · 2026-06-04") + self.assertEqual(client.published_slug, "ai-2026-06-04") + self.assertEqual(result.blog_url, "https://blog.example/posts/ai-2026-06-04") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..48a42f3 --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,14 @@ +import unittest + +from ai_daily_report.validate import validate_report_markdown + + +class ValidateTests(unittest.TestCase): + def test_validate_report_markdown_delegates_markdown_checks(self): + report = validate_report_markdown("", []) + + self.assertIn("no_items", report["blocking_errors"]) + + +if __name__ == "__main__": + unittest.main()