Files
agent-skills/sn-search-code/scripts/huggingface_search.py
Hermes Agent ccc63d1e70 first commit
2026-05-10 13:52:46 +08:00

139 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""HuggingFace 搜索模型、数据集、Space。通过 HuggingFace Hub API。"""
from __future__ import annotations
import sys
from search_utils import build_parser, get_client, get_key, make_item, make_result, print_json
API_BASE = "https://huggingface.co/api"
SEARCH_TYPES = {
"models": "models",
"datasets": "datasets",
"spaces": "spaces",
"model": "models", # 别名
"dataset": "datasets", # 别名
"space": "spaces", # 别名
}
# 过滤掉无信息量的内部 tag地区、部署、引用文献等
_TAG_NOISE_PREFIXES = ("region:", "deploy:", "arxiv:", "dataset:", "endpoints_")
def search(query: str, limit: int, search_type: str = "models", token: str | None = None) -> list[dict]:
"""执行 HuggingFace 搜索。"""
endpoint = SEARCH_TYPES.get(search_type, "models")
url = f"{API_BASE}/{endpoint}"
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
params = {
"search": query,
"limit": min(limit, 100),
"full": "true",
}
with get_client(headers=headers) as client:
resp = client.get(url, params=params)
resp.raise_for_status()
data = resp.json()
items = []
for item in data[:limit]:
if endpoint == "models":
items.append(_parse_model(item))
elif endpoint == "datasets":
items.append(_parse_dataset(item))
elif endpoint == "spaces":
items.append(_parse_space(item))
return items
def _parse_model(item: dict) -> dict:
model_id = item.get("id", "")
tags = _filter_tags(item.get("tags", []))
return make_item(
title=model_id,
url=f"https://huggingface.co/{model_id}",
snippet=_model_snippet(item),
pipeline_tag=item.get("pipeline_tag"),
library=item.get("library_name"),
downloads=item.get("downloads"),
likes=item.get("likes"),
tags=tags or None,
last_modified=item.get("lastModified"),
)
def _parse_dataset(item: dict) -> dict:
dataset_id = item.get("id", "")
description = (item.get("description") or "").strip()
tags = _filter_tags(item.get("tags", []))
return make_item(
title=dataset_id,
url=f"https://huggingface.co/datasets/{dataset_id}",
snippet=description,
downloads=item.get("downloads"),
likes=item.get("likes"),
tags=tags or None,
last_modified=item.get("lastModified"),
)
def _parse_space(item: dict) -> dict:
space_id = item.get("id", "")
tags = _filter_tags(item.get("tags", []))
return make_item(
title=space_id,
url=f"https://huggingface.co/spaces/{space_id}",
snippet=item.get("shortDescription") or "",
sdk=item.get("sdk"),
likes=item.get("likes"),
tags=tags or None,
last_modified=item.get("lastModified"),
)
def _model_snippet(item: dict) -> str:
"""用 pipeline_tag + 关键 tag 拼出简短描述。"""
parts = []
if item.get("pipeline_tag"):
parts.append(item["pipeline_tag"])
if item.get("library_name"):
parts.append(item["library_name"])
# 保留语言 tag如 en, zh
lang_tags = [t for t in (item.get("tags") or []) if len(t) <= 3 and t.isalpha()]
if lang_tags:
parts.append("lang:" + ",".join(lang_tags[:3]))
return " | ".join(parts)
def _filter_tags(tags: list[str]) -> list[str]:
"""过滤掉无信息量的内部 tag。"""
return [t for t in tags if not any(t.startswith(p) for p in _TAG_NOISE_PREFIXES)]
def main():
parser = build_parser("搜索 HuggingFace 模型、数据集、Space")
parser.add_argument("--type", "-t", default="models",
choices=list(SEARCH_TYPES.keys()),
help="搜索类型(默认 models")
parser.add_argument("--token", help="HuggingFace Token也可通过 HF_TOKEN 环境变量设置,可选,提高限额)")
args = parser.parse_args()
token = get_key("HF_TOKEN", args.token)
try:
items = search(args.query, args.limit, args.type, token)
print_json(make_result(True, args.query, "huggingface", items))
except Exception as e:
print_json(make_result(False, args.query, "huggingface", [], str(e)))
sys.exit(1)
if __name__ == "__main__":
main()