first commit

This commit is contained in:
Hermes Agent
2026-05-10 13:52:46 +08:00
commit ccc63d1e70
4583 changed files with 584341 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python3
"""Semantic Scholar 引用追溯查询论文的参考文献backward和被引论文forward"""
from __future__ import annotations
import argparse
import sys
from search_utils import get_client, make_item, print_json
API_BASE = "https://api.semanticscholar.org/graph/v1/paper"
# paper-level fields嵌套在 citedPaper/citingPaper 下)
# 注意: tldr 在 nested 请求中容易触发 rate limit不请求
PAPER_FIELDS = [
"title", "abstract", "year", "venue", "publicationDate",
"authors", "citationCount", "influentialCitationCount",
"isOpenAccess", "openAccessPdf", "externalIds", "fieldsOfStudy",
]
# edge-level fields引用关系本身的属性
EDGE_FIELDS = ["contexts", "intents"]
def resolve_paper_id(identifier: str) -> str:
"""将各种论文标识符转为 Semantic Scholar 可接受的格式。
支持:
- Semantic Scholar paper ID (40-char hex)
- DOI: 10.xxxx/... → DOI:10.xxxx/...
- ArXiv ID: 2301.07041 → ARXIV:2301.07041
- PubMed ID: PMID:12345678
- URL: https://www.semanticscholar.org/paper/... → 提取 ID
"""
identifier = identifier.strip()
# S2 URL
if "semanticscholar.org/paper/" in identifier:
# URL 末尾的 40-char hex
parts = identifier.rstrip("/").split("/")
return parts[-1]
# DOI
if identifier.startswith("10."):
return f"DOI:{identifier}"
if identifier.lower().startswith("doi:"):
return identifier
# ArXiv
if identifier.lower().startswith("arxiv:"):
return identifier.upper()
# 形如 2301.07041 或 2301.07041v2
if "." in identifier and identifier.replace(".", "").replace("v", "").isdigit():
return f"ARXIV:{identifier}"
# PMID
if identifier.lower().startswith("pmid:"):
return identifier.upper()
# 假设是 S2 paper ID
return identifier
def fetch_refs(
paper_id: str,
direction: str,
limit: int,
min_citations: int,
year_min: int | None,
year_max: int | None,
api_key: str | None = None,
) -> dict:
"""获取论文的 references 或 citations。"""
resolved = resolve_paper_id(paper_id)
endpoint = f"{API_BASE}/{resolved}/{direction}"
headers: dict[str, str] = {}
if api_key:
headers["x-api-key"] = api_key
# S2 API 单次最多 1000分页用 offset
# S2 references/citations 端点paper fields 用 nested 前缀edge fields 直接列出
# 格式: fields=contexts,intents,citedPaper.title,citedPaper.year,...
paper_key_prefix = "citedPaper" if direction == "references" else "citingPaper"
prefixed_fields = [f"{paper_key_prefix}.{f}" for f in PAPER_FIELDS]
all_fields = ",".join(EDGE_FIELDS + prefixed_fields)
params = {
"fields": all_fields,
# citations 端点按时间倒序返回,需要多取才能找到高引论文
# references 通常较少(几十条),多取无害
"limit": 1000,
}
with get_client(timeout=30, headers=headers) as client:
resp = client.get(endpoint, params=params)
resp.raise_for_status()
data = resp.json()
# 获取论文本体信息(用于输出上下文)
paper_resp = None
with get_client(timeout=15, headers=headers) as client:
try:
r = client.get(f"{API_BASE}/{resolved}", params={"fields": "title,year,citationCount"})
r.raise_for_status()
paper_resp = r.json()
except Exception:
pass
# direction=references 时结构是 {"data": [{"citedPaper": {...}, "contexts": [...], "intents": [...]}]}
# direction=citations 时结构是 {"data": [{"citingPaper": {...}, "contexts": [...], "intents": [...]}]}
paper_key = "citedPaper" if direction == "references" else "citingPaper"
items = []
for entry in data.get("data", []):
paper = entry.get(paper_key, {})
if not paper or not paper.get("title"):
continue
year = paper.get("year")
citation_count = paper.get("citationCount") or 0
# 过滤
if citation_count < min_citations:
continue
if year_min and year and year < year_min:
continue
if year_max and year and year > year_max:
continue
authors = [a.get("name", "") for a in paper.get("authors", [])]
external_ids = paper.get("externalIds") or {}
doi = external_ids.get("DOI")
arxiv_id = external_ids.get("ArXiv")
s2_id = paper.get("paperId", "")
url = f"https://www.semanticscholar.org/paper/{s2_id}" if s2_id else ""
abstract = paper.get("abstract") or ""
snippet = abstract
open_access_pdf = None
if paper.get("openAccessPdf"):
open_access_pdf = paper["openAccessPdf"].get("url")
# contexts: 引用该论文时的上下文句子(仅 citations 方向有意义)
contexts = entry.get("contexts") or []
intents = entry.get("intents") or []
item = make_item(
title=paper.get("title", ""),
url=url,
snippet=snippet,
authors=authors,
year=year,
venue=paper.get("venue") or None,
publication_date=paper.get("publicationDate"),
citation_count=citation_count,
influential_citation_count=paper.get("influentialCitationCount"),
is_open_access=paper.get("isOpenAccess"),
open_access_pdf=open_access_pdf,
fields_of_study=paper.get("fieldsOfStudy") or None,
doi=doi,
arxiv_id=arxiv_id,
paper_id=s2_id,
citation_contexts=contexts[:3] if contexts else None, # 最多 3 条上下文
citation_intents=intents if intents else None,
)
items.append(item)
# 按引用数排序,取 top-N
items.sort(key=lambda x: x.get("citation_count", 0), reverse=True)
items = items[:limit]
result = {
"success": True,
"paper_id": resolved,
"direction": direction,
"provider": "semantic_scholar",
"items": items,
"total_available": len(data.get("data", [])),
"returned": len(items),
"error": None,
}
if paper_resp:
result["source_paper"] = {
"title": paper_resp.get("title"),
"year": paper_resp.get("year"),
"citation_count": paper_resp.get("citationCount"),
}
return result
def main():
parser = argparse.ArgumentParser(
description="查询论文的参考文献backward或被引论文forward"
)
parser.add_argument(
"paper_id",
help="论文标识符S2 ID、DOI如 10.1234/...、ArXiv ID如 2301.07041、PMID如 PMID:12345678",
)
parser.add_argument(
"direction",
choices=["references", "citations"],
help="references=参考文献backwardcitations=被引论文forward",
)
parser.add_argument("--limit", "-n", type=int, default=20, help="返回结果数量(默认 20")
parser.add_argument("--min-citations", type=int, default=0, help="最低引用数过滤(默认 0")
parser.add_argument("--year-min", type=int, default=None, help="最早年份过滤")
parser.add_argument("--year-max", type=int, default=None, help="最晚年份过滤")
parser.add_argument("--api-key", help="Semantic Scholar API Key可选")
args = parser.parse_args()
try:
result = fetch_refs(
args.paper_id,
args.direction,
args.limit,
args.min_citations,
args.year_min,
args.year_max,
getattr(args, "api_key", None),
)
print_json(result)
except Exception as e:
print_json({
"success": False,
"paper_id": args.paper_id,
"direction": args.direction,
"provider": "semantic_scholar",
"items": [],
"error": str(e),
})
sys.exit(1)
if __name__ == "__main__":
main()