#!/usr/bin/env python3 """Semantic Scholar 引用追溯:查询论文的参考文献(backward)和被引论文(forward)。""" from __future__ import annotations import argparse import sys from search_utils import get_client, make_item, print_json API_BASE = "https://api.semanticscholar.org/graph/v1/paper" # paper-level fields(嵌套在 citedPaper/citingPaper 下) # 注意: tldr 在 nested 请求中容易触发 rate limit,不请求 PAPER_FIELDS = [ "title", "abstract", "year", "venue", "publicationDate", "authors", "citationCount", "influentialCitationCount", "isOpenAccess", "openAccessPdf", "externalIds", "fieldsOfStudy", ] # edge-level fields(引用关系本身的属性) EDGE_FIELDS = ["contexts", "intents"] def resolve_paper_id(identifier: str) -> str: """将各种论文标识符转为 Semantic Scholar 可接受的格式。 支持: - Semantic Scholar paper ID (40-char hex) - DOI: 10.xxxx/... → DOI:10.xxxx/... - ArXiv ID: 2301.07041 → ARXIV:2301.07041 - PubMed ID: PMID:12345678 - URL: https://www.semanticscholar.org/paper/... → 提取 ID """ identifier = identifier.strip() # S2 URL if "semanticscholar.org/paper/" in identifier: # URL 末尾的 40-char hex parts = identifier.rstrip("/").split("/") return parts[-1] # DOI if identifier.startswith("10."): return f"DOI:{identifier}" if identifier.lower().startswith("doi:"): return identifier # ArXiv if identifier.lower().startswith("arxiv:"): return identifier.upper() # 形如 2301.07041 或 2301.07041v2 if "." in identifier and identifier.replace(".", "").replace("v", "").isdigit(): return f"ARXIV:{identifier}" # PMID if identifier.lower().startswith("pmid:"): return identifier.upper() # 假设是 S2 paper ID return identifier def fetch_refs( paper_id: str, direction: str, limit: int, min_citations: int, year_min: int | None, year_max: int | None, api_key: str | None = None, ) -> dict: """获取论文的 references 或 citations。""" resolved = resolve_paper_id(paper_id) endpoint = f"{API_BASE}/{resolved}/{direction}" headers: dict[str, str] = {} if api_key: headers["x-api-key"] = api_key # S2 API 单次最多 1000,分页用 offset # S2 references/citations 端点:paper fields 用 nested 前缀,edge fields 直接列出 # 格式: fields=contexts,intents,citedPaper.title,citedPaper.year,... paper_key_prefix = "citedPaper" if direction == "references" else "citingPaper" prefixed_fields = [f"{paper_key_prefix}.{f}" for f in PAPER_FIELDS] all_fields = ",".join(EDGE_FIELDS + prefixed_fields) params = { "fields": all_fields, # citations 端点按时间倒序返回,需要多取才能找到高引论文 # references 通常较少(几十条),多取无害 "limit": 1000, } with get_client(timeout=30, headers=headers) as client: resp = client.get(endpoint, params=params) resp.raise_for_status() data = resp.json() # 获取论文本体信息(用于输出上下文) paper_resp = None with get_client(timeout=15, headers=headers) as client: try: r = client.get(f"{API_BASE}/{resolved}", params={"fields": "title,year,citationCount"}) r.raise_for_status() paper_resp = r.json() except Exception: pass # direction=references 时结构是 {"data": [{"citedPaper": {...}, "contexts": [...], "intents": [...]}]} # direction=citations 时结构是 {"data": [{"citingPaper": {...}, "contexts": [...], "intents": [...]}]} paper_key = "citedPaper" if direction == "references" else "citingPaper" items = [] for entry in data.get("data", []): paper = entry.get(paper_key, {}) if not paper or not paper.get("title"): continue year = paper.get("year") citation_count = paper.get("citationCount") or 0 # 过滤 if citation_count < min_citations: continue if year_min and year and year < year_min: continue if year_max and year and year > year_max: continue authors = [a.get("name", "") for a in paper.get("authors", [])] external_ids = paper.get("externalIds") or {} doi = external_ids.get("DOI") arxiv_id = external_ids.get("ArXiv") s2_id = paper.get("paperId", "") url = f"https://www.semanticscholar.org/paper/{s2_id}" if s2_id else "" abstract = paper.get("abstract") or "" snippet = abstract open_access_pdf = None if paper.get("openAccessPdf"): open_access_pdf = paper["openAccessPdf"].get("url") # contexts: 引用该论文时的上下文句子(仅 citations 方向有意义) contexts = entry.get("contexts") or [] intents = entry.get("intents") or [] item = make_item( title=paper.get("title", ""), url=url, snippet=snippet, authors=authors, year=year, venue=paper.get("venue") or None, publication_date=paper.get("publicationDate"), citation_count=citation_count, influential_citation_count=paper.get("influentialCitationCount"), is_open_access=paper.get("isOpenAccess"), open_access_pdf=open_access_pdf, fields_of_study=paper.get("fieldsOfStudy") or None, doi=doi, arxiv_id=arxiv_id, paper_id=s2_id, citation_contexts=contexts[:3] if contexts else None, # 最多 3 条上下文 citation_intents=intents if intents else None, ) items.append(item) # 按引用数排序,取 top-N items.sort(key=lambda x: x.get("citation_count", 0), reverse=True) items = items[:limit] result = { "success": True, "paper_id": resolved, "direction": direction, "provider": "semantic_scholar", "items": items, "total_available": len(data.get("data", [])), "returned": len(items), "error": None, } if paper_resp: result["source_paper"] = { "title": paper_resp.get("title"), "year": paper_resp.get("year"), "citation_count": paper_resp.get("citationCount"), } return result def main(): parser = argparse.ArgumentParser( description="查询论文的参考文献(backward)或被引论文(forward)" ) parser.add_argument( "paper_id", help="论文标识符:S2 ID、DOI(如 10.1234/...)、ArXiv ID(如 2301.07041)、PMID(如 PMID:12345678)", ) parser.add_argument( "direction", choices=["references", "citations"], help="references=参考文献(backward),citations=被引论文(forward)", ) parser.add_argument("--limit", "-n", type=int, default=20, help="返回结果数量(默认 20)") parser.add_argument("--min-citations", type=int, default=0, help="最低引用数过滤(默认 0)") parser.add_argument("--year-min", type=int, default=None, help="最早年份过滤") parser.add_argument("--year-max", type=int, default=None, help="最晚年份过滤") parser.add_argument("--api-key", help="Semantic Scholar API Key(可选)") args = parser.parse_args() try: result = fetch_refs( args.paper_id, args.direction, args.limit, args.min_citations, args.year_min, args.year_max, getattr(args, "api_key", None), ) print_json(result) except Exception as e: print_json({ "success": False, "paper_id": args.paper_id, "direction": args.direction, "provider": "semantic_scholar", "items": [], "error": str(e), }) sys.exit(1) if __name__ == "__main__": main()