diff --git a/databricks/sdk/service/files.py b/databricks/sdk/service/files.py index 394aa8697..d8b29af36 100755 --- a/databricks/sdk/service/files.py +++ b/databricks/sdk/service/files.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, BinaryIO, Dict, Iterator, List, Optional +from requests.utils import super_len + from ._internal import _escape_multi_segment_path_parameter, _repeated_dict _LOG = logging.getLogger("databricks.sdk") @@ -1259,6 +1261,10 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool "Content-Type": "application/octet-stream", } + length = super_len(contents) + if length > 0: + headers["Content-Length"] = str(length) + self._api.do( "PUT", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", diff --git a/tests/test_files.py b/tests/test_files.py index e25035523..46790ebb8 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1279,8 +1279,6 @@ def run(self, config: Config): session = requests.Session() with requests_mock.Mocker(session=session) as session_mock: - session_mock.get(f"http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}", status_code=200) - upload_state = SingleShotUploadState() def custom_matcher(request): @@ -1855,3 +1853,135 @@ def to_string(test_case): ) def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase): test_case.run(config) + + +class SingleShotUploadContentLengthTestCase: + + def __init__( + self, + name: str, + contents: Callable[[], io.IOBase], + expected_content_length: Optional[int], + cleanup: Callable[[io.IOBase], None] = None, + ): + super().__init__() + self.name = name + self.contents = contents + self.expected_content_length = expected_content_length + self.cleanup = cleanup + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + def run(self, config: Config): + config = config.copy() + config.enable_experimental_files_api_client = False # enforce single-shot upload + + file_path = "/test.txt" + contents = self.contents() + + try: + with requests_mock.Mocker() as session_mock: + + def custom_matcher(request): + request_url = urlparse(request.url) + + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{file_path}" + and request.method == "PUT" + ): + body = request.body.read() + + if self.expected_content_length: + content_length = request.headers["Content-Length"] + assert self.expected_content_length == int(content_length) + assert len(body) == int(content_length) + else: + assert request.headers.get("Content-Length") is None + + resp = requests.Response() + resp.status_code = 204 + resp.request = request + resp._content = b"" + return resp + return None + + session_mock.add_matcher(matcher=custom_matcher) + + w = WorkspaceClient(config=config) + w.files.upload(file_path, contents) + finally: + if self.cleanup: + self.cleanup(contents) + + +def make_non_seekable(stream: io.IOBase, disable_seek: bool = False, disable_tell: bool = False): + def raise_(ex): + raise ex + + stream.seekable = lambda: False # checked by BaseClient._is_seekable_stream() + + # requests.super_len() does not check seekable(), it calls seek() and tell() directly + if disable_seek: + stream.seek = lambda offset, whence: raise_(OSError()) + if disable_tell: + stream.tell = lambda: raise_(OSError()) + return stream + + +def create_file_stream(length: int) -> io.IOBase: + fd, temp_file = mkstemp() + with open(fd, "wb") as f: + f.write(os.urandom(length)) + + stream = open(temp_file, "rb") + + def cleanup(): + try: + stream.close() + finally: + os.remove(temp_file) + + stream.cleanup = lambda: cleanup() + return stream + + +@pytest.mark.parametrize( + "test_case", + [ + SingleShotUploadContentLengthTestCase( + "Empty contents treated as unknown length", contents=lambda: io.BytesIO(b""), expected_content_length=None + ), + SingleShotUploadContentLengthTestCase("Bytes", contents=lambda: io.BytesIO(b"abc"), expected_content_length=3), + SingleShotUploadContentLengthTestCase( + "seek disabled: length unknown", + contents=lambda: make_non_seekable(io.BytesIO(b"abc"), disable_seek=True), + expected_content_length=None, + ), + SingleShotUploadContentLengthTestCase( + "tell disabled: length unknown", + contents=lambda: make_non_seekable(io.BytesIO(b"abc"), disable_tell=True), + expected_content_length=None, + ), + SingleShotUploadContentLengthTestCase( + "File stream: length reported", + contents=lambda: create_file_stream(566), + expected_content_length=566, + cleanup=lambda stream: stream.cleanup(), + ), + SingleShotUploadContentLengthTestCase( + "File stream with tell disabled: length unknown", + contents=lambda: make_non_seekable(create_file_stream(239), disable_tell=True), + expected_content_length=None, + cleanup=lambda stream: stream.cleanup(), + ), + ], + ids=SingleShotUploadContentLengthTestCase.to_string, +) +def test_content_length(config: Config, test_case: SingleShotUploadContentLengthTestCase): + test_case.run(config)