Skip to content

Commit d907c0c

Browse files
authored
[Feature] Files API client: recover on download failures (#844) (#845)
## What changes are proposed in this pull request? 1. Extending Files API client to support resuming download on failures. New implementation tracks current offset in the input stream and issues a new download request from this point in case of an error. 2. New code path is enabled by 'DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT' config parameter. ## How is this tested? Added unit tests for the new code path: `% python3 -m pytest tests/test_files.py` --------- Signed-off-by: Kirill Safonov <[email protected]>
1 parent 6d6923e commit d907c0c

File tree

6 files changed

+559
-11
lines changed

6 files changed

+559
-11
lines changed

databricks/sdk/__init__.py

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

databricks/sdk/_base_client.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import logging
33
import urllib.parse
4+
from abc import ABC, abstractmethod
45
from datetime import timedelta
56
from types import TracebackType
67
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
@@ -285,8 +286,20 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
285286
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())
286287

287288

289+
class _RawResponse(ABC):
290+
291+
@abstractmethod
292+
# follows Response signature: https://github.com/psf/requests/blob/main/src/requests/models.py#L799
293+
def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False):
294+
pass
295+
296+
@abstractmethod
297+
def close(self):
298+
pass
299+
300+
288301
class _StreamingResponse(BinaryIO):
289-
_response: requests.Response
302+
_response: _RawResponse
290303
_buffer: bytes
291304
_content: Union[Iterator[bytes], None]
292305
_chunk_size: Union[int, None]
@@ -298,7 +311,7 @@ def fileno(self) -> int:
298311
def flush(self) -> int:
299312
pass
300313

301-
def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None):
314+
def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None):
302315
self._response = response
303316
self._buffer = b''
304317
self._content = None
@@ -308,7 +321,7 @@ def _open(self) -> None:
308321
if self._closed:
309322
raise ValueError("I/O operation on closed file")
310323
if not self._content:
311-
self._content = self._response.iter_content(chunk_size=self._chunk_size)
324+
self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False)
312325

313326
def __enter__(self) -> BinaryIO:
314327
self._open()

databricks/sdk/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ class Config:
9292
max_connections_per_pool: int = ConfigAttribute()
9393
databricks_environment: Optional[DatabricksEnvironment] = None
9494

95+
enable_experimental_files_api_client: bool = ConfigAttribute(
96+
env='DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT')
97+
files_api_client_download_max_total_recovers = None
98+
files_api_client_download_max_total_recovers_without_progressing = 1
99+
95100
def __init__(
96101
self,
97102
*,

databricks/sdk/mixins/files.py

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,35 @@
11
from __future__ import annotations
22

33
import base64
4+
import logging
45
import os
56
import pathlib
67
import platform
78
import shutil
89
import sys
910
from abc import ABC, abstractmethod
1011
from collections import deque
12+
from collections.abc import Iterator
1113
from io import BytesIO
1214
from types import TracebackType
1315
from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable,
14-
Iterator, Type, Union)
16+
Optional, Type, Union)
1517
from urllib import parse
1618

19+
from requests import RequestException
20+
21+
from .._base_client import _RawResponse, _StreamingResponse
1722
from .._property import _cached_property
1823
from ..errors import NotFound
1924
from ..service import files
25+
from ..service._internal import _escape_multi_segment_path_parameter
26+
from ..service.files import DownloadResponse
2027

2128
if TYPE_CHECKING:
2229
from _typeshed import Self
2330

31+
_LOG = logging.getLogger(__name__)
32+
2433

2534
class _DbfsIO(BinaryIO):
2635
MAX_CHUNK_SIZE = 1024 * 1024
@@ -636,3 +645,177 @@ def delete(self, path: str, *, recursive=False):
636645
if p.is_dir and not recursive:
637646
raise IOError('deleting directories requires recursive flag')
638647
p.delete(recursive=recursive)
648+
649+
650+
class FilesExt(files.FilesAPI):
651+
__doc__ = files.FilesAPI.__doc__
652+
653+
def __init__(self, api_client, config: Config):
654+
super().__init__(api_client)
655+
self._config = config.copy()
656+
657+
def download(self, file_path: str) -> DownloadResponse:
658+
"""Download a file.
659+
660+
Downloads a file of any size. The file contents are the response body.
661+
This is a standard HTTP file download, not a JSON RPC.
662+
663+
It is strongly recommended, for fault tolerance reasons,
664+
to iteratively consume from the stream with a maximum read(size)
665+
defined instead of using indefinite-size reads.
666+
667+
:param file_path: str
668+
The remote path of the file, e.g. /Volumes/path/to/your/file
669+
670+
:returns: :class:`DownloadResponse`
671+
"""
672+
673+
initial_response: DownloadResponse = self._download_raw_stream(file_path=file_path,
674+
start_byte_offset=0,
675+
if_unmodified_since_timestamp=None)
676+
677+
wrapped_response = self._wrap_stream(file_path, initial_response)
678+
initial_response.contents._response = wrapped_response
679+
return initial_response
680+
681+
def _download_raw_stream(self,
682+
file_path: str,
683+
start_byte_offset: int,
684+
if_unmodified_since_timestamp: Optional[str] = None) -> DownloadResponse:
685+
headers = {'Accept': 'application/octet-stream', }
686+
687+
if start_byte_offset and not if_unmodified_since_timestamp:
688+
raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified")
689+
690+
if start_byte_offset:
691+
headers['Range'] = f'bytes={start_byte_offset}-'
692+
693+
if if_unmodified_since_timestamp:
694+
headers['If-Unmodified-Since'] = if_unmodified_since_timestamp
695+
696+
response_headers = ['content-length', 'content-type', 'last-modified', ]
697+
res = self._api.do('GET',
698+
f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}',
699+
headers=headers,
700+
response_headers=response_headers,
701+
raw=True)
702+
703+
result = DownloadResponse.from_dict(res)
704+
if not isinstance(result.contents, _StreamingResponse):
705+
raise Exception("Internal error: response contents is of unexpected type: " +
706+
type(result.contents).__name__)
707+
708+
return result
709+
710+
def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse):
711+
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
712+
return _ResilientResponse(self,
713+
file_path,
714+
downloadResponse.last_modified,
715+
offset=0,
716+
underlying_response=underlying_response)
717+
718+
719+
class _ResilientResponse(_RawResponse):
720+
721+
def __init__(self, api: FilesExt, file_path: str, file_last_modified: str, offset: int,
722+
underlying_response: _RawResponse):
723+
self.api = api
724+
self.file_path = file_path
725+
self.underlying_response = underlying_response
726+
self.offset = offset
727+
self.file_last_modified = file_last_modified
728+
729+
def iter_content(self, chunk_size=1, decode_unicode=False):
730+
if decode_unicode:
731+
raise ValueError('Decode unicode is not supported')
732+
733+
iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False)
734+
self.iterator = _ResilientIterator(iterator, self.file_path, self.file_last_modified, self.offset,
735+
self.api, chunk_size)
736+
return self.iterator
737+
738+
def close(self):
739+
self.iterator.close()
740+
741+
742+
class _ResilientIterator(Iterator):
743+
# This class tracks current offset (returned to the client code)
744+
# and recovers from failures by requesting download from the current offset.
745+
746+
@staticmethod
747+
def _extract_raw_response(download_response: DownloadResponse) -> _RawResponse:
748+
streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse
749+
return streaming_response._response
750+
751+
def __init__(self, underlying_iterator, file_path: str, file_last_modified: str, offset: int,
752+
api: FilesExt, chunk_size: int):
753+
self._underlying_iterator = underlying_iterator
754+
self._api = api
755+
self._file_path = file_path
756+
757+
# Absolute current offset (0-based), i.e. number of bytes from the beginning of the file
758+
# that were so far returned to the caller code.
759+
self._offset = offset
760+
self._file_last_modified = file_last_modified
761+
self._chunk_size = chunk_size
762+
763+
self._total_recovers_count: int = 0
764+
self._recovers_without_progressing_count: int = 0
765+
self._closed: bool = False
766+
767+
def _should_recover(self) -> bool:
768+
if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers:
769+
_LOG.debug("Total recovers limit exceeded")
770+
return False
771+
if self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None and self._recovers_without_progressing_count >= self._api._config.files_api_client_download_max_total_recovers_without_progressing:
772+
_LOG.debug("No progression recovers limit exceeded")
773+
return False
774+
return True
775+
776+
def _recover(self) -> bool:
777+
if not self._should_recover():
778+
return False # recover suppressed, rethrow original exception
779+
780+
self._total_recovers_count += 1
781+
self._recovers_without_progressing_count += 1
782+
783+
try:
784+
self._underlying_iterator.close()
785+
786+
_LOG.debug("Trying to recover from offset " + str(self._offset))
787+
788+
# following call includes all the required network retries
789+
downloadResponse = self._api._download_raw_stream(self._file_path, self._offset,
790+
self._file_last_modified)
791+
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
792+
self._underlying_iterator = underlying_response.iter_content(chunk_size=self._chunk_size,
793+
decode_unicode=False)
794+
_LOG.debug("Recover succeeded")
795+
return True
796+
except:
797+
return False # recover failed, rethrow original exception
798+
799+
def __next__(self):
800+
if self._closed:
801+
# following _BaseClient
802+
raise ValueError("I/O operation on closed file")
803+
804+
while True:
805+
try:
806+
returned_bytes = next(self._underlying_iterator)
807+
self._offset += len(returned_bytes)
808+
self._recovers_without_progressing_count = 0
809+
return returned_bytes
810+
811+
except StopIteration:
812+
raise
813+
814+
# https://requests.readthedocs.io/en/latest/user/quickstart/#errors-and-exceptions
815+
except RequestException:
816+
if not self._recover():
817+
raise
818+
819+
def close(self):
820+
self._underlying_iterator.close()
821+
self._closed = True

tests/test_base_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from unittest.mock import Mock
66

77
import pytest
8-
import requests
98

109
from databricks.sdk import errors, useragent
11-
from databricks.sdk._base_client import _BaseClient, _StreamingResponse
10+
from databricks.sdk._base_client import (_BaseClient, _RawResponse,
11+
_StreamingResponse)
1212
from databricks.sdk.core import DatabricksError
1313

1414
from .clock import FakeClock
1515
from .fixture_server import http_fixture_server
1616

1717

18-
class DummyResponse(requests.Response):
18+
class DummyResponse(_RawResponse):
1919
_content: Iterator[bytes]
2020
_closed: bool = False
2121

@@ -293,9 +293,9 @@ def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
293293
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))
294294

295295
content_chunks = []
296-
mock_response = Mock(spec=requests.Response)
296+
mock_response = Mock(spec=_RawResponse)
297297

298-
def mock_iter_content(chunk_size):
298+
def mock_iter_content(chunk_size: int, decode_unicode: bool):
299299
# Simulate how requests would chunk the data.
300300
for i in range(0, len(test_data), chunk_size):
301301
chunk = test_data[i:i + chunk_size]

0 commit comments

Comments
 (0)