Skip to content

Commit 2c7c19d

Browse files
hlkyhanouticelinaWauplin
authored
Daily Papers API (#2554)
* Daily Papers API * Apply suggestions from code review Co-authored-by: Celina Hanouti <[email protected]> * Apply suggestions from code review Co-authored-by: Celina Hanouti <[email protected]> * Fix tests * Run papers API tests independently * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * Remove date * additional test and update docstring --------- Co-authored-by: Celina Hanouti <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 613b591 commit 2c7c19d

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

src/huggingface_hub/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@
216216
"list_metrics",
217217
"list_models",
218218
"list_organization_members",
219+
"list_papers",
219220
"list_pending_access_requests",
220221
"list_rejected_access_requests",
221222
"list_repo_commits",
@@ -230,6 +231,7 @@
230231
"merge_pull_request",
231232
"model_info",
232233
"move_repo",
234+
"paper_info",
233235
"parse_safetensors_file_metadata",
234236
"pause_inference_endpoint",
235237
"pause_space",
@@ -741,6 +743,7 @@ def __dir__():
741743
list_metrics, # noqa: F401
742744
list_models, # noqa: F401
743745
list_organization_members, # noqa: F401
746+
list_papers, # noqa: F401
744747
list_pending_access_requests, # noqa: F401
745748
list_rejected_access_requests, # noqa: F401
746749
list_repo_commits, # noqa: F401
@@ -755,6 +758,7 @@ def __dir__():
755758
merge_pull_request, # noqa: F401
756759
model_info, # noqa: F401
757760
move_repo, # noqa: F401
761+
paper_info, # noqa: F401
758762
parse_safetensors_file_metadata, # noqa: F401
759763
pause_inference_endpoint, # noqa: F401
760764
pause_space, # noqa: F401

src/huggingface_hub/hf_api.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,70 @@ def __init__(self, **kwargs) -> None:
14571457
self.__dict__.update(**kwargs)
14581458

14591459

1460+
@dataclass
1461+
class PaperInfo:
1462+
"""
1463+
Contains information about a paper on the Hub.
1464+
1465+
Attributes:
1466+
id (`str`):
1467+
arXiv paper ID.
1468+
authors (`List[str]`, **optional**):
1469+
Names of paper authors
1470+
published_at (`datetime`, **optional**):
1471+
Date paper published.
1472+
title (`str`, **optional**):
1473+
Title of the paper.
1474+
summary (`str`, **optional**):
1475+
Summary of the paper.
1476+
upvotes (`int`, **optional**):
1477+
Number of upvotes for the paper on the Hub.
1478+
discussion_id (`str`, **optional**):
1479+
Discussion ID for the paper on the Hub.
1480+
source (`str`, **optional**):
1481+
Source of the paper.
1482+
comments (`int`, **optional**):
1483+
Number of comments for the paper on the Hub.
1484+
submitted_at (`datetime`, **optional**):
1485+
Date paper appeared in daily papers on the Hub.
1486+
submitted_by (`User`, **optional**):
1487+
Information about who submitted the daily paper.
1488+
"""
1489+
1490+
id: str
1491+
authors: Optional[List[str]]
1492+
published_at: Optional[datetime]
1493+
title: Optional[str]
1494+
summary: Optional[str]
1495+
upvotes: Optional[int]
1496+
discussion_id: Optional[str]
1497+
source: Optional[str]
1498+
comments: Optional[int]
1499+
submitted_at: Optional[datetime]
1500+
submitted_by: Optional[User]
1501+
1502+
def __init__(self, **kwargs) -> None:
1503+
paper = kwargs.pop("paper", {})
1504+
self.id = kwargs.pop("id", None) or paper.pop("id", None)
1505+
authors = paper.pop("authors", None) or kwargs.pop("authors", None)
1506+
self.authors = [author.pop("name", None) for author in authors] if authors else None
1507+
published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None)
1508+
self.published_at = parse_datetime(published_at) if published_at else None
1509+
self.title = kwargs.pop("title", None)
1510+
self.source = kwargs.pop("source", None)
1511+
self.summary = paper.pop("summary", None) or kwargs.pop("summary", None)
1512+
self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None)
1513+
self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None)
1514+
self.comments = kwargs.pop("numComments", 0)
1515+
submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None)
1516+
self.submitted_at = parse_datetime(submitted_at) if submitted_at else None
1517+
submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None)
1518+
self.submitted_by = User(**submitted_by) if submitted_by else None
1519+
1520+
# forward compatibility
1521+
self.__dict__.update(**kwargs)
1522+
1523+
14601524
def future_compatible(fn: CallableT) -> CallableT:
14611525
"""Wrap a method of `HfApi` to handle `run_as_future=True`.
14621526
@@ -9673,6 +9737,72 @@ def list_user_following(self, username: str, token: Union[bool, str, None] = Non
96739737
):
96749738
yield User(**followed_user)
96759739

9740+
def list_papers(
9741+
self,
9742+
*,
9743+
query: Optional[str] = None,
9744+
token: Union[bool, str, None] = None,
9745+
) -> Iterable[PaperInfo]:
9746+
"""
9747+
List daily papers on the Hugging Face Hub given a search query.
9748+
9749+
Args:
9750+
query (`str`, *optional*):
9751+
A search query string to find papers.
9752+
If provided, returns papers that match the query.
9753+
token (Union[bool, str, None], *optional*):
9754+
A valid user access token (string). Defaults to the locally saved
9755+
token, which is the recommended method for authentication (see
9756+
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
9757+
To disable authentication, pass `False`.
9758+
9759+
Returns:
9760+
`Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects.
9761+
9762+
Example:
9763+
9764+
```python
9765+
>>> from huggingface_hub import HfApi
9766+
9767+
>>> api = HfApi()
9768+
9769+
# List all papers with "attention" in their title
9770+
>>> api.list_papers(query="attention")
9771+
```
9772+
"""
9773+
path = f"{self.endpoint}/api/papers/search"
9774+
params = {}
9775+
if query:
9776+
params["q"] = query
9777+
r = get_session().get(
9778+
path,
9779+
params=params,
9780+
headers=self._build_hf_headers(token=token),
9781+
)
9782+
hf_raise_for_status(r)
9783+
for paper in r.json():
9784+
yield PaperInfo(**paper)
9785+
9786+
def paper_info(self, id: str) -> PaperInfo:
9787+
"""
9788+
Get information for a paper on the Hub.
9789+
9790+
Args:
9791+
id (`str`, **optional**):
9792+
ArXiv id of the paper.
9793+
9794+
Returns:
9795+
`PaperInfo`: A `PaperInfo` object.
9796+
9797+
Raises:
9798+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError):
9799+
HTTP 404 If the paper does not exist on the Hub.
9800+
"""
9801+
path = f"{self.endpoint}/api/papers/{id}"
9802+
r = get_session().get(path)
9803+
hf_raise_for_status(r)
9804+
return PaperInfo(**r.json())
9805+
96769806
def auth_check(
96779807
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
96789808
) -> None:
@@ -9768,6 +9898,9 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:
97689898
list_spaces = api.list_spaces
97699899
space_info = api.space_info
97709900

9901+
list_papers = api.list_papers
9902+
paper_info = api.paper_info
9903+
97719904
repo_exists = api.repo_exists
97729905
revision_exists = api.revision_exists
97739906
file_exists = api.file_exists

tests/test_hf_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4106,6 +4106,28 @@ def test_user_following(self) -> None:
41064106
assert len(list(following)) > 500
41074107

41084108

4109+
class PaperApiTest(unittest.TestCase):
4110+
@classmethod
4111+
@with_production_testing
4112+
def setUpClass(cls) -> None:
4113+
cls.api = HfApi()
4114+
return super().setUpClass()
4115+
4116+
def test_papers_by_query(self) -> None:
4117+
papers = list(self.api.list_papers(query="llama"))
4118+
assert len(papers) > 0
4119+
assert "The Llama 3 Herd of Models" in [paper.title for paper in papers]
4120+
4121+
def test_get_paper_by_id_success(self) -> None:
4122+
paper = self.api.paper_info("2407.21783")
4123+
assert paper.title == "The Llama 3 Herd of Models"
4124+
4125+
def test_get_paper_by_id_not_found(self) -> None:
4126+
with self.assertRaises(HfHubHTTPError) as context:
4127+
self.api.paper_info("1234.56789")
4128+
assert context.exception.response.status_code == 404
4129+
4130+
41094131
class WebhookApiTest(HfApiCommonTest):
41104132
def setUp(self) -> None:
41114133
super().setUp()

0 commit comments

Comments
 (0)