Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion composer/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ def __len__(self) -> int:

JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']]

TPrefetchFn = Callable[[Batch], Batch]
TDeviceTransformFn = Callable[[Batch], Batch]

StateDict = Dict[str, Any]
32 changes: 15 additions & 17 deletions composer/datasets/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import yahp as hp
from torch.utils.data import Dataset

from composer.datasets.hparams import DataloaderSpec, DatasetHparams
from composer.core.types import DataLoader, Dataset
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams
from composer.utils import ddp

PATCH_SIZE = [1, 192, 160]

Expand Down Expand Up @@ -46,27 +49,22 @@ class BratsDatasetHparams(DatasetHparams):
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)
oversampling: float = hp.optional("oversampling", default=0.33)

def initialize_object(self) -> DataloaderSpec:
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:

datadir = self.datadir
oversampling = self.oversampling

x_train, y_train, x_val, y_val = get_data_split(datadir)
train_dataset = PytTrain(x_train, y_train, oversampling)
val_dataset = PytVal(x_val, y_val)
if self.is_train:
return DataloaderSpec(
dataset=train_dataset,
drop_last=self.drop_last,
shuffle=self.shuffle,
)
else:
return DataloaderSpec(
dataset=val_dataset,
drop_last=self.drop_last,
shuffle=self.shuffle,
collate_fn=my_collate, # type: ignore
)
dataset = PytTrain(x_train, y_train, oversampling) if self.is_train else PytVal(x_val, y_val)
collate_fn = None if self.is_train else my_collate
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
return dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last,
collate_fn=collate_fn,
)


def coin_flip(prob):
Expand Down
28 changes: 17 additions & 11 deletions composer/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10

from composer.datasets.hparams import DataloaderSpec, DatasetHparams
from composer.core.types import DataLoader
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams
from composer.utils import ddp


@dataclass
Expand All @@ -27,7 +30,7 @@ class CIFAR10DatasetHparams(DatasetHparams):
drop_last: bool = hp.optional("Whether to drop the last samples for the last batch", default=True)
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)

def initialize_object(self) -> DataloaderSpec:
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
cifar10_mean, cifar10_std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
datadir = self.datadir

Expand All @@ -44,13 +47,16 @@ def initialize_object(self) -> DataloaderSpec:
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
])

return DataloaderSpec(
dataset=CIFAR10(
datadir,
train=self.is_train,
download=self.download,
transform=transformation,
),
drop_last=self.drop_last,
shuffle=self.shuffle,
dataset = CIFAR10(
datadir,
train=self.is_train,
download=self.download,
transform=transformation,
)

sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last)
109 changes: 65 additions & 44 deletions composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@

from __future__ import annotations

import textwrap
import warnings
from dataclasses import dataclass
from typing import Any, Iterator, Optional
from typing import Any, Callable, Iterator, Optional

import torch
import torch.distributed
import torch.utils.data
import yahp as hp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler

from composer.core.types import Batch, DataLoader
from composer.datasets.hparams import DataloaderSpec
from composer.core.types import Batch, DataLoader, Dataset


class WrappedDataLoader(DataLoader):
Expand Down Expand Up @@ -47,17 +46,17 @@ def __setattr__(self, name: str, value: Any) -> None:


class DDPDataLoader(WrappedDataLoader):
"""Ensure sampler.set_epoch() is called after each iteration.

DDPDataLoader wraps a dataloader and a distributed sampler and is
called after each iteration (epoch) through the dataset.
See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
"""Wraps the dataset to ensure that, if the dataset sampler is a
:class:`~torch.utils.data.distributed.DistributedSampler`, then
:meth:`~torch.utils.data.distributed.DistributedSampler.set_epoch`
is called after each epoch.

If the dataset sampler is not a :class:`~torch.utils.data.distributed.DistributedSampler`,
then this wrapper is a no-op.
"""

def __init__(self, dataloader: DataLoader) -> None:
super().__init__(dataloader)
if not isinstance(self.dataloader.sampler, DistributedSampler):
raise ValueError("When using the DDP data loader, the sampler must be a DistributedSampler")
self._iterator: Optional[Iterator[Batch]] = None

def __iter__(self) -> DDPDataLoader:
Expand All @@ -68,8 +67,8 @@ def __iter__(self) -> DDPDataLoader:
"The dataloader is skipping ahead to the start of the next epoch. "
"Multiple simultaneous iterations through the DDP dataloader prohibited, since "
"it automatically tracks the current epoch.")
assert isinstance(self.sampler, DistributedSampler)
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
self._iterator = iter(self.dataloader)
return self

Expand All @@ -79,47 +78,69 @@ def __next__(self) -> Batch:
return next(self._iterator)
except StopIteration:
self._iterator = None
assert isinstance(self.sampler, DistributedSampler)
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
raise


@dataclass
class DataloaderHparams(hp.Hparams):
"""Hyperparameters to initialize a ``torch.utils.data.Dataloader``."""
"""Hyperparameters to initialize a :class:`~torch.utils.data.Dataloader`.

Parameters:
num_workers (int): Number of CPU workers to use per device to fetch data.
prefetch_factor (int): Number of samples loaded in advance by each worker.
2 means there will be a total of 2 * num_workers samples prefetched across all workers.
persistent_workers (bool): Whether or not to shutdown workers after the dataset has been consumed once.
pin_memory (bool): Whether or not to copy Tensors into CUDA pinned memory before returning them.
timeout (float): Timeout, in seconds, for collecting a batch from workers. Set to 0 for no timeout.

"""

num_workers: int = hp.required(doc="Number of CPU workers to use per gpu", template_default=8)
prefetch_factor: int = hp.required(doc="Number of samples loaded in advance by each worker", template_default=2)
persistent_workers: bool = hp.required(doc="Whether or not to shutdown workers after the dataset"
" has been consumed once",
num_workers: int = hp.required("Number of CPU workers to use per device to fetch data.", template_default=8)
prefetch_factor: int = hp.required("Number of samples loaded in advance by each worker", template_default=2)
persistent_workers: bool = hp.required(textwrap.dedent("""Whether or not to shutdown workers after the dataset
has been consumed once"""),
template_default=True)
pin_memory: bool = hp.required(doc="Whether or not to copy Tensors into CUDA pinned memory"
" before returning them",
pin_memory: bool = hp.required(textwrap.dedent("""Whether or not to copy Tensors into CUDA pinned memory
before returning them"""),
template_default=True)
timeout: int = hp.required(doc="Timeout value for collecting a batch from workers. 0 for no timeout.",
template_default=0)
timeout: float = hp.required("Timeout, in seconds, for collecting a batch from workers. Set to 0 for no timeout",
template_default=0)

def initialize_object(
self,
dataset: Dataset,
*,
batch_size: int,
sampler: Sampler,
dataloader_spec: DataloaderSpec,
sampler: torch.utils.data.Sampler[int],
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
) -> DataLoader:
"""Initializes the dataloader."""

dataloader = torch.utils.data.DataLoader(
dataloader_spec.dataset,
batch_size=batch_size,
shuffle=False, # set in the sampler
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=dataloader_spec.drop_last,
sampler=sampler,
collate_fn=dataloader_spec.collate_fn,
worker_init_fn=dataloader_spec.worker_init_fn,
multiprocessing_context=dataloader_spec.multiprocessing_context,
generator=dataloader_spec.generator,
timeout=self.timeout,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers)
return dataloader
"""Create a dataloader.

Args:
dataset (Dataset): The dataset.
batch_size (int): The per-device batch size.
sampler (torch.utils.data.Sampler[int]): The sampler to use for the dataloader.
drop_last (bool): Whether to drop the last batch if the number of
samples is not evenly divisible by the batch size.
collate_fn (callable, optional): Custom collate function. Defaults to None.
worker_init_fn (callable, optional): Custom worker init function. Defaults to None.

Returns:
DataLoader: The dataloader.
"""

return torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=drop_last,
sampler=sampler,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
timeout=self.timeout,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers)
36 changes: 12 additions & 24 deletions composer/datasets/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from typing import Callable, List, NamedTuple, Optional, Sequence, Union

import torch
import yahp as hp

from composer.core.types import Batch, Dataset, Tensor, TPrefetchFn
from composer.core.types import Batch, DataLoader, TDeviceTransformFn, Tensor
from composer.datasets.dataloader import DataloaderHparams


def _split_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
Expand All @@ -30,29 +30,16 @@ class DataloaderSpec(NamedTuple):
"""Specification for initializing a dataloader.

Attributes:
dataset (Dataset): The initialized dataset from which to load data.
drop_last (bool): Whether the final batch of an epoch should be discarded
if there are fewer samples than the batch size.
shuffle (bool): Whether the data should be shuffled.
collate_fn (List[Any] -> Batch, optional): A function to collate
data before returning it from the dataloader.
worker_init_fn (int -> None, optional): A function to be ran
on each worker before dataloading begins.
multiprocessing_context (Any, optional): The context to use for multiprocessing.
generator (torch.Generator, optional): An RNG to be used for seeding workers.
prefetch_fn (TPrefetchFn, optional): A function to run for prefetching data.
dataloader (DataLoader): The initialized dataloader.
device_transform_fn (TDeviceTransformFn, optional):
A function to modify the data once it has been loaded onto the device (for example, GPU-based batch normalization)
This function is invoked with a batch of data after it has been moved onto the device,
and it is expected to return a batch.
split_fn (Batch, int -> List[Batch]): A function to
run to split batches into microbatches.
"""

dataset: Dataset
drop_last: bool
shuffle: bool
collate_fn: Optional[Callable[[List[Any]], Batch]] = None
worker_init_fn: Optional[Callable[[int], None]] = None
multiprocessing_context: Any = None
generator: Optional[torch.Generator] = None
prefetch_fn: Optional[TPrefetchFn] = None
dataloader: DataLoader
device_transform_fn: Optional[TDeviceTransformFn] = None
split_fn: Callable[[Batch, int], List[Batch]] = _split_fn


Expand All @@ -63,7 +50,8 @@ class DatasetHparams(hp.Hparams, ABC):
pass

@abstractmethod
def initialize_object(self) -> DataloaderSpec:
def initialize_object(self, batch_size: int,
dataloader_hparams: DataloaderHparams) -> Union[DataLoader, DataloaderSpec]:
"""Initializes a :class:`DataloaderSpec` for this dataset."""

pass
23 changes: 16 additions & 7 deletions composer/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from torchvision.datasets import ImageFolder

from composer.core.types import Batch, Tensor
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DataloaderSpec, DatasetHparams
from composer.utils import ddp


class PreprocessingFn:
class TransformationFn:

def __init__(self) -> None:
self.mean: Optional[Tensor] = None
Expand Down Expand Up @@ -82,7 +84,7 @@ class ImagenetDatasetHparams(DatasetHparams):
drop_last: bool = hp.optional("Whether to drop the last samples for the last batch", default=True)
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)

def initialize_object(self) -> DataloaderSpec:
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataloaderSpec:
datadir = self.datadir
is_train = self.is_train

Expand All @@ -107,10 +109,17 @@ def initialize_object(self) -> DataloaderSpec:

split = "train" if is_train else "val"

dataset = ImageFolder(os.path.join(datadir, split), transformation)

sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return DataloaderSpec(
dataset=ImageFolder(os.path.join(datadir, split), transformation),
drop_last=self.drop_last,
collate_fn=fast_collate,
shuffle=self.shuffle,
prefetch_fn=PreprocessingFn(),
dataloader=dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last,
collate_fn=fast_collate,
),
device_transform_fn=TransformationFn(),
)
Loading