first commit
This commit is contained in:
138
sn-search-code/scripts/huggingface_search.py
Normal file
138
sn-search-code/scripts/huggingface_search.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
"""HuggingFace 搜索:模型、数据集、Space。通过 HuggingFace Hub API。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from search_utils import build_parser, get_client, get_key, make_item, make_result, print_json
|
||||
|
||||
|
||||
API_BASE = "https://huggingface.co/api"
|
||||
|
||||
SEARCH_TYPES = {
|
||||
"models": "models",
|
||||
"datasets": "datasets",
|
||||
"spaces": "spaces",
|
||||
"model": "models", # 别名
|
||||
"dataset": "datasets", # 别名
|
||||
"space": "spaces", # 别名
|
||||
}
|
||||
|
||||
# 过滤掉无信息量的内部 tag(地区、部署、引用文献等)
|
||||
_TAG_NOISE_PREFIXES = ("region:", "deploy:", "arxiv:", "dataset:", "endpoints_")
|
||||
|
||||
|
||||
def search(query: str, limit: int, search_type: str = "models", token: str | None = None) -> list[dict]:
|
||||
"""执行 HuggingFace 搜索。"""
|
||||
endpoint = SEARCH_TYPES.get(search_type, "models")
|
||||
url = f"{API_BASE}/{endpoint}"
|
||||
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
params = {
|
||||
"search": query,
|
||||
"limit": min(limit, 100),
|
||||
"full": "true",
|
||||
}
|
||||
|
||||
with get_client(headers=headers) as client:
|
||||
resp = client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
items = []
|
||||
for item in data[:limit]:
|
||||
if endpoint == "models":
|
||||
items.append(_parse_model(item))
|
||||
elif endpoint == "datasets":
|
||||
items.append(_parse_dataset(item))
|
||||
elif endpoint == "spaces":
|
||||
items.append(_parse_space(item))
|
||||
return items
|
||||
|
||||
|
||||
def _parse_model(item: dict) -> dict:
|
||||
model_id = item.get("id", "")
|
||||
tags = _filter_tags(item.get("tags", []))
|
||||
return make_item(
|
||||
title=model_id,
|
||||
url=f"https://huggingface.co/{model_id}",
|
||||
snippet=_model_snippet(item),
|
||||
pipeline_tag=item.get("pipeline_tag"),
|
||||
library=item.get("library_name"),
|
||||
downloads=item.get("downloads"),
|
||||
likes=item.get("likes"),
|
||||
tags=tags or None,
|
||||
last_modified=item.get("lastModified"),
|
||||
)
|
||||
|
||||
|
||||
def _parse_dataset(item: dict) -> dict:
|
||||
dataset_id = item.get("id", "")
|
||||
description = (item.get("description") or "").strip()
|
||||
tags = _filter_tags(item.get("tags", []))
|
||||
return make_item(
|
||||
title=dataset_id,
|
||||
url=f"https://huggingface.co/datasets/{dataset_id}",
|
||||
snippet=description,
|
||||
downloads=item.get("downloads"),
|
||||
likes=item.get("likes"),
|
||||
tags=tags or None,
|
||||
last_modified=item.get("lastModified"),
|
||||
)
|
||||
|
||||
|
||||
def _parse_space(item: dict) -> dict:
|
||||
space_id = item.get("id", "")
|
||||
tags = _filter_tags(item.get("tags", []))
|
||||
return make_item(
|
||||
title=space_id,
|
||||
url=f"https://huggingface.co/spaces/{space_id}",
|
||||
snippet=item.get("shortDescription") or "",
|
||||
sdk=item.get("sdk"),
|
||||
likes=item.get("likes"),
|
||||
tags=tags or None,
|
||||
last_modified=item.get("lastModified"),
|
||||
)
|
||||
|
||||
|
||||
def _model_snippet(item: dict) -> str:
|
||||
"""用 pipeline_tag + 关键 tag 拼出简短描述。"""
|
||||
parts = []
|
||||
if item.get("pipeline_tag"):
|
||||
parts.append(item["pipeline_tag"])
|
||||
if item.get("library_name"):
|
||||
parts.append(item["library_name"])
|
||||
# 保留语言 tag(如 en, zh)
|
||||
lang_tags = [t for t in (item.get("tags") or []) if len(t) <= 3 and t.isalpha()]
|
||||
if lang_tags:
|
||||
parts.append("lang:" + ",".join(lang_tags[:3]))
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
def _filter_tags(tags: list[str]) -> list[str]:
|
||||
"""过滤掉无信息量的内部 tag。"""
|
||||
return [t for t in tags if not any(t.startswith(p) for p in _TAG_NOISE_PREFIXES)]
|
||||
|
||||
|
||||
def main():
|
||||
parser = build_parser("搜索 HuggingFace 模型、数据集、Space")
|
||||
parser.add_argument("--type", "-t", default="models",
|
||||
choices=list(SEARCH_TYPES.keys()),
|
||||
help="搜索类型(默认 models)")
|
||||
parser.add_argument("--token", help="HuggingFace Token(也可通过 HF_TOKEN 环境变量设置,可选,提高限额)")
|
||||
args = parser.parse_args()
|
||||
|
||||
token = get_key("HF_TOKEN", args.token)
|
||||
try:
|
||||
items = search(args.query, args.limit, args.type, token)
|
||||
print_json(make_result(True, args.query, "huggingface", items))
|
||||
except Exception as e:
|
||||
print_json(make_result(False, args.query, "huggingface", [], str(e)))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user