Skip to content

Commit 684becb

Browse files
committed
revert cast, lint
1 parent 9dbf769 commit 684becb

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

composer/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
"""
2222
from composer.datasets.ade20k import ADE20kDatasetHparams as ADE20kDatasetHparams
2323
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
24+
from composer.datasets.c4 import C4DatasetHparams as C4DatasetHparams
2425
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
2526
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
2627
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
2728
from composer.datasets.dataset_registry import get_dataset_registry as get_dataset_registry
2829
from composer.datasets.evaluator import EvaluatorHparams as EvaluatorHparams
2930
from composer.datasets.glue import GLUEHparams as GLUEHparams
30-
from composer.datasets.c4 import C4DatasetHparams as C4DatasetHparams
3131
from composer.datasets.hparams import DatasetHparams as DatasetHparams
3232
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
3333
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams

composer/datasets/c4.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,35 @@
44
from dataclasses import dataclass
55
from functools import partial
66
from itertools import chain
7-
from typing import List, Optional
7+
from typing import List
88

99
import yahp as hp
1010
from torch.utils.data import IterableDataset, get_worker_info
1111

1212
from composer.core.types import Batch, DataSpec
1313
from composer.datasets.dataloader import DataloaderHparams
14-
from composer.datasets.hparams import DatasetHparams
1514
from composer.utils import dist
1615

1716
log = logging.getLogger(__name__)
1817

1918

19+
def _split_dict_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
20+
if isinstance(batch, dict):
21+
chunked = {k: v.chunk(n_microbatches) for k, v in batch.items()}
22+
for k, v in chunked.items():
23+
if len(v) != n_microbatches:
24+
raise ValueError(
25+
f"Unable to split batch into microbatches. "
26+
f"Key '{k}' has chunked list: {v} with length {len(v)}, but expected length {n_microbatches}. ")
27+
microbatches = []
28+
for idx in range(n_microbatches):
29+
mb = {k: v[idx] for k, v in chunked.items()}
30+
microbatches.append(mb)
31+
return microbatches
32+
else:
33+
raise ValueError(f'Expected batch to be of type Dict[str, Tensor], but got {type(batch)}')
34+
35+
2036
@dataclass
2137
class C4DatasetHparams(hp.Hparams):
2238
"""Builds a DataSpec for the C4 (Colossal Cleaned CommonCrawl) dataset.
@@ -64,23 +80,13 @@ def validate(self):
6480
if self.mlm and self.mlm_probability <= 0:
6581
raise ValueError("Must provide a positive 'mlm_probability' when using masked language modeling.")
6682

67-
def _split_dict_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
68-
if isinstance(batch, dict):
69-
chunked = {k: v.chunk(n_microbatches) for k, v in batch.items()}
70-
for k, v in chunked.items():
71-
if len(v) != n_microbatches:
72-
raise ValueError(
73-
f"Unable to split batch into microbatches. "
74-
f"Key '{k}' has chunked list: {v} with length {len(v)}, but expected length {n_microbatches}. ")
75-
microbatches = []
76-
for idx in range(n_microbatches):
77-
mb = {k: v[idx] for k, v in chunked.items()}
78-
microbatches.append(mb)
79-
return microbatches
80-
else:
81-
raise ValueError(f'Expected batch to be of type Dict[str, Tensor], but got {type(batch)}')
82-
8383
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataSpec:
84+
try:
85+
import transformers
86+
except ImportError:
87+
raise ImportError('HuggingFace transformers not installed. '
88+
'Please install with `pip install composer[nlp]`')
89+
8490
# Get C4 dataset
8591
c4_dataset = C4Dataset(split=self.split,
8692
max_samples=self.max_samples,
@@ -103,7 +109,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
103109
drop_last=self.drop_last,
104110
collate_fn=collate_fn,
105111
),
106-
split_batch=self._split_dict_fn)
112+
split_batch=_split_dict_fn)
107113

108114

109115
class C4Dataset(IterableDataset):

composer/datasets/dataset_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from composer.datasets.ade20k import ADE20kDatasetHparams
44
from composer.datasets.brats import BratsDatasetHparams
5+
from composer.datasets.c4 import C4DatasetHparams
56
from composer.datasets.cifar10 import CIFAR10DatasetHparams
67
from composer.datasets.glue import GLUEHparams
7-
from composer.datasets.c4 import C4DatasetHparams
88
from composer.datasets.imagenet import ImagenetDatasetHparams
99
from composer.datasets.lm_datasets import LMDatasetHparams
1010
from composer.datasets.mnist import MNISTDatasetHparams

composer/trainer/trainer_hparams.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import textwrap
99
import warnings
1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Dict, List, Optional
11+
from typing import TYPE_CHECKING, Dict, List, Optional, cast
1212

1313
import yahp as hp
1414

@@ -306,8 +306,8 @@ def validate(self):
306306
super().validate()
307307

308308
if self.deepspeed is not None:
309-
self.deepspeed["zero_stage"] = self.deepspeed.get("zero_stage", 0)
310-
self.deepspeed["steps_per_print"] = self.deepspeed.get("steps_per_print", 1e20)
309+
self.deepspeed["zero_stage"] = cast(int, self.deepspeed.get("zero_stage", 0))
310+
self.deepspeed["steps_per_print"] = cast(int, self.deepspeed.get("steps_per_print", 1e20))
311311

312312
if self.deterministic_mode and self.deepspeed["zero_stage"] > 0:
313313
raise ValueError("Deepspeed with zero stage > 0 is not compatible with deterministic mode")

0 commit comments

Comments
 (0)