4
4
from dataclasses import dataclass
5
5
from functools import partial
6
6
from itertools import chain
7
- from typing import List , Optional
7
+ from typing import List
8
8
9
9
import yahp as hp
10
10
from torch .utils .data import IterableDataset , get_worker_info
11
11
12
12
from composer .core .types import Batch , DataSpec
13
13
from composer .datasets .dataloader import DataloaderHparams
14
- from composer .datasets .hparams import DatasetHparams
15
14
from composer .utils import dist
16
15
17
16
log = logging .getLogger (__name__ )
18
17
19
18
19
+ def _split_dict_fn (batch : Batch , n_microbatches : int ) -> List [Batch ]:
20
+ if isinstance (batch , dict ):
21
+ chunked = {k : v .chunk (n_microbatches ) for k , v in batch .items ()}
22
+ for k , v in chunked .items ():
23
+ if len (v ) != n_microbatches :
24
+ raise ValueError (
25
+ f"Unable to split batch into microbatches. "
26
+ f"Key '{ k } ' has chunked list: { v } with length { len (v )} , but expected length { n_microbatches } . " )
27
+ microbatches = []
28
+ for idx in range (n_microbatches ):
29
+ mb = {k : v [idx ] for k , v in chunked .items ()}
30
+ microbatches .append (mb )
31
+ return microbatches
32
+ else :
33
+ raise ValueError (f'Expected batch to be of type Dict[str, Tensor], but got { type (batch )} ' )
34
+
35
+
20
36
@dataclass
21
37
class C4DatasetHparams (hp .Hparams ):
22
38
"""Builds a DataSpec for the C4 (Colossal Cleaned CommonCrawl) dataset.
@@ -64,23 +80,13 @@ def validate(self):
64
80
if self .mlm and self .mlm_probability <= 0 :
65
81
raise ValueError ("Must provide a positive 'mlm_probability' when using masked language modeling." )
66
82
67
- def _split_dict_fn (batch : Batch , n_microbatches : int ) -> List [Batch ]:
68
- if isinstance (batch , dict ):
69
- chunked = {k : v .chunk (n_microbatches ) for k , v in batch .items ()}
70
- for k , v in chunked .items ():
71
- if len (v ) != n_microbatches :
72
- raise ValueError (
73
- f"Unable to split batch into microbatches. "
74
- f"Key '{ k } ' has chunked list: { v } with length { len (v )} , but expected length { n_microbatches } . " )
75
- microbatches = []
76
- for idx in range (n_microbatches ):
77
- mb = {k : v [idx ] for k , v in chunked .items ()}
78
- microbatches .append (mb )
79
- return microbatches
80
- else :
81
- raise ValueError (f'Expected batch to be of type Dict[str, Tensor], but got { type (batch )} ' )
82
-
83
83
def initialize_object (self , batch_size : int , dataloader_hparams : DataloaderHparams ) -> DataSpec :
84
+ try :
85
+ import transformers
86
+ except ImportError :
87
+ raise ImportError ('HuggingFace transformers not installed. '
88
+ 'Please install with `pip install composer[nlp]`' )
89
+
84
90
# Get C4 dataset
85
91
c4_dataset = C4Dataset (split = self .split ,
86
92
max_samples = self .max_samples ,
@@ -103,7 +109,7 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
103
109
drop_last = self .drop_last ,
104
110
collate_fn = collate_fn ,
105
111
),
106
- split_batch = self . _split_dict_fn )
112
+ split_batch = _split_dict_fn )
107
113
108
114
109
115
class C4Dataset (IterableDataset ):
0 commit comments