diff --git a/composer/loggers/logger_hparams_registry.py b/composer/loggers/logger_hparams_registry.py index 9f027ba9b6..6f3c1d600a 100644 --- a/composer/loggers/logger_hparams_registry.py +++ b/composer/loggers/logger_hparams_registry.py @@ -36,7 +36,7 @@ class RemoteUploaderDownloaderHparams(hp.Hparams): Args: object_store_hparams (ObjectStoreHparams): The object store provider hparams. - object_name (str, optional): See :class:`.RemoteUploaderDownloader`. + file_path_format_string (str, optional): See :class:`.RemoteUploaderDownloader`. num_concurrent_uploads (int, optional): See :class:`.RemoteUploaderDownloader`. upload_staging_folder (str, optional): See :class:`.RemoteUploaderDownloader`. use_procs (bool, optional): See :class:`.RemoteUploaderDownloader`. @@ -46,8 +46,9 @@ class RemoteUploaderDownloaderHparams(hp.Hparams): 'object_store_hparams': object_store_registry, } - object_store_hparams: ObjectStoreHparams = hp.required('Object store provider hparams.') - object_name: str = hp.auto(RemoteUploaderDownloader, 'object_name') + bucket_uri: str = hp.required('Remote bucket uri') + object_store_hparams: Optional[ObjectStoreHparams] = hp.optional('Object store provider hparams.', default=None) + file_path_format_string: str = hp.auto(RemoteUploaderDownloader, 'file_path_format_string') num_concurrent_uploads: int = hp.auto(RemoteUploaderDownloader, 'num_concurrent_uploads') use_procs: bool = hp.auto(RemoteUploaderDownloader, 'use_procs') upload_staging_folder: Optional[str] = hp.auto(RemoteUploaderDownloader, 'upload_staging_folder') @@ -55,9 +56,9 @@ class RemoteUploaderDownloaderHparams(hp.Hparams): def initialize_object(self) -> RemoteUploaderDownloader: return RemoteUploaderDownloader( - object_store_cls=self.object_store_hparams.get_object_store_cls(), - object_store_kwargs=self.object_store_hparams.get_kwargs(), - object_name=self.object_name, + bucket_uri=self.bucket_uri, + backend_kwargs=self.object_store_hparams.get_kwargs() if self.object_store_hparams is not None else {}, + file_path_format_string=self.file_path_format_string, num_concurrent_uploads=self.num_concurrent_uploads, upload_staging_folder=self.upload_staging_folder, use_procs=self.use_procs, diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index b876216c14..a945e956d6 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -18,24 +18,32 @@ import warnings from multiprocessing.context import SpawnProcess from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from urllib.parse import urlparse from composer.core.state import State from composer.loggers.logger import Logger from composer.loggers.logger_destination import LoggerDestination -from composer.utils import ObjectStore, ObjectStoreTransientError, dist, format_name_with_dist, get_file, retry +from composer.utils import (LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, S3ObjectStore, SFTPObjectStore, + dist, format_name_with_dist, get_file, retry) log = logging.getLogger(__name__) __all__ = ['RemoteUploaderDownloader'] -def _build_remote_backend(object_store_cls: Type[ObjectStore], object_store_kwargs: Dict[str, Any]): - # error: Expected no arguments to "ObjectStore" constructor - return object_store_cls(**object_store_kwargs) # type: ignore +def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]): + remote_backend_name_to_cls = {'s3': S3ObjectStore, 'sftp': SFTPObjectStore, 'libcloud': LibcloudObjectStore} + remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None) + if remote_backend_cls is None: + raise ValueError( + f'The remote backend {remote_backend_name} is not supported. Please use one of ({list(remote_backend_name_to_cls.keys())})' + ) + + return remote_backend_cls(**backend_kwargs) class RemoteUploaderDownloader(LoggerDestination): - r"""Logger destination that uploads (downloads) files to (from) an object store. + r"""Logger destination that uploads (downloads) files to (from) a remote backend. This logger destination handles calls to :meth:`.Logger.upload_file` and uploads files to :class:`.ObjectStore`, such as AWS S3 or Google Cloud Storage. To minimize the training @@ -47,8 +55,26 @@ class RemoteUploaderDownloader(LoggerDestination): from composer.utils import LibcloudObjectStore remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=LibcloudObjectStore, - object_store_kwargs={ + bucket_uri="s3://my-bucket", + file_path_format_string="path/to/my/checkpoints/{remote_file_name}", + ) + + # Construct the trainer using this logger + trainer = Trainer( + ..., + loggers=[remote_uploader_downloader], + ) + + or + + .. testcode:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__ + + from composer.loggers import RemoteUploaderDownloader + from composer.utils import LibcloudObjectStore + + remote_uploader_downloader = RemoteUploaderDownloader( + bucket_uri="libcloud://my-bucket", + backend_kwargs={ 'provider': 's3', 'container': 'my-bucket', 'provider_kwargs=': { @@ -70,8 +96,6 @@ class RemoteUploaderDownloader(LoggerDestination): This callback blocks the training loop to upload each file, as the uploading happens in the background. Here are some additional tips for minimizing the performance impact: - * Set ``should_log`` to filter which files will be uploaded. By default, all files are uploaded. - * Set ``use_procs=True`` (the default) to use background processes, instead of threads, to perform the file uploads. Processes are recommended to ensure that the GIL is not blocking the training loop when performing CPU operations on uploaded files (e.g. computing and comparing checksums). Network I/O happens @@ -82,17 +106,17 @@ class RemoteUploaderDownloader(LoggerDestination): be raised. Args: - object_store_cls (Type[ObjectStore]): The object store class. + bucket_uri (str): The remote uri for the bucket to use (e.g. s3://my-bucket). As individual :class:`.ObjectStore` instances are not necessarily thread safe, each worker will construct - its own :class:`.ObjectStore` instance from ``object_store_cls`` and ``object_store_kwargs``. + its own :class:`.ObjectStore` instance from ``remote_backend`` and ``backend_kwargs``. - object_store_kwargs (Dict[str, Any]): The keyword arguments to construct ``object_store_cls``. + backend_kwargs (Dict[str, Any]): The keyword arguments to construct the remote backend indicated by ``bucket_uri``. As individual :class:`.ObjectStore` instances are not necessarily thread safe, each worker will construct - its own :class:`.ObjectStore` instance from ``object_store_cls`` and ``object_store_kwargs``. + its own :class:`.ObjectStore` instance from ``remote_backend`` and ``backend_kwargs``. - object_name (str, optional): A format string used to determine the object name. + file_path_format_string (str, optional): A format string used to determine the remote file path (within the specified bucket). The following format variables are available: @@ -124,7 +148,7 @@ class RemoteUploaderDownloader(LoggerDestination): Consider the following example, which subfolders the remote files by their rank: - .. testsetup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.object_name + .. testsetup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string import os @@ -133,23 +157,23 @@ class RemoteUploaderDownloader(LoggerDestination): with open('path/to/file.txt', 'w+') as f: f.write('hi') - .. doctest:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.object_name + .. doctest:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string - >>> remote_uploader_downloader = RemoteUploaderDownloader(..., object_name='rank_{rank}/{remote_file_name}') + >>> remote_uploader_downloader = RemoteUploaderDownloader(..., file_path_format_string='rank_{rank}/{remote_file_name}') >>> trainer = Trainer(..., run_name='foo', loggers=[remote_uploader_downloader]) >>> trainer.logger.upload_file( ... remote_file_name='bar.txt', ... file_path='path/to/file.txt', ... ) - .. testcleanup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.object_name + .. testcleanup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string # Shut down the uploader remote_uploader_downloader._check_workers() remote_uploader_downloader.post_close() - Assuming that the process's rank is ``0``, the object store would store the contents of - ``'path/to/file.txt'`` in an object named ``'rank0/bar.txt'``. + Assuming that the process's rank is ``0``, the remote backend would store the contents of + ``'path/to/file.txt'`` in at ``'rank0/bar.txt'``. Default: ``'{remote_file_name}'`` @@ -163,16 +187,24 @@ class RemoteUploaderDownloader(LoggerDestination): """ def __init__(self, - object_store_cls: Type[ObjectStore], - object_store_kwargs: Dict[str, Any], - object_name: str = '{remote_file_name}', + bucket_uri: str, + backend_kwargs: Optional[Dict[str, Any]] = None, + file_path_format_string: str = '{remote_file_name}', num_concurrent_uploads: int = 4, upload_staging_folder: Optional[str] = None, use_procs: bool = True, num_attempts: int = 3) -> None: - self.object_store_cls = object_store_cls - self.object_store_kwargs = object_store_kwargs - self.object_name = object_name + parsed_remote_bucket = urlparse(bucket_uri) + self.remote_backend_name, remote_bucket_name = parsed_remote_bucket.scheme, parsed_remote_bucket.netloc + self.backend_kwargs = backend_kwargs if backend_kwargs is not None else {} + if self.remote_backend_name == 's3' and 'bucket' not in self.backend_kwargs: + self.backend_kwargs['bucket'] = remote_bucket_name + elif self.remote_backend_name == 'sftp' and 'host' not in self.backend_kwargs: + self.backend_kwargs['host'] = f'sftp://{remote_bucket_name}' + elif self.remote_backend_name == 'libcloud' and 'container' not in self.backend_kwargs: + self.backend_kwargs['container'] = remote_bucket_name + + self.file_path_format_string = file_path_format_string self.num_attempts = num_attempts self._run_name = None @@ -192,7 +224,7 @@ def __init__(self, # The object store might keep the earlier file rather than the latter file as the "latest" version # To work around this, each object name can appear at most once in `self._file_upload_queue` - # The main separately keeps track of {object_name: tempfile_path} for each API call to self.upload_file + # The main separately keeps track of {file_path_format_string: tempfile_path} for each API call to self.upload_file # and then periodically transfers items from this dictionary onto the file upload queue # Lock for modifying `logged_objects` or `enqueued_objects` @@ -232,7 +264,7 @@ def __init__(self, def remote_backend(self) -> ObjectStore: """The :class:`.ObjectStore` instance for the main thread.""" if self._remote_backend is None: - self._remote_backend = _build_remote_backend(self.object_store_cls, self.object_store_kwargs) + self._remote_backend = _build_remote_backend(self.remote_backend_name, self.backend_kwargs) return self._remote_backend def init(self, state: State, logger: Logger) -> None: @@ -241,7 +273,7 @@ def init(self, state: State, logger: Logger) -> None: raise RuntimeError('The RemoteUploaderDownloader is already initialized.') self._worker_flag = self._finished_cls() self._run_name = state.run_name - object_name_to_test = self._object_name('.credentials_validated_successfully') + file_name_to_test = self._remote_file_name('.credentials_validated_successfully') # Create the enqueue thread self._enqueue_thread_flag = self._finished_cls() @@ -250,7 +282,7 @@ def init(self, state: State, logger: Logger) -> None: if dist.get_global_rank() == 0: retry(ObjectStoreTransientError, - self.num_attempts)(lambda: _validate_credentials(self.remote_backend, object_name_to_test))() + self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))() assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None' for _ in range(self._num_concurrent_uploads): worker = self._proc_class( @@ -258,8 +290,8 @@ def init(self, state: State, logger: Logger) -> None: kwargs={ 'file_queue': self._file_upload_queue, 'is_finished': self._worker_flag, - 'object_store_cls': self.object_store_cls, - 'object_store_kwargs': self.object_store_kwargs, + 'remote_backend_name': self.remote_backend_name, + 'backend_kwargs': self.backend_kwargs, 'num_attempts': self.num_attempts, 'completed_queue': self._completed_queue, }, @@ -302,11 +334,12 @@ def upload_file( copied_path = os.path.join(self._upload_staging_folder, str(uuid.uuid4())) os.makedirs(self._upload_staging_folder, exist_ok=True) shutil.copy2(file_path, copied_path) - object_name = self._object_name(remote_file_name) + formatted_remote_file_name = self._remote_file_name(remote_file_name) with self._object_lock: - if object_name in self._logged_objects and not overwrite: - raise FileExistsError(f'Object {object_name} was already enqueued to be uploaded, but overwrite=False.') - self._logged_objects[object_name] = (copied_path, overwrite) + if formatted_remote_file_name in self._logged_objects and not overwrite: + raise FileExistsError( + f'Object {formatted_remote_file_name} was already enqueued to be uploaded, but overwrite=False.') + self._logged_objects[formatted_remote_file_name] = (copied_path, overwrite) def can_upload_files(self) -> bool: """Whether the logger supports uploading files.""" @@ -459,15 +492,15 @@ def get_uri_for_file(self, remote_file_name: str) -> str: Returns: str: The uri corresponding to the uploaded location of the remote file. """ - obj_name = self._object_name(remote_file_name) - return self.remote_backend.get_uri(obj_name.lstrip('/')) + formatted_remote_file_name = self._remote_file_name(remote_file_name) + return self.remote_backend.get_uri(formatted_remote_file_name.lstrip('/')) - def _object_name(self, remote_file_name: str): - """Format the ``remote_file_name`` according to the ``object_name_string``.""" + def _remote_file_name(self, remote_file_name: str): + """Format the ``remote_file_name`` according to the ``file_path_format_string``.""" if self._run_name is None: raise RuntimeError('The run name is not set. It should have been set on Event.INIT.') key_name = format_name_with_dist( - self.object_name, + self.file_path_format_string, run_name=self._run_name, remote_file_name=remote_file_name, ) @@ -477,15 +510,15 @@ def _object_name(self, remote_file_name: str): def _validate_credentials( - object_store: ObjectStore, - object_name_to_test: str, + remote_backend: ObjectStore, + remote_file_name_to_test: str, ) -> None: # Validates the credentials by attempting to touch a file in the bucket # raises an error if there was a credentials failure. with tempfile.NamedTemporaryFile('wb') as f: f.write(b'credentials_validated_successfully') - object_store.upload_object( - object_name=object_name_to_test, + remote_backend.upload_object( + object_name=remote_file_name_to_test, filename=f.name, ) @@ -494,8 +527,8 @@ def _upload_worker( file_queue: Union[queue.Queue[Tuple[str, str, bool]], multiprocessing.JoinableQueue[Tuple[str, str, bool]]], completed_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], is_finished: Union[multiprocessing._EventType, threading.Event], - object_store_cls: Type[ObjectStore], - object_store_kwargs: Dict[str, Any], + remote_backend_name: str, + backend_kwargs: Dict[str, Any], num_attempts: int, ): """A long-running function to handle uploading files to the object store. @@ -503,23 +536,23 @@ def _upload_worker( The worker will continuously poll ``file_queue`` for files to upload. Once ``is_finished`` is set, the worker will exit once ``file_queue`` is empty. """ - object_store = _build_remote_backend(object_store_cls, object_store_kwargs) + remote_backend = _build_remote_backend(remote_backend_name, backend_kwargs) while True: try: - file_path_to_upload, object_name, overwrite = file_queue.get(block=True, timeout=0.5) + file_path_to_upload, remote_file_name, overwrite = file_queue.get(block=True, timeout=0.5) except queue.Empty: if is_finished.is_set(): break else: continue - uri = object_store.get_uri(object_name) + uri = remote_backend.get_uri(remote_file_name) # defining as a function-in-function to use decorator notation with num_attempts as an argument @retry(ObjectStoreTransientError, num_attempts=num_attempts) def upload_file(): if not overwrite: try: - object_store.get_object_size(object_name) + remote_backend.get_object_size(remote_file_name) except FileNotFoundError: # Good! It shouldn't exist. pass @@ -527,12 +560,12 @@ def upload_file(): # Exceptions will be detected on the next batch_end or epoch_end event raise FileExistsError(f'Object {uri} already exists, but allow_overwrite was set to False.') log.info('Uploading file %s to %s', file_path_to_upload, uri) - object_store.upload_object( - object_name=object_name, + remote_backend.upload_object( + object_name=remote_file_name, filename=file_path_to_upload, ) os.remove(file_path_to_upload) file_queue.task_done() - completed_queue.put_nowait(object_name) + completed_queue.put_nowait(remote_file_name) upload_file() diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index daaa3e4d38..07b2b0a5f6 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -212,8 +212,8 @@ def _new_RemoteUploaderDownloader_init(self, fake_ellipses: None = None, **kwarg os.makedirs('./object_store', exist_ok=True) kwargs.update(use_procs=False, num_concurrent_uploads=1, - object_store_cls=LibcloudObjectStore, - object_store_kwargs={ + bucket_uri='libcloud://.', + backend_kwargs={ 'provider': 'local', 'container': '.', 'provider_kwargs': { diff --git a/docs/source/notes/resumption.rst b/docs/source/notes/resumption.rst index d66f64612e..7f5a3bc1ff 100644 --- a/docs/source/notes/resumption.rst +++ b/docs/source/notes/resumption.rst @@ -85,10 +85,7 @@ A typical use case is saving checkpoints to object store (e.g. S3) when there is # this assumes credentials are already configured via boto3 remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=S3ObjectStore, - object_store_kwargs={ - "bucket": "checkpoint-debugging", - }, + bucket_uri=f"s3://checkpoint-debugging", ) trainer = Trainer( @@ -118,10 +115,7 @@ To run fine-tuning on a spot instance, ``load_path`` would be set to the origina from composer.utils.object_store import S3ObjectStore remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=S3ObjectStore, - object_store_kwargs={ - "bucket": "checkpoint-debugging_2", - }, + bucket_uri=f"s3://checkpoint-debugging_2", ) # Train to generate and save the "pretrained_weights/model.pt", diff --git a/docs/source/trainer/checkpointing.rst b/docs/source/trainer/checkpointing.rst index 8e352a9c5e..9a60959088 100644 --- a/docs/source/trainer/checkpointing.rst +++ b/docs/source/trainer/checkpointing.rst @@ -304,8 +304,8 @@ Behind the scenes, the :class:`.RemoteUploaderDownloader` uses :doc:`Apache Libc from composer.utils import LibcloudObjectStore remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=LibcloudObjectStore, - object_store_kwargs={ + bucket_uri="libcloud://my_bucket", + backend_kwargs={ "provider": "s3", # The Apache Libcloud provider name "container": "my_bucket", # The name of the cloud container (i.e. bucket) to use. "provider_kwargs": { # The Apache Libcloud provider driver initialization arguments @@ -339,8 +339,8 @@ Once you've configured your object store logger per above, all that's left is to from composer.loggers import RemoteUploaderDownloader remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=LibcloudObjectStore, - object_store_kwargs={ + bucket_uri="libcloud://checkpoint-debugging", + backend_kwargs={ "provider": "s3", # The Apache Libcloud provider name "container": "checkpoint-debugging", # The name of the cloud container (i.e. bucket) to use. "provider_kwargs": { # The Apache Libcloud provider driver initialization arguments diff --git a/docs/source/trainer/file_uploading.rst b/docs/source/trainer/file_uploading.rst index dbd463aed8..9286db2d2a 100644 --- a/docs/source/trainer/file_uploading.rst +++ b/docs/source/trainer/file_uploading.rst @@ -176,12 +176,7 @@ with the :class:`~composer.utils.object_store.s3_object_store.S3ObjectStore` bac # Configure the logger logger = RemoteUploaderDownloader( - object_store_cls=S3ObjectStore, - object_store_kwargs={ - # Keyword arguments for the S3ObjectStore constructor. - # See the API reference for all available arguments - 'bucket': 'my-bucket-name', - }, + bucket_uri="s3://my-bucket-name", ) # Define the trainer @@ -212,12 +207,7 @@ Similar to the S3 Example above, we can upload files to a remote SFTP filesystem # Configure the logger logger = RemoteUploaderDownloader( - object_store_cls=SFTPObjectStore, - object_store_kwargs={ - # Keyword arguments for the SFTPObjectStore constructor. - # See the API reference for all available arguments - 'host': 'sftp_server.example.com', - }, + bucket_uri="sftp://sftp_server.example.com", ) # Define the trainer diff --git a/examples/training_without_local_storage.ipynb b/examples/training_without_local_storage.ipynb index 22f363f7a1..9e4130c036 100644 --- a/examples/training_without_local_storage.ipynb +++ b/examples/training_without_local_storage.ipynb @@ -297,12 +297,9 @@ "\n", "def get_remote_uploader_downloader():\n", " return RemoteUploaderDownloader(\n", - " object_store_cls=S3ObjectStore,\n", - " # Keyword arguments passed to the S3ObjectStore constructor\n", - " object_store_kwargs={\n", - " 'bucket': s3_bucket_name,\n", - " 'prefix': bucket_prefix,\n", - " },\n", + " bucket_uri=f\"s3://{s3_bucket_name}\",\n", + " # This creates a format string for where to store the checkpoints within the S3 bucket\n", + " file_path_format_string=bucket_prefix + \"/{remote_file_name}\"\n", " # In Jupyter, we set use_procs to False, since subprocess do not work\n", " # well within notebooks. Outside of Jupyter, it is recommended to let\n", " # use_procs default to True for performance\n", @@ -941,12 +938,12 @@ ")\n", "\n", "cloud_logger = RemoteUploaderDownloader(\n", - " object_store_cls=S3ObjectStore,\n", - " # Keyword arguments passed to the S3ObjectStore constructor\n", - " object_store_kwargs={\n", - " 'bucket': s3_bucket_name,\n", - " 'prefix': bucket_prefix,\n", - " },\n", + " bucket_uri=f\"s3://{s3_bucket_name}\",\n", + " # This creates a format string for where to store the checkpoints within the S3 bucket\n", + " file_path_format_string=bucket_prefix + \"/{remote_file_name}\"\n", + " # In Jupyter, we set use_procs to False, since subprocess do not work\n", + " # well within notebooks. Outside of Jupyter, it is recommended to let\n", + " # use_procs default to True for performance\n", " use_procs=False,\n", ")\n", "\n", diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 0214073e6e..f4250318b9 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -19,7 +19,6 @@ from composer.loggers.logger_destination import LoggerDestination from composer.loggers.logger_hparams_registry import RemoteUploaderDownloaderHparams, logger_registry from composer.loggers.progress_bar_logger import ProgressBarLogger -from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore from tests.common import get_module_subclasses try: @@ -63,8 +62,8 @@ _callback_kwargs: Dict[Union[Type[Callback], Type[hp.Hparams]], Dict[str, Any],] = { RemoteUploaderDownloader: { - 'object_store_cls': LibcloudObjectStore, - 'object_store_kwargs': { + 'bucket_uri': 'libcloud://.', + 'backend_kwargs': { 'provider': 'local', 'container': '.', 'provider_kwargs': { @@ -95,6 +94,7 @@ 'window_size': 1, }, RemoteUploaderDownloaderHparams: { + 'bucket_uri': 'libcloud://.', 'object_store_hparams': { 'libcloud': { 'provider': 'local', diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py index 920fd5767d..38d236850f 100644 --- a/tests/callbacks/test_loggers_across_callbacks.py +++ b/tests/callbacks/test_loggers_across_callbacks.py @@ -21,8 +21,8 @@ def test_loggers_on_callbacks(logger_cls: Type[LoggerDestination], callback_cls: logger_kwargs = get_cb_kwargs(logger_cls) if issubclass(logger_cls, RemoteUploaderDownloader): # Ensure that the remote directory does not conflict with any directory used by callbacks - logger_kwargs['object_store_kwargs']['provider_kwargs']['key'] = './remote' - os.makedirs(logger_kwargs['object_store_kwargs']['provider_kwargs']['key'], exist_ok=True) + logger_kwargs['backend_kwargs']['provider_kwargs']['key'] = './remote' + os.makedirs(logger_kwargs['backend_kwargs']['provider_kwargs']['key'], exist_ok=True) logger = logger_cls(**logger_kwargs) callback_kwargs = get_cb_kwargs(callback_cls) callback = callback_cls(**callback_kwargs) diff --git a/tests/loggers/test_remote_uploader_downloader.py b/tests/loggers/test_remote_uploader_downloader.py index afc6f1d76e..95488981d2 100644 --- a/tests/loggers/test_remote_uploader_downloader.py +++ b/tests/loggers/test_remote_uploader_downloader.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import multiprocessing import os import pathlib import random import shutil import time -from typing import Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union +from unittest.mock import patch import pytest @@ -21,7 +23,7 @@ class DummyObjectStore(ObjectStore): """Dummy ObjectStore implementation that is backed by a local directory.""" - def __init__(self, dir: pathlib.Path, always_fail: bool = False) -> None: + def __init__(self, dir: pathlib.Path, always_fail: bool = False, **kwargs: Dict[str, Any]) -> None: self.dir = str(dir) self.always_fail = always_fail os.makedirs(self.dir, exist_ok=True) @@ -69,80 +71,86 @@ def object_store_test_helper( remote_dir = str(tmp_path / 'object_store') os.makedirs(remote_dir, exist_ok=True) - remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=DummyObjectStore, - object_store_kwargs={ - 'dir': remote_dir, - }, - num_concurrent_uploads=1, - use_procs=use_procs, - upload_staging_folder=str(tmp_path / 'staging_folder'), - num_attempts=1, - ) - logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) - remote_file_name = 'remote_file_name' - - remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) - - file_path = os.path.join(tmp_path, f'file') - with open(file_path, 'w+') as f: - f.write('1') - logger.upload_file(remote_file_name, file_path, overwrite=overwrite) - - file_path_2 = os.path.join(tmp_path, f'file_2') - with open(file_path_2, 'w+') as f: - f.write('2') - - post_close_ctx = contextlib.nullcontext() - - if not overwrite: - # If not `overwrite_delay`, then the `logger.upload_file` will raise a FileExistsException, because the upload will not even be enqueued - # Otherwise, with a sufficient will be uploaded, and cleared from the queue, causing a runtime error to be raised on Event.BATCH_END or Event.EPOCH_END - # A 2 second sleep should be enough here -- the DummyObjectStore will block for at most 0.5 seconds, and the RemoteUploaderDownloader polls every 0.1 seconds - if overwrite_delay: - post_close_ctx = pytest.warns( - RuntimeWarning, - match=r'The following objects may not have been uploaded, likely due to a worker crash: remote_file_name' + # Patching does not work when using multiprocessing with spawn, so we also + # patch to use fork + fork_context = multiprocessing.get_context('fork') + with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): + remote_uploader_downloader = RemoteUploaderDownloader( + bucket_uri='s3://{remote_dir}', + backend_kwargs={ + 'dir': remote_dir, + }, + num_concurrent_uploads=1, + use_procs=use_procs, + upload_staging_folder=str(tmp_path / 'staging_folder'), + num_attempts=1, ) - # Wait for the first upload to go through - time.sleep(2) - # Do the second upload -- it should enqueue - logger.upload_file(remote_file_name, file_path_2, overwrite=overwrite) - # Give it some time to finish the second upload - # (Since the upload is really a file copy, it should be fast) - time.sleep(2) - # Then, crashes are detected on the next batch end / epoch end event - with pytest.raises(RuntimeError): - remote_uploader_downloader.run_event(Event.BATCH_END, dummy_state, logger) - - with pytest.raises(RuntimeError): - remote_uploader_downloader.run_event(Event.EPOCH_END, dummy_state, logger) - else: - # Otherwise, if no delay, it should error when being enqueued - with pytest.raises(FileExistsError): - logger.upload_file(remote_file_name, file_path_2, overwrite=overwrite) - - remote_uploader_downloader.close(dummy_state, logger) - - with post_close_ctx: - remote_uploader_downloader.post_close() - - # verify upload uri for file is correct - upload_uri = remote_uploader_downloader.get_uri_for_file(remote_file_name) - expected_upload_uri = f'local://{remote_file_name}' - assert upload_uri == expected_upload_uri - - # Test downloading file - download_path = os.path.join(tmp_path, 'download') - remote_uploader_downloader.download_file(remote_file_name, download_path) - with open(download_path, 'r') as f: - assert f.read() == '1' if not overwrite else '2' - - # now assert that we have a dummy file in the remote folder - remote_file = os.path.join(str(remote_dir), remote_file_name) - # Verify file contains the correct value - with open(remote_file, 'r') as f: - assert f.read() == '1' if not overwrite else '2' + logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) + remote_file_name = 'remote_file_name' + + remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) + + file_path = os.path.join(tmp_path, f'file') + with open(file_path, 'w+') as f: + f.write('1') + logger.upload_file(remote_file_name, file_path, overwrite=overwrite) + + file_path_2 = os.path.join(tmp_path, f'file_2') + with open(file_path_2, 'w+') as f: + f.write('2') + + post_close_ctx = contextlib.nullcontext() + + if not overwrite: + # If not `overwrite_delay`, then the `logger.upload_file` will raise a FileExistsException, because the upload will not even be enqueued + # Otherwise, with a sufficient will be uploaded, and cleared from the queue, causing a runtime error to be raised on Event.BATCH_END or Event.EPOCH_END + # A 2 second sleep should be enough here -- the DummyObjectStore will block for at most 0.5 seconds, and the RemoteUploaderDownloader polls every 0.1 seconds + if overwrite_delay: + post_close_ctx = pytest.warns( + RuntimeWarning, + match= + r'The following objects may not have been uploaded, likely due to a worker crash: remote_file_name' + ) + # Wait for the first upload to go through + time.sleep(2) + # Do the second upload -- it should enqueue + logger.upload_file(remote_file_name, file_path_2, overwrite=overwrite) + # Give it some time to finish the second upload + # (Since the upload is really a file copy, it should be fast) + time.sleep(2) + # Then, crashes are detected on the next batch end / epoch end event + with pytest.raises(RuntimeError): + remote_uploader_downloader.run_event(Event.BATCH_END, dummy_state, logger) + + with pytest.raises(RuntimeError): + remote_uploader_downloader.run_event(Event.EPOCH_END, dummy_state, logger) + else: + # Otherwise, if no delay, it should error when being enqueued + with pytest.raises(FileExistsError): + logger.upload_file(remote_file_name, file_path_2, overwrite=overwrite) + + remote_uploader_downloader.close(dummy_state, logger) + + with post_close_ctx: + remote_uploader_downloader.post_close() + + # verify upload uri for file is correct + upload_uri = remote_uploader_downloader.get_uri_for_file(remote_file_name) + expected_upload_uri = f'local://{remote_file_name}' + assert upload_uri == expected_upload_uri + + # Test downloading file + download_path = os.path.join(tmp_path, 'download') + remote_uploader_downloader.download_file(remote_file_name, download_path) + with open(download_path, 'r') as f: + assert f.read() == '1' if not overwrite else '2' + + # now assert that we have a dummy file in the remote folder + remote_file = os.path.join(str(remote_dir), remote_file_name) + # Verify file contains the correct value + with open(remote_file, 'r') as f: + assert f.read() == '1' if not overwrite else '2' def test_remote_uploader_downloader(tmp_path: pathlib.Path, dummy_state: State): @@ -174,81 +182,108 @@ def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_stat with open(tmp_path / 'samples' / f'sample_{i}', 'w+') as f: f.write(str(i)) - # Create the object store logger - remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=DummyObjectStore, - object_store_kwargs={ - 'dir': tmp_path / 'object_store_backend', - }, - num_concurrent_uploads=4, - use_procs=use_procs, - upload_staging_folder=str(tmp_path / 'staging_folder'), - num_attempts=1, - ) + # Patching does not work when using multiprocessing with spawn, so we also + # patch to use fork + fork_context = multiprocessing.get_context('fork') + with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): + # Create the object store logger + remote_uploader_downloader = RemoteUploaderDownloader( + bucket_uri=f"s3://{tmp_path}/'object_store_backend", + backend_kwargs={ + 'dir': tmp_path / 'object_store_backend', + }, + num_concurrent_uploads=4, + use_procs=use_procs, + upload_staging_folder=str(tmp_path / 'staging_folder'), + num_attempts=1, + ) - logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) + logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) - remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) + remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) - # Queue the files for upload in rapid succession to the same remote_file_name - remote_file_name = 'remote_file_name' - for i in range(num_files): - file_path = tmp_path / 'samples' / f'sample_{i}' - remote_uploader_downloader.upload_file(dummy_state, remote_file_name, file_path, overwrite=True) + # Queue the files for upload in rapid succession to the same remote_file_name + remote_file_name = 'remote_file_name' + for i in range(num_files): + file_path = tmp_path / 'samples' / f'sample_{i}' + remote_uploader_downloader.upload_file(dummy_state, remote_file_name, file_path, overwrite=True) - # Shutdown the logger. This should wait until all objects are uploaded - remote_uploader_downloader.close(dummy_state, logger=logger) - remote_uploader_downloader.post_close() + # Shutdown the logger. This should wait until all objects are uploaded + remote_uploader_downloader.close(dummy_state, logger=logger) + remote_uploader_downloader.post_close() - # Assert that the file called "remote_file_name" has the content of the last file uploaded file -- i.e. `num_files` - 1 - destination = tmp_path / 'downloaded_file' - remote_uploader_downloader.download_file(remote_file_name, str(destination), overwrite=False, progress_bar=False) - with open(destination, 'r') as f: - assert f.read() == str(num_files - 1) + # Assert that the file called "remote_file_name" has the content of the last file uploaded file -- i.e. `num_files` - 1 + destination = tmp_path / 'downloaded_file' + remote_uploader_downloader.download_file(remote_file_name, + str(destination), + overwrite=False, + progress_bar=False) + with open(destination, 'r') as f: + assert f.read() == str(num_files - 1) @pytest.mark.filterwarnings(r'ignore:Exception in thread:pytest.PytestUnhandledThreadExceptionWarning') def test_close_on_failure(tmp_path: pathlib.Path, dummy_state: State): """Test that .close() and .post_close() does not hang even when a worker crashes.""" - # Create the object store logger - remote_uploader_downloader = RemoteUploaderDownloader( - object_store_cls=DummyObjectStore, - object_store_kwargs={ - 'dir': tmp_path / 'object_store_backend', - 'always_fail': True, - }, - num_concurrent_uploads=1, - use_procs=False, - upload_staging_folder=str(tmp_path / 'staging_folder'), - num_attempts=1, - ) - - # Enqueue a file. Because `always_fail` is True, it will cause the worker to crash - - tmpfile_path = tmp_path / 'dummy_file' - with open(tmpfile_path, 'w+') as f: - f.write('hi') - - logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) - - remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) - - logger.upload_file('dummy_remote_file_name', tmpfile_path) - - # Wait enough time for the file to be enqueued - time.sleep(0.5) - - # Assert that the worker crashed - with pytest.raises(RuntimeError): - remote_uploader_downloader.run_event(Event.EPOCH_END, dummy_state, logger) - - # Enqueue the file again to ensure that the buffers are dirty - logger.upload_file('dummy_remote_file_name', tmpfile_path) - - # Shutdown the logger. This should not hang or cause any exception - remote_uploader_downloader.close(dummy_state, logger=logger) - with pytest.warns( - RuntimeWarning, - match= - r'The following objects may not have been uploaded, likely due to a worker crash: dummy_remote_file_name'): - remote_uploader_downloader.post_close() + + with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + # Create the object store logger + remote_uploader_downloader = RemoteUploaderDownloader( + bucket_uri=f"s3://{tmp_path}/'object_store_backend", + backend_kwargs={ + 'dir': tmp_path / 'object_store_backend', + 'always_fail': True, + }, + num_concurrent_uploads=1, + use_procs=False, + upload_staging_folder=str(tmp_path / 'staging_folder'), + num_attempts=1, + ) + + # Enqueue a file. Because `always_fail` is True, it will cause the worker to crash + + tmpfile_path = tmp_path / 'dummy_file' + with open(tmpfile_path, 'w+') as f: + f.write('hi') + + logger = Logger(dummy_state, destinations=[remote_uploader_downloader]) + + remote_uploader_downloader.run_event(Event.INIT, dummy_state, logger) + + logger.upload_file('dummy_remote_file_name', tmpfile_path) + + # Wait enough time for the file to be enqueued + time.sleep(0.5) + + # Assert that the worker crashed + with pytest.raises(RuntimeError): + remote_uploader_downloader.run_event(Event.EPOCH_END, dummy_state, logger) + + # Enqueue the file again to ensure that the buffers are dirty + logger.upload_file('dummy_remote_file_name', tmpfile_path) + + # Shutdown the logger. This should not hang or cause any exception + remote_uploader_downloader.close(dummy_state, logger=logger) + with pytest.warns( + RuntimeWarning, + match= + r'The following objects may not have been uploaded, likely due to a worker crash: dummy_remote_file_name' + ): + remote_uploader_downloader.post_close() + + +def test_valid_backend_names(): + valid_backend_names = ['s3', 'libcloud', 'sftp'] + with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore') as _, \ + patch('composer.loggers.remote_uploader_downloader.SFTPObjectStore') as _, \ + patch('composer.loggers.remote_uploader_downloader.LibcloudObjectStore') as _: + for name in valid_backend_names: + remote_uploader_downloader = RemoteUploaderDownloader(bucket_uri=f'{name}://not-a-real-bucket') + # Access the remote_backend property so that it is built + _ = remote_uploader_downloader.remote_backend + + with pytest.raises(ValueError): + remote_uploader_downloader = RemoteUploaderDownloader(bucket_uri='magicloud://not-a-real-bucket') + # Access the remote_backend property so that it is built + _ = remote_uploader_downloader.remote_backend diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 569d7eb727..d6f8227601 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -25,7 +25,6 @@ from composer.trainer.trainer import Trainer from composer.utils import dist, is_tar from composer.utils.checkpoint import glob_filter -from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore from tests.common import RandomImageDataset, SimpleConvModel, deep_compare, device @@ -174,8 +173,8 @@ def get_logger(self, tmp_path: pathlib.Path): os.makedirs(remote_dir, exist_ok=True) return RemoteUploaderDownloader( - object_store_cls=LibcloudObjectStore, - object_store_kwargs={ + bucket_uri='libcloud://.', + backend_kwargs={ 'provider': 'local', 'container': '.', 'provider_kwargs': { diff --git a/tests/utils/object_store/test_object_store.py b/tests/utils/object_store/test_object_store.py index e59185226e..e95dcccc9a 100644 --- a/tests/utils/object_store/test_object_store.py +++ b/tests/utils/object_store/test_object_store.py @@ -4,7 +4,8 @@ import contextlib import copy import pathlib -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, Tuple +from urllib.parse import urlparse import pytest @@ -14,13 +15,14 @@ @pytest.fixture -def object_store_cls_and_kwargs(request, s3_bucket: str, sftp_uri: str, test_session_name: str): +def bucket_uri_and_kwargs(request, s3_bucket: str, sftp_uri: str, test_session_name: str): remote = request.node.get_closest_marker('remote') is not None if request.param is LibcloudObjectStore: if remote: pytest.skip('Libcloud object store has no remote tests') else: + bucket_uri = 'libcloud://.' kwargs = { 'provider': 'local', 'container': '.', @@ -31,16 +33,20 @@ def object_store_cls_and_kwargs(request, s3_bucket: str, sftp_uri: str, test_ses } elif request.param is S3ObjectStore: if remote: + bucket_uri = f's3://{s3_bucket}' kwargs = {'bucket': s3_bucket, 'prefix': test_session_name} else: + bucket_uri = 's3://my-bucket' kwargs = {'bucket': 'my-bucket', 'prefix': 'folder/subfolder'} elif request.param is SFTPObjectStore: if remote: + bucket_uri = f"sftp://{sftp_uri.rstrip('/') + '/' + test_session_name}" kwargs = { 'host': sftp_uri.rstrip('/') + '/' + test_session_name, 'missing_host_key_policy': 'WarningPolicy', } else: + bucket_uri = 'sftp://localhost:23' kwargs = { 'host': 'localhost', 'port': 23, @@ -48,7 +54,7 @@ def object_store_cls_and_kwargs(request, s3_bucket: str, sftp_uri: str, test_ses } else: raise ValueError(f'Invalid object store type: {request.param.__name__}') - return request.param, kwargs + return bucket_uri, kwargs class MockCallback: @@ -68,23 +74,29 @@ def assert_all_data_transferred(self): assert self.total_num_bytes == self.transferred_bytes -@pytest.mark.parametrize('object_store_cls_and_kwargs', object_stores, indirect=True) +@pytest.mark.parametrize('bucket_uri_and_kwargs', object_stores, indirect=True) @pytest.mark.parametrize('remote', [False, pytest.param(True, marks=pytest.mark.remote)]) class TestObjectStore: @pytest.fixture def object_store( self, - object_store_cls_and_kwargs: Tuple[Type[ObjectStore], Dict[str, Any]], + bucket_uri_and_kwargs: Tuple[str, Dict[str, Any]], monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, remote: bool, ): - object_store_cls, kwargs = object_store_cls_and_kwargs - with get_object_store_ctx(object_store_cls, kwargs, monkeypatch, tmp_path, remote=remote): + remote_backend_name_to_class = {'s3': S3ObjectStore, 'sftp': SFTPObjectStore, 'libcloud': LibcloudObjectStore} + bucket_uri, kwargs = bucket_uri_and_kwargs + remote_backend_name = urlparse(bucket_uri).scheme + with get_object_store_ctx(remote_backend_name_to_class[remote_backend_name], + kwargs, + monkeypatch, + tmp_path, + remote=remote): copied_config = copy.deepcopy(kwargs) # type error: Type[ObjectStore] is not callable - object_store = object_store_cls(**copied_config) # type: ignore + object_store = remote_backend_name_to_class[remote_backend_name](**copied_config) # type: ignore with object_store: yield object_store