Skip to content

Commit 15df497

Browse files
committed
wip
1 parent b16caab commit 15df497

File tree

5 files changed

+413
-0
lines changed

5 files changed

+413
-0
lines changed

composer/datasets/__init__.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams
1212
from composer.datasets.lm_datasets import LMDatasetHparams as LMDatasetHparams
1313
from composer.datasets.mnist import MNISTDatasetHparams as MNISTDatasetHparams
14+
from composer.datasets.streaming_lm_datasets import StreamingLMDatasetHparams as StreamingLMDatasetHparams
1415
from composer.datasets.synthetic import MemoryFormat as MemoryFormat
1516
from composer.datasets.synthetic import SyntheticBatchPairDataset as SyntheticBatchPairDataset
1617
from composer.datasets.synthetic import SyntheticDataLabelType as SyntheticDataLabelType
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2021 MosaicML. All Rights Reserved.
2+
3+
import copy
4+
import logging
5+
import tempfile
6+
from dataclasses import dataclass
7+
from functools import partial
8+
from itertools import chain
9+
from os.path import join
10+
from typing import List, Optional
11+
12+
import datasets
13+
import torch
14+
import yahp as hp
15+
from transformers.testing_utils import CaptureLogger
16+
17+
from composer.core.types import Batch
18+
from composer.datasets.dataloader import DataloaderHparams
19+
from composer.datasets.hparams import DataloaderSpec, DatasetHparams
20+
from composer.utils import dist
21+
from composer.utils.data import get_subset_dataset
22+
23+
log = logging.getLogger(__name__)
24+
25+
26+
def _split_dict_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
27+
if isinstance(batch, dict):
28+
chunked = {k: v.chunk(n_microbatches) for k, v in batch.items()}
29+
num_chunks = len(list(chunked.values())[0])
30+
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]
31+
else:
32+
raise ValueError(f'Expect batch from dataloader to be of type Dict[str, Tensor], but got {type(batch)}')
33+
34+
35+
CACHED_DATASET_SIZES = {"c4": {"en": {"train": (1024, 356317), "validation": (8, 45576)}}}
36+
37+
@dataclass
38+
class StreamingLMDatasetHparams(DatasetHparams):
39+
"""
40+
Defines a generic dataset class for autoregressive and masked language models.
41+
"""
42+
43+
dataset_name: str = hp.optional("Name of the dataset to load.", default=None)
44+
dataset_config_name: Optional[str] = hp.optional(
45+
"If required, the specific configuration of the dataset that you would like to use.", default=None)
46+
split: str = hp.optional("What split of the dataset to use (e.g. 'train' or 'validation' or 'test')", default=None)
47+
max_shards: int = hp.optional("Max number of shards, used to deterministically reduce dataset size.", default=-1)
48+
max_samples: int = hp.optional("Max number of post-processed samples, note that the subset will depend on seed and world size.", default=-1)
49+
tokenizer_name: str = hp.optional("The name of the tokenizer to preprocess text with.", default=None)
50+
max_seq_len: int = hp.optional("The max sequence length of each token sample.", default=None)
51+
group_method: str = hp.optional("How to group text samples into token samples.", default=None)
52+
use_masked_lm: bool = hp.optional("Whether the dataset shoud be encoded with masked language modeling or not.",
53+
default=None)
54+
mlm_probability: float = hp.optional("If using masked language modeling, the probability to mask tokens with.",
55+
default=0.15)
56+
seed: int = hp.optional("Which seed to use to generate train and validation splits.", default=5)
57+
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch.", default=True)
58+
drop_last: bool = hp.optional("Whether to drop the last samples for the last batch.", default=False)
59+
60+
def validate(self):
61+
assert self.group_method in ["truncate", "concat"], f"Unknown group_method: '{self.group_method}'"
62+
assert self.drop_last == True, "No support for 'drop_last'=False currently."
63+
if self.group_method == "concat":
64+
assert self.max_samples > 0, f"Must provide 'max_samples' if 'group_method'='concat'"
65+
if self.use_masked_lm:
66+
if self.mlm_probability <= 0.0:
67+
raise ValueError(
68+
"If using Masked Language Modeling, you must replace tokens with a non-zero probability.")
69+
70+
def _load_dataset(self):
71+
return datasets.load_dataset(path=self.dataset_name,
72+
name=self.dataset_config_name,
73+
split=self.split,
74+
streaming=True)
75+
76+
def _get_approx_num_samples(self):
77+
try:
78+
if self.max_samples > 0:
79+
return self.max_samples
80+
else:
81+
n_shards, samples_per_shard = CACHED_DATASET_SIZES[self.dataset_name][self.dataset_config_name][self.split]
82+
n_shards = self.max_shards if self.max_shards > 0 else n_shards
83+
return n_shards * samples_per_shard
84+
except:
85+
raise NotImplementedError
86+
87+
def _get_approx_num_tokens(self):
88+
return 1e12
89+
90+
def _subsample(self, device_offset, text_batch):
91+
# Only return the i-th item out of N sequential items
92+
for k, v in text_batch.items():
93+
text_batch[k] = v[device_offset:device_offset + 1]
94+
return text_batch
95+
96+
def _shard_dataset(self, dataset):
97+
# Select a subset of filepaths for sharded DDP training
98+
world_size = dist.get_world_size()
99+
rank = dist.get_global_rank()
100+
filepaths = dataset._ex_iterable.kwargs['filepaths']
101+
# If subsampling using 'max_shards', determimistically choose shards
102+
if self.max_shards > 0:
103+
filepaths = filepaths[:self.max_shards]
104+
num_shards = len(filepaths)
105+
106+
devices_per_shard = 1
107+
if world_size > num_shards:
108+
log.warning(
109+
f"Not enough unique shards ({num_shards}) for world size ({world_size}). Splitting shards among devices."
110+
)
111+
assert world_size % num_shards == 0, f"Cannot evenly split shards among devices"
112+
devices_per_shard = world_size // num_shards
113+
shard_offset = rank // devices_per_shard
114+
device_offset = rank % devices_per_shard
115+
116+
device_filepaths = filepaths[shard_offset::world_size]
117+
dataset._ex_iterable.kwargs['filepaths'] = device_filepaths
118+
119+
# Subsample dataset if shard is being shared among devices
120+
# NOTE: Mapping is executed in batched mode for better CPU utilization,
121+
# but the returned dataset is still an iterable over text samples
122+
if devices_per_shard > 1:
123+
dataset = dataset.map(
124+
partial(self._subsample, device_offset),
125+
batched=True,
126+
batch_size=devices_per_shard,
127+
)
128+
return dataset
129+
130+
def _tokenize(self, text_batch):
131+
# Convert a text batch to a token batch
132+
if self.group_method == "truncate":
133+
truncation = True
134+
padding = 'max_length'
135+
max_length = self.max_seq_len
136+
else:
137+
truncation = False
138+
padding = False
139+
max_length = None
140+
return self.tokenizer(text_batch["text"], truncation=truncation, padding=padding, max_length=max_length)
141+
142+
def _group_tokens(self, token_batch):
143+
if self.group_method == "concat":
144+
# Concatenate all tokens.
145+
concat_tokens = {}
146+
num_tokens = None
147+
for k, v in token_batch.items():
148+
concat_v = list(chain(*v))
149+
concat_tokens[k] = concat_v
150+
if num_tokens is None:
151+
num_tokens = len(concat_v)
152+
else:
153+
assert num_tokens == len(concat_v), "Not all values in concat_tokens dict have same len()"
154+
155+
# We drop the small remainder of tokens at the end of the batch,
156+
# In the future we could support padding.
157+
if num_tokens >= self.max_seq_len:
158+
num_tokens = (num_tokens // self.max_seq_len) * self.max_seq_len
159+
160+
# Split into token samples of size max_seq_len.
161+
result = {
162+
k: [v[i:i + self.max_seq_len] for i in range(0, num_tokens, self.max_seq_len)] for k, v in concat_tokens.items()
163+
}
164+
result["labels"] = result["input_ids"].copy()
165+
return result
166+
else:
167+
raise ValueError(f"Unknown group_method: '{group_method}'")
168+
169+
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataloaderSpec:
170+
assert dataloader_hparams.num_workers == 1, "LM Streaming Dataloader only supports num_workers=1"
171+
172+
try:
173+
import datasets
174+
import transformers
175+
except ImportError:
176+
raise ImportError('huggingface transformers and datasets are not installed. '
177+
'Please install with `pip install mosaicml-composer[nlp]`')
178+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty)
179+
self.config = transformers.AutoConfig.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty)
180+
181+
# Load and shard dataset
182+
text_dataset = self._load_dataset()
183+
text_dataset = self._shard_dataset(text_dataset)
184+
185+
# Shuffle
186+
if self.shuffle:
187+
text_dataset = text_dataset.shuffle(buffer_size=10000, seed=self.seed)
188+
189+
# Map text samples to token samples
190+
# NOTE: Mapping is executed in batched mode for better CPU utilization,
191+
# but the returned dataset is still an iterable over tokenized samples
192+
text_sample_batch_size = 1000
193+
token_dataset = text_dataset.map(
194+
self._tokenize,
195+
batched=True,
196+
batch_size=text_sample_batch_size,
197+
)
198+
199+
if self.group_method != "truncate":
200+
# Map variable-length token samples to fixed-length token samples
201+
# NOTE: Mapping is executed in batched mode for better CPU utilization,
202+
# but the returned dataset is still an iterable over tokenized samples.
203+
# NOTE: Depending on the 'group_method', this step may alter the number of
204+
# token samples in the dataset, and may mix neighboring token samples together.
205+
token_sample_batch_size = 1000
206+
token_dataset = token_dataset.map(
207+
self._group_tokens,
208+
batched=True,
209+
batch_size=token_sample_batch_size,
210+
)
211+
212+
# Maybe limit the number of post-processed samples
213+
if self.max_samples > 0:
214+
token_dataset = token_dataset.take(self.max_samples // dist.get_world_size())
215+
216+
# Add approx num samples and create a SizedIterableDataset
217+
sized_iterable_dataset = SizedIterableDataset(token_dataset, self._get_approx_num_samples())
218+
219+
220+
# Get collate_fn
221+
if self.tokenizer_name in ["gpt2"]:
222+
# Really annoying but GPT2 tokenizer has no padding token which causes bugs
223+
collate_fn = transformers.default_data_collator
224+
else:
225+
collate_fn = transformers.DataCollatorForLanguageModeling(tokenizer=self.tokenizer,
226+
mlm=self.use_masked_lm,
227+
mlm_probability=self.mlm_probability)
228+
# Return DataloaderSpec
229+
return DataloaderSpec(dataloader=dataloader_hparams.initialize_object(
230+
dataset=sized_iterable_dataset,
231+
batch_size=batch_size,
232+
sampler=None,
233+
drop_last=self.drop_last,
234+
collate_fn=collate_fn,
235+
),
236+
split_fn=_split_dict_fn)
237+
238+
239+
class SizedIterableDataset(torch.utils.data.IterableDataset):
240+
241+
def __init__(self, hf_iterable_dataset, num_samples):
242+
self.hf_iterable_dataset = hf_iterable_dataset
243+
self.num_samples = num_samples
244+
245+
def __iter__(self):
246+
return iter(self.hf_iterable_dataset)
247+
248+
def __len__(self):
249+
return self.num_samples

composer/trainer/trainer_hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"mnist": datasets.MNISTDatasetHparams,
7979
"lm": datasets.LMDatasetHparams,
8080
"glue": datasets.GLUEHparams,
81+
"streaming_lm": datasets.StreamingLMDatasetHparams,
8182
}
8283

8384
algorithms_registry = get_algorithm_registry()
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# GPT2-125m with streaming C4 dataset
2+
3+
train_dataset:
4+
streaming_lm:
5+
dataset_name: c4
6+
dataset_config_name: en
7+
split: train
8+
max_shards: -1
9+
max_samples: 7168000
10+
max_seq_len: 1024
11+
group_method: concat
12+
tokenizer_name: gpt2
13+
use_masked_lm: false
14+
seed: 17
15+
shuffle: true
16+
drop_last: true
17+
val_dataset:
18+
streaming_lm:
19+
dataset_name: c4
20+
dataset_config_name: en
21+
split: validation
22+
max_shards: -1
23+
max_samples: 128000
24+
max_seq_len: 1024
25+
group_method: concat
26+
tokenizer_name: gpt2
27+
use_masked_lm: false
28+
seed: 17
29+
shuffle: false
30+
drop_last: true
31+
32+
model:
33+
gpt2:
34+
use_pretrained: false
35+
tokenizer_name: gpt2
36+
model_config:
37+
activation_function: gelu_new
38+
architectures:
39+
- GPT2LMHeadModel
40+
attn_pdrop: 0.1
41+
bos_token_id: 50256
42+
embd_pdrop: 0.1
43+
eos_token_id: 50256
44+
initializer_range: 0.02
45+
layer_norm_epsilon: 1.0e-05
46+
model_type: gpt2
47+
n_ctx: 1024
48+
n_embd: 768
49+
n_head: 12
50+
n_inner: 3072
51+
n_layer: 12
52+
n_positions: 1024
53+
resid_pdrop: 0.1
54+
scale_attn_weights: true
55+
summary_activation: null
56+
summary_first_dropout: 0.1
57+
summary_proj_to_labels: true
58+
summary_type: cls_index
59+
summary_use_proj: true
60+
task_specific_params:
61+
text-generation:
62+
do_sample: true
63+
max_length: 50
64+
transformers_version: 4.11.0.dev0
65+
use_cache: true
66+
vocab_size: 50257
67+
optimizer:
68+
adamw:
69+
lr: 6.0e-4
70+
betas:
71+
- 0.9
72+
- 0.999
73+
eps: 1.0e-08
74+
weight_decay: 0.0
75+
schedulers:
76+
- warmup:
77+
warmup_method: linear
78+
warmup_factor: 0
79+
interval: step
80+
warmup_iters: 140ba
81+
- cosine_decay:
82+
interval: step
83+
eta_min: 0
84+
verbose: false
85+
T_max: 13860ba
86+
loggers:
87+
- file:
88+
log_level: batch
89+
filename: stdout
90+
buffer_size: 1
91+
flush_every_n_batches: 100
92+
every_n_batches: 100
93+
every_n_epochs: 1
94+
max_epochs: 1
95+
train_batch_size: 512
96+
eval_batch_size: 8 # use micro_bs_per_gpu = 1 to accomodate 10GB limit
97+
seed: 17
98+
device:
99+
gpu: {}
100+
dataloader:
101+
pin_memory: true
102+
persistent_workers: true
103+
num_workers: 1
104+
timeout: 0
105+
prefetch_factor: 2
106+
precision: amp
107+
grad_clip_norm: 1.0
108+
grad_accum: 22
109+
validate_every_n_batches: 1000
110+
validate_every_n_epochs: 1

0 commit comments

Comments
 (0)