Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/crawlee/storages/_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,15 @@ async def add_requests_batched(
transformed_requests = self._transform_requests(requests)
wait_time_secs = wait_time_between_batches.total_seconds()

async def _process_batch(batch: Sequence[Request]) -> None:
request_count = len(batch)
response = await self._resource_client.batch_add_requests(batch)
self._assumed_total_count += request_count
logger.debug(f'Added {request_count} requests to the queue, response: {response}')

# Wait for the first batch to be added
first_batch = transformed_requests[:batch_size]
if first_batch:
await _process_batch(first_batch)
await self._process_batch(first_batch, base_retry_wait=wait_time_between_batches)

async def _process_remaining_batches() -> None:
for i in range(batch_size, len(transformed_requests), batch_size):
batch = transformed_requests[i : i + batch_size]
await _process_batch(batch)
await self._process_batch(batch, base_retry_wait=wait_time_between_batches)
if i + batch_size < len(transformed_requests):
await asyncio.sleep(wait_time_secs)

Expand All @@ -270,6 +264,31 @@ async def _process_remaining_batches() -> None:
timeout=wait_for_all_requests_to_be_added_timeout,
)

async def _process_batch(self, batch: Sequence[Request], base_retry_wait: timedelta, attempt: int = 1) -> None:
max_attempts = 5
response = await self._resource_client.batch_add_requests(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're at it, could you please make sure that ApifyStorageClient in the SDK does not do any retries of its own?

Of course, this is assuming we find that undesirable, but I firmly believe that the retry logic should be concentrated in a single place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it does. The retries that happens are currently in apify-client-python. First in Http clients and then in RequestQueueClients. Those retries are on different level and are both dealing with status codes of http responses and are not doing retries based on parsed content of successful API call - which is the case in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ApifyStorageClient does its own retries via apify-client-python. Because the apify-client implementation also does its own kind of batching, the situation gets extra blurry. It doesn't make sense to change that in this PR, but I think we should make an issue to resolve this some day. If you agree, that is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if response.unprocessed_requests:
logger.debug(f'Following requests were not processed: {response.unprocessed_requests}.')
if attempt > max_attempts:
logger.warning(
f'Following requests were not processed even after {max_attempts} attempts:\n'
f'{response.unprocessed_requests}'
)
else:
logger.debug('Retry to add requests.')
unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests}
retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys]
await asyncio.sleep((base_retry_wait * attempt).total_seconds())
await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1)

request_count = len(batch) - len(response.unprocessed_requests)
self._assumed_total_count += request_count
if request_count:
logger.debug(
f'Added {request_count} requests to the queue. Processed requests: {response.processed_requests}'
)

async def get_request(self, request_id: str) -> Request | None:
"""Retrieve a request from the queue.

Expand Down
83 changes: 81 additions & 2 deletions tests/unit/storages/test_request_queue.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from itertools import count
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock

import pytest
from pydantic import ValidationError

from crawlee import Request, service_locator
from crawlee._request import RequestState
from crawlee.storage_clients.models import StorageMetadata
from crawlee.storage_clients import MemoryStorageClient, StorageClient
from crawlee.storage_clients._memory import RequestQueueClient
from crawlee.storage_clients.models import (
BatchRequestsOperationResponse,
StorageMetadata,
UnprocessedRequest,
)
from crawlee.storages import RequestQueue

if TYPE_CHECKING:
Expand Down Expand Up @@ -286,3 +294,74 @@ async def test_from_storage_object() -> None:
assert request_queue.name == storage_object.name
assert request_queue.storage_object == storage_object
assert storage_object.model_extra.get('extra_attribute') == 'extra' # type: ignore[union-attr]


async def test_add_batched_requests_with_retry(request_queue: RequestQueue) -> None:
"""Test that unprocessed requests are retried.

Unprocessed requests should not count in `get_total_count`
Test creates situation where in `batch_add_requests` call in first batch 3 requests are unprocessed.
On each following `batch_add_requests` call the last request in batch remains unprocessed.
In this test `batch_add_requests` is called once with batch of 10 requests. With retries only 1 request should
remain unprocessed."""

batch_add_requests_call_counter = count(start=1)
service_locator.get_storage_client()
initial_request_count = 10
expected_added_requests = 9
requests = [f'https://example.com/{i}' for i in range(initial_request_count)]

class MockedRequestQueueClient(RequestQueueClient):
"""Patched memory storage client that simulates unprocessed requests."""

async def _batch_add_requests_without_last_n(
self, batch: Sequence[Request], n: int = 0
) -> BatchRequestsOperationResponse:
response = await super().batch_add_requests(batch[:-n])
response.unprocessed_requests = [
UnprocessedRequest(url=r.url, unique_key=r.unique_key, method=r.method) for r in batch[-n:]
]
return response

async def batch_add_requests(
self,
requests: Sequence[Request],
*,
forefront: bool = False, # noqa: ARG002
) -> BatchRequestsOperationResponse:
"""Mocked client behavior that simulates unprocessed requests.

It processes all except last three at first run, then all except last none.
Overall if tried with the same batch it will process all except the last one.
"""
call_count = next(batch_add_requests_call_counter)
if call_count == 1:
# Process all but last three
return await self._batch_add_requests_without_last_n(requests, n=3)
# Process all but last
return await self._batch_add_requests_without_last_n(requests, n=1)

mocked_storage_client = AsyncMock(spec=StorageClient)
mocked_storage_client.request_queue = MagicMock(
return_value=MockedRequestQueueClient(id='default', memory_storage_client=MemoryStorageClient.from_config())
)

request_queue = RequestQueue(id='default', name='some_name', storage_client=mocked_storage_client)

# Add the requests to the RQ in batches
await request_queue.add_requests_batched(
requests, wait_for_all_requests_to_be_added=True, wait_time_between_batches=timedelta(0)
)

# Ensure the batch was processed correctly
assert await request_queue.get_total_count() == expected_added_requests
# Fetch and validate each request in the queue
for original_request in requests[:expected_added_requests]:
next_request = await request_queue.fetch_next_request()
assert next_request is not None

expected_url = original_request if isinstance(original_request, str) else original_request.url
assert next_request.url == expected_url

# Confirm the queue is empty after processing all requests
assert await request_queue.is_empty() is True
Loading