Files
agent-skills/sn-search-academic/scripts/semantic_scholar_refs.py
Hermes Agent ccc63d1e70 first commit
2026-05-10 13:52:46 +08:00

239 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Semantic Scholar 引用追溯查询论文的参考文献backward和被引论文forward"""
from __future__ import annotations
import argparse
import sys
from search_utils import get_client, make_item, print_json
API_BASE = "https://api.semanticscholar.org/graph/v1/paper"
# paper-level fields嵌套在 citedPaper/citingPaper 下)
# 注意: tldr 在 nested 请求中容易触发 rate limit不请求
PAPER_FIELDS = [
"title", "abstract", "year", "venue", "publicationDate",
"authors", "citationCount", "influentialCitationCount",
"isOpenAccess", "openAccessPdf", "externalIds", "fieldsOfStudy",
]
# edge-level fields引用关系本身的属性
EDGE_FIELDS = ["contexts", "intents"]
def resolve_paper_id(identifier: str) -> str:
"""将各种论文标识符转为 Semantic Scholar 可接受的格式。
支持:
- Semantic Scholar paper ID (40-char hex)
- DOI: 10.xxxx/... → DOI:10.xxxx/...
- ArXiv ID: 2301.07041 → ARXIV:2301.07041
- PubMed ID: PMID:12345678
- URL: https://www.semanticscholar.org/paper/... → 提取 ID
"""
identifier = identifier.strip()
# S2 URL
if "semanticscholar.org/paper/" in identifier:
# URL 末尾的 40-char hex
parts = identifier.rstrip("/").split("/")
return parts[-1]
# DOI
if identifier.startswith("10."):
return f"DOI:{identifier}"
if identifier.lower().startswith("doi:"):
return identifier
# ArXiv
if identifier.lower().startswith("arxiv:"):
return identifier.upper()
# 形如 2301.07041 或 2301.07041v2
if "." in identifier and identifier.replace(".", "").replace("v", "").isdigit():
return f"ARXIV:{identifier}"
# PMID
if identifier.lower().startswith("pmid:"):
return identifier.upper()
# 假设是 S2 paper ID
return identifier
def fetch_refs(
paper_id: str,
direction: str,
limit: int,
min_citations: int,
year_min: int | None,
year_max: int | None,
api_key: str | None = None,
) -> dict:
"""获取论文的 references 或 citations。"""
resolved = resolve_paper_id(paper_id)
endpoint = f"{API_BASE}/{resolved}/{direction}"
headers: dict[str, str] = {}
if api_key:
headers["x-api-key"] = api_key
# S2 API 单次最多 1000分页用 offset
# S2 references/citations 端点paper fields 用 nested 前缀edge fields 直接列出
# 格式: fields=contexts,intents,citedPaper.title,citedPaper.year,...
paper_key_prefix = "citedPaper" if direction == "references" else "citingPaper"
prefixed_fields = [f"{paper_key_prefix}.{f}" for f in PAPER_FIELDS]
all_fields = ",".join(EDGE_FIELDS + prefixed_fields)
params = {
"fields": all_fields,
# citations 端点按时间倒序返回,需要多取才能找到高引论文
# references 通常较少(几十条),多取无害
"limit": 1000,
}
with get_client(timeout=30, headers=headers) as client:
resp = client.get(endpoint, params=params)
resp.raise_for_status()
data = resp.json()
# 获取论文本体信息(用于输出上下文)
paper_resp = None
with get_client(timeout=15, headers=headers) as client:
try:
r = client.get(f"{API_BASE}/{resolved}", params={"fields": "title,year,citationCount"})
r.raise_for_status()
paper_resp = r.json()
except Exception:
pass
# direction=references 时结构是 {"data": [{"citedPaper": {...}, "contexts": [...], "intents": [...]}]}
# direction=citations 时结构是 {"data": [{"citingPaper": {...}, "contexts": [...], "intents": [...]}]}
paper_key = "citedPaper" if direction == "references" else "citingPaper"
items = []
for entry in data.get("data", []):
paper = entry.get(paper_key, {})
if not paper or not paper.get("title"):
continue
year = paper.get("year")
citation_count = paper.get("citationCount") or 0
# 过滤
if citation_count < min_citations:
continue
if year_min and year and year < year_min:
continue
if year_max and year and year > year_max:
continue
authors = [a.get("name", "") for a in paper.get("authors", [])]
external_ids = paper.get("externalIds") or {}
doi = external_ids.get("DOI")
arxiv_id = external_ids.get("ArXiv")
s2_id = paper.get("paperId", "")
url = f"https://www.semanticscholar.org/paper/{s2_id}" if s2_id else ""
abstract = paper.get("abstract") or ""
snippet = abstract
open_access_pdf = None
if paper.get("openAccessPdf"):
open_access_pdf = paper["openAccessPdf"].get("url")
# contexts: 引用该论文时的上下文句子(仅 citations 方向有意义)
contexts = entry.get("contexts") or []
intents = entry.get("intents") or []
item = make_item(
title=paper.get("title", ""),
url=url,
snippet=snippet,
authors=authors,
year=year,
venue=paper.get("venue") or None,
publication_date=paper.get("publicationDate"),
citation_count=citation_count,
influential_citation_count=paper.get("influentialCitationCount"),
is_open_access=paper.get("isOpenAccess"),
open_access_pdf=open_access_pdf,
fields_of_study=paper.get("fieldsOfStudy") or None,
doi=doi,
arxiv_id=arxiv_id,
paper_id=s2_id,
citation_contexts=contexts[:3] if contexts else None, # 最多 3 条上下文
citation_intents=intents if intents else None,
)
items.append(item)
# 按引用数排序,取 top-N
items.sort(key=lambda x: x.get("citation_count", 0), reverse=True)
items = items[:limit]
result = {
"success": True,
"paper_id": resolved,
"direction": direction,
"provider": "semantic_scholar",
"items": items,
"total_available": len(data.get("data", [])),
"returned": len(items),
"error": None,
}
if paper_resp:
result["source_paper"] = {
"title": paper_resp.get("title"),
"year": paper_resp.get("year"),
"citation_count": paper_resp.get("citationCount"),
}
return result
def main():
parser = argparse.ArgumentParser(
description="查询论文的参考文献backward或被引论文forward"
)
parser.add_argument(
"paper_id",
help="论文标识符S2 ID、DOI如 10.1234/...、ArXiv ID如 2301.07041、PMID如 PMID:12345678",
)
parser.add_argument(
"direction",
choices=["references", "citations"],
help="references=参考文献backwardcitations=被引论文forward",
)
parser.add_argument("--limit", "-n", type=int, default=20, help="返回结果数量(默认 20")
parser.add_argument("--min-citations", type=int, default=0, help="最低引用数过滤(默认 0")
parser.add_argument("--year-min", type=int, default=None, help="最早年份过滤")
parser.add_argument("--year-max", type=int, default=None, help="最晚年份过滤")
parser.add_argument("--api-key", help="Semantic Scholar API Key可选")
args = parser.parse_args()
try:
result = fetch_refs(
args.paper_id,
args.direction,
args.limit,
args.min_citations,
args.year_min,
args.year_max,
getattr(args, "api_key", None),
)
print_json(result)
except Exception as e:
print_json({
"success": False,
"paper_id": args.paper_id,
"direction": args.direction,
"provider": "semantic_scholar",
"items": [],
"error": str(e),
})
sys.exit(1)
if __name__ == "__main__":
main()