Skip to content

Commit c9907d6

Browse files
authored
fix: respect EnqueueLinksKwargs for extract_links function (#1213)
### Description - Fix support `EnqueueLinksKwargs` parameters for `extract_links` function ### Issues - Closes: #1212
1 parent 7cb3776 commit c9907d6

File tree

5 files changed

+120
-43
lines changed

5 files changed

+120
-43
lines changed

src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from abc import ABC
66
from typing import TYPE_CHECKING, Any, Callable, Generic, Union
7+
from urllib.parse import urlparse
78

89
from pydantic import ValidationError
910
from typing_extensions import TypeVar
@@ -155,15 +156,21 @@ async def extract_links(
155156
| None = None,
156157
**kwargs: Unpack[EnqueueLinksKwargs],
157158
) -> list[Request]:
158-
kwargs.setdefault('strategy', 'same-hostname')
159-
160159
requests = list[Request]()
161160
skipped = list[str]()
162161
base_user_data = user_data or {}
163162

164163
robots_txt_file = await self._get_robots_txt_file_for_url(context.request.url)
165164

165+
strategy = kwargs.get('strategy', 'same-hostname')
166+
include_blobs = kwargs.get('include')
167+
exclude_blobs = kwargs.get('exclude')
168+
limit_requests = kwargs.get('limit')
169+
166170
for link in self._parser.find_links(parsed_content, selector=selector):
171+
if limit_requests and len(requests) >= limit_requests:
172+
break
173+
167174
url = link
168175
if not is_url_absolute(url):
169176
base_url = context.request.loaded_url or context.request.url
@@ -173,26 +180,31 @@ async def extract_links(
173180
skipped.append(url)
174181
continue
175182

176-
request_options = RequestOptions(url=url, user_data={**base_user_data}, label=label)
177-
178-
if transform_request_function:
179-
transform_request_options = transform_request_function(request_options)
180-
if transform_request_options == 'skip':
183+
if self._check_enqueue_strategy(
184+
strategy,
185+
target_url=urlparse(url),
186+
origin_url=urlparse(context.request.url),
187+
) and self._check_url_patterns(url, include_blobs, exclude_blobs):
188+
request_options = RequestOptions(url=url, user_data={**base_user_data}, label=label)
189+
190+
if transform_request_function:
191+
transform_request_options = transform_request_function(request_options)
192+
if transform_request_options == 'skip':
193+
continue
194+
if transform_request_options != 'unchanged':
195+
request_options = transform_request_options
196+
197+
try:
198+
request = Request.from_url(**request_options)
199+
except ValidationError as exc:
200+
context.log.debug(
201+
f'Skipping URL "{url}" due to invalid format: {exc}. '
202+
'This may be caused by a malformed URL or unsupported URL scheme. '
203+
'Please ensure the URL is correct and retry.'
204+
)
181205
continue
182-
if transform_request_options != 'unchanged':
183-
request_options = transform_request_options
184-
185-
try:
186-
request = Request.from_url(**request_options)
187-
except ValidationError as exc:
188-
context.log.debug(
189-
f'Skipping URL "{url}" due to invalid format: {exc}. '
190-
'This may be caused by a malformed URL or unsupported URL scheme. '
191-
'Please ensure the URL is correct and retry.'
192-
)
193-
continue
194206

195-
requests.append(request)
207+
requests.append(request)
196208

197209
if skipped:
198210
skipped_tasks = [

src/crawlee/crawlers/_playwright/_playwright_crawler.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from functools import partial
77
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union
8+
from urllib.parse import urlparse
89

910
from pydantic import ValidationError
1011
from typing_extensions import NotRequired, TypedDict, TypeVar
@@ -344,8 +345,6 @@ async def extract_links(
344345
345346
The `PlaywrightCrawler` implementation of the `ExtractLinksFunction` function.
346347
"""
347-
kwargs.setdefault('strategy', 'same-hostname')
348-
349348
requests = list[Request]()
350349
skipped = list[str]()
351350
base_user_data = user_data or {}
@@ -354,7 +353,15 @@ async def extract_links(
354353

355354
robots_txt_file = await self._get_robots_txt_file_for_url(context.request.url)
356355

356+
strategy = kwargs.get('strategy', 'same-hostname')
357+
include_blobs = kwargs.get('include')
358+
exclude_blobs = kwargs.get('exclude')
359+
limit_requests = kwargs.get('limit')
360+
357361
for element in elements:
362+
if limit_requests and len(requests) >= limit_requests:
363+
break
364+
358365
url = await element.get_attribute('href')
359366

360367
if url:
@@ -368,26 +375,31 @@ async def extract_links(
368375
skipped.append(url)
369376
continue
370377

371-
request_option = RequestOptions({'url': url, 'user_data': {**base_user_data}, 'label': label})
372-
373-
if transform_request_function:
374-
transform_request_option = transform_request_function(request_option)
375-
if transform_request_option == 'skip':
378+
if self._check_enqueue_strategy(
379+
strategy,
380+
target_url=urlparse(url),
381+
origin_url=urlparse(context.request.url),
382+
) and self._check_url_patterns(url, include_blobs, exclude_blobs):
383+
request_option = RequestOptions({'url': url, 'user_data': {**base_user_data}, 'label': label})
384+
385+
if transform_request_function:
386+
transform_request_option = transform_request_function(request_option)
387+
if transform_request_option == 'skip':
388+
continue
389+
if transform_request_option != 'unchanged':
390+
request_option = transform_request_option
391+
392+
try:
393+
request = Request.from_url(**request_option)
394+
except ValidationError as exc:
395+
context.log.debug(
396+
f'Skipping URL "{url}" due to invalid format: {exc}. '
397+
'This may be caused by a malformed URL or unsupported URL scheme. '
398+
'Please ensure the URL is correct and retry.'
399+
)
376400
continue
377-
if transform_request_option != 'unchanged':
378-
request_option = transform_request_option
379-
380-
try:
381-
request = Request.from_url(**request_option)
382-
except ValidationError as exc:
383-
context.log.debug(
384-
f'Skipping URL "{url}" due to invalid format: {exc}. '
385-
'This may be caused by a malformed URL or unsupported URL scheme. '
386-
'Please ensure the URL is correct and retry.'
387-
)
388-
continue
389401

390-
requests.append(request)
402+
requests.append(request)
391403

392404
if skipped:
393405
skipped_tasks = [

tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING
44
from unittest import mock
55

6-
from crawlee import ConcurrencySettings, HttpHeaders, RequestTransformAction, SkippedReason
6+
from crawlee import ConcurrencySettings, Glob, HttpHeaders, RequestTransformAction, SkippedReason
77
from crawlee.crawlers import BeautifulSoupCrawler, BeautifulSoupCrawlingContext
88

99
if TYPE_CHECKING:
@@ -183,3 +183,18 @@ async def skipped_hook(url: str, _reason: SkippedReason) -> None:
183183
str(server_url / 'page_2'),
184184
str(server_url / 'page_3'),
185185
}
186+
187+
188+
async def test_extract_links(server_url: URL, http_client: HttpClient) -> None:
189+
crawler = BeautifulSoupCrawler(http_client=http_client)
190+
extracted_links: list[str] = []
191+
192+
@crawler.router.default_handler
193+
async def request_handler(context: BeautifulSoupCrawlingContext) -> None:
194+
links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')])
195+
extracted_links.extend(request.url for request in links)
196+
197+
await crawler.run([str(server_url / 'start_enqueue')])
198+
199+
assert len(extracted_links) == 1
200+
assert extracted_links[0] == str(server_url / 'page_1')

tests/unit/crawlers/_parsel/test_parsel_crawler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from crawlee import ConcurrencySettings, HttpHeaders, Request, RequestTransformAction, SkippedReason
9+
from crawlee import ConcurrencySettings, Glob, HttpHeaders, Request, RequestTransformAction, SkippedReason
1010
from crawlee.crawlers import ParselCrawler
1111

1212
if TYPE_CHECKING:
@@ -279,3 +279,18 @@ async def skipped_hook(url: str, _reason: SkippedReason) -> None:
279279
str(server_url / 'page_2'),
280280
str(server_url / 'page_3'),
281281
}
282+
283+
284+
async def test_extract_links(server_url: URL, http_client: HttpClient) -> None:
285+
crawler = ParselCrawler(http_client=http_client)
286+
extracted_links: list[str] = []
287+
288+
@crawler.router.default_handler
289+
async def request_handler(context: ParselCrawlingContext) -> None:
290+
links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')])
291+
extracted_links.extend(request.url for request in links)
292+
293+
await crawler.run([str(server_url / 'start_enqueue')])
294+
295+
assert len(extracted_links) == 1
296+
assert extracted_links[0] == str(server_url / 'page_1')

tests/unit/crawlers/_playwright/test_playwright_crawler.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111

1212
import pytest
1313

14-
from crawlee import ConcurrencySettings, HttpHeaders, Request, RequestTransformAction, SkippedReason, service_locator
14+
from crawlee import (
15+
ConcurrencySettings,
16+
Glob,
17+
HttpHeaders,
18+
Request,
19+
RequestTransformAction,
20+
SkippedReason,
21+
service_locator,
22+
)
1523
from crawlee.configuration import Configuration
1624
from crawlee.crawlers import PlaywrightCrawler
1725
from crawlee.fingerprint_suite import (
@@ -704,3 +712,18 @@ async def test_overwrite_configuration() -> None:
704712
PlaywrightCrawler(configuration=configuration)
705713
used_configuration = service_locator.get_configuration()
706714
assert used_configuration is configuration
715+
716+
717+
async def test_extract_links(server_url: URL) -> None:
718+
crawler = PlaywrightCrawler()
719+
extracted_links: list[str] = []
720+
721+
@crawler.router.default_handler
722+
async def request_handler(context: PlaywrightCrawlingContext) -> None:
723+
links = await context.extract_links(exclude=[Glob(f'{server_url}sub_index')])
724+
extracted_links.extend(request.url for request in links)
725+
726+
await crawler.run([str(server_url / 'start_enqueue')])
727+
728+
assert len(extracted_links) == 1
729+
assert extracted_links[0] == str(server_url / 'page_1')

0 commit comments

Comments
 (0)