|
| 1 | +# Copyright 2021 MosaicML. All Rights Reserved. |
| 2 | + |
| 3 | +import copy |
| 4 | +import logging |
| 5 | +import tempfile |
| 6 | +from dataclasses import dataclass |
| 7 | +from functools import partial |
| 8 | +from itertools import chain |
| 9 | +from os.path import join |
| 10 | +from typing import List, Optional |
| 11 | + |
| 12 | +import datasets |
| 13 | +import torch |
| 14 | +import yahp as hp |
| 15 | +from transformers.testing_utils import CaptureLogger |
| 16 | + |
| 17 | +from composer.core.types import Batch |
| 18 | +from composer.datasets.dataloader import DataloaderHparams |
| 19 | +from composer.datasets.hparams import DataloaderSpec, DatasetHparams |
| 20 | +from composer.utils import dist |
| 21 | +from composer.utils.data import get_subset_dataset |
| 22 | + |
| 23 | +log = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +def _split_dict_fn(batch: Batch, n_microbatches: int) -> List[Batch]: |
| 27 | + if isinstance(batch, dict): |
| 28 | + chunked = {k: v.chunk(n_microbatches) for k, v in batch.items()} |
| 29 | + num_chunks = len(list(chunked.values())[0]) |
| 30 | + return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] |
| 31 | + else: |
| 32 | + raise ValueError(f'Expect batch from dataloader to be of type Dict[str, Tensor], but got {type(batch)}') |
| 33 | + |
| 34 | + |
| 35 | +CACHED_DATASET_SIZES = {"c4": {"en": {"train": (1024, 356317), "validation": (8, 45576)}}} |
| 36 | + |
| 37 | +@dataclass |
| 38 | +class StreamingLMDatasetHparams(DatasetHparams): |
| 39 | + """ |
| 40 | + Defines a generic dataset class for autoregressive and masked language models. |
| 41 | + """ |
| 42 | + |
| 43 | + dataset_name: str = hp.optional("Name of the dataset to load.", default=None) |
| 44 | + dataset_config_name: Optional[str] = hp.optional( |
| 45 | + "If required, the specific configuration of the dataset that you would like to use.", default=None) |
| 46 | + split: str = hp.optional("What split of the dataset to use (e.g. 'train' or 'validation' or 'test')", default=None) |
| 47 | + max_shards: int = hp.optional("Max number of shards, used to deterministically reduce dataset size.", default=-1) |
| 48 | + max_samples: int = hp.optional("Max number of post-processed samples, note that the subset will depend on seed and world size.", default=-1) |
| 49 | + tokenizer_name: str = hp.optional("The name of the tokenizer to preprocess text with.", default=None) |
| 50 | + max_seq_len: int = hp.optional("The max sequence length of each token sample.", default=None) |
| 51 | + group_method: str = hp.optional("How to group text samples into token samples.", default=None) |
| 52 | + use_masked_lm: bool = hp.optional("Whether the dataset shoud be encoded with masked language modeling or not.", |
| 53 | + default=None) |
| 54 | + mlm_probability: float = hp.optional("If using masked language modeling, the probability to mask tokens with.", |
| 55 | + default=0.15) |
| 56 | + seed: int = hp.optional("Which seed to use to generate train and validation splits.", default=5) |
| 57 | + shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch.", default=True) |
| 58 | + drop_last: bool = hp.optional("Whether to drop the last samples for the last batch.", default=False) |
| 59 | + |
| 60 | + def validate(self): |
| 61 | + assert self.group_method in ["truncate", "concat"], f"Unknown group_method: '{self.group_method}'" |
| 62 | + assert self.drop_last == True, "No support for 'drop_last'=False currently." |
| 63 | + if self.group_method == "concat": |
| 64 | + assert self.max_samples > 0, f"Must provide 'max_samples' if 'group_method'='concat'" |
| 65 | + if self.use_masked_lm: |
| 66 | + if self.mlm_probability <= 0.0: |
| 67 | + raise ValueError( |
| 68 | + "If using Masked Language Modeling, you must replace tokens with a non-zero probability.") |
| 69 | + |
| 70 | + def _load_dataset(self): |
| 71 | + return datasets.load_dataset(path=self.dataset_name, |
| 72 | + name=self.dataset_config_name, |
| 73 | + split=self.split, |
| 74 | + streaming=True) |
| 75 | + |
| 76 | + def _get_approx_num_samples(self): |
| 77 | + try: |
| 78 | + if self.max_samples > 0: |
| 79 | + return self.max_samples |
| 80 | + else: |
| 81 | + n_shards, samples_per_shard = CACHED_DATASET_SIZES[self.dataset_name][self.dataset_config_name][self.split] |
| 82 | + n_shards = self.max_shards if self.max_shards > 0 else n_shards |
| 83 | + return n_shards * samples_per_shard |
| 84 | + except: |
| 85 | + raise NotImplementedError |
| 86 | + |
| 87 | + def _get_approx_num_tokens(self): |
| 88 | + return 1e12 |
| 89 | + |
| 90 | + def _subsample(self, device_offset, text_batch): |
| 91 | + # Only return the i-th item out of N sequential items |
| 92 | + for k, v in text_batch.items(): |
| 93 | + text_batch[k] = v[device_offset:device_offset + 1] |
| 94 | + return text_batch |
| 95 | + |
| 96 | + def _shard_dataset(self, dataset): |
| 97 | + # Select a subset of filepaths for sharded DDP training |
| 98 | + world_size = dist.get_world_size() |
| 99 | + rank = dist.get_global_rank() |
| 100 | + filepaths = dataset._ex_iterable.kwargs['filepaths'] |
| 101 | + # If subsampling using 'max_shards', determimistically choose shards |
| 102 | + if self.max_shards > 0: |
| 103 | + filepaths = filepaths[:self.max_shards] |
| 104 | + num_shards = len(filepaths) |
| 105 | + |
| 106 | + devices_per_shard = 1 |
| 107 | + if world_size > num_shards: |
| 108 | + log.warning( |
| 109 | + f"Not enough unique shards ({num_shards}) for world size ({world_size}). Splitting shards among devices." |
| 110 | + ) |
| 111 | + assert world_size % num_shards == 0, f"Cannot evenly split shards among devices" |
| 112 | + devices_per_shard = world_size // num_shards |
| 113 | + shard_offset = rank // devices_per_shard |
| 114 | + device_offset = rank % devices_per_shard |
| 115 | + |
| 116 | + device_filepaths = filepaths[shard_offset::world_size] |
| 117 | + dataset._ex_iterable.kwargs['filepaths'] = device_filepaths |
| 118 | + |
| 119 | + # Subsample dataset if shard is being shared among devices |
| 120 | + # NOTE: Mapping is executed in batched mode for better CPU utilization, |
| 121 | + # but the returned dataset is still an iterable over text samples |
| 122 | + if devices_per_shard > 1: |
| 123 | + dataset = dataset.map( |
| 124 | + partial(self._subsample, device_offset), |
| 125 | + batched=True, |
| 126 | + batch_size=devices_per_shard, |
| 127 | + ) |
| 128 | + return dataset |
| 129 | + |
| 130 | + def _tokenize(self, text_batch): |
| 131 | + # Convert a text batch to a token batch |
| 132 | + if self.group_method == "truncate": |
| 133 | + truncation = True |
| 134 | + padding = 'max_length' |
| 135 | + max_length = self.max_seq_len |
| 136 | + else: |
| 137 | + truncation = False |
| 138 | + padding = False |
| 139 | + max_length = None |
| 140 | + return self.tokenizer(text_batch["text"], truncation=truncation, padding=padding, max_length=max_length) |
| 141 | + |
| 142 | + def _group_tokens(self, token_batch): |
| 143 | + if self.group_method == "concat": |
| 144 | + # Concatenate all tokens. |
| 145 | + concat_tokens = {} |
| 146 | + num_tokens = None |
| 147 | + for k, v in token_batch.items(): |
| 148 | + concat_v = list(chain(*v)) |
| 149 | + concat_tokens[k] = concat_v |
| 150 | + if num_tokens is None: |
| 151 | + num_tokens = len(concat_v) |
| 152 | + else: |
| 153 | + assert num_tokens == len(concat_v), "Not all values in concat_tokens dict have same len()" |
| 154 | + |
| 155 | + # We drop the small remainder of tokens at the end of the batch, |
| 156 | + # In the future we could support padding. |
| 157 | + if num_tokens >= self.max_seq_len: |
| 158 | + num_tokens = (num_tokens // self.max_seq_len) * self.max_seq_len |
| 159 | + |
| 160 | + # Split into token samples of size max_seq_len. |
| 161 | + result = { |
| 162 | + k: [v[i:i + self.max_seq_len] for i in range(0, num_tokens, self.max_seq_len)] for k, v in concat_tokens.items() |
| 163 | + } |
| 164 | + result["labels"] = result["input_ids"].copy() |
| 165 | + return result |
| 166 | + else: |
| 167 | + raise ValueError(f"Unknown group_method: '{group_method}'") |
| 168 | + |
| 169 | + def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataloaderSpec: |
| 170 | + assert dataloader_hparams.num_workers == 1, "LM Streaming Dataloader only supports num_workers=1" |
| 171 | + |
| 172 | + try: |
| 173 | + import datasets |
| 174 | + import transformers |
| 175 | + except ImportError: |
| 176 | + raise ImportError('huggingface transformers and datasets are not installed. ' |
| 177 | + 'Please install with `pip install mosaicml-composer[nlp]`') |
| 178 | + self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty) |
| 179 | + self.config = transformers.AutoConfig.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty) |
| 180 | + |
| 181 | + # Load and shard dataset |
| 182 | + text_dataset = self._load_dataset() |
| 183 | + text_dataset = self._shard_dataset(text_dataset) |
| 184 | + |
| 185 | + # Shuffle |
| 186 | + if self.shuffle: |
| 187 | + text_dataset = text_dataset.shuffle(buffer_size=10000, seed=self.seed) |
| 188 | + |
| 189 | + # Map text samples to token samples |
| 190 | + # NOTE: Mapping is executed in batched mode for better CPU utilization, |
| 191 | + # but the returned dataset is still an iterable over tokenized samples |
| 192 | + text_sample_batch_size = 1000 |
| 193 | + token_dataset = text_dataset.map( |
| 194 | + self._tokenize, |
| 195 | + batched=True, |
| 196 | + batch_size=text_sample_batch_size, |
| 197 | + ) |
| 198 | + |
| 199 | + if self.group_method != "truncate": |
| 200 | + # Map variable-length token samples to fixed-length token samples |
| 201 | + # NOTE: Mapping is executed in batched mode for better CPU utilization, |
| 202 | + # but the returned dataset is still an iterable over tokenized samples. |
| 203 | + # NOTE: Depending on the 'group_method', this step may alter the number of |
| 204 | + # token samples in the dataset, and may mix neighboring token samples together. |
| 205 | + token_sample_batch_size = 1000 |
| 206 | + token_dataset = token_dataset.map( |
| 207 | + self._group_tokens, |
| 208 | + batched=True, |
| 209 | + batch_size=token_sample_batch_size, |
| 210 | + ) |
| 211 | + |
| 212 | + # Maybe limit the number of post-processed samples |
| 213 | + if self.max_samples > 0: |
| 214 | + token_dataset = token_dataset.take(self.max_samples // dist.get_world_size()) |
| 215 | + |
| 216 | + # Add approx num samples and create a SizedIterableDataset |
| 217 | + sized_iterable_dataset = SizedIterableDataset(token_dataset, self._get_approx_num_samples()) |
| 218 | + |
| 219 | + |
| 220 | + # Get collate_fn |
| 221 | + if self.tokenizer_name in ["gpt2"]: |
| 222 | + # Really annoying but GPT2 tokenizer has no padding token which causes bugs |
| 223 | + collate_fn = transformers.default_data_collator |
| 224 | + else: |
| 225 | + collate_fn = transformers.DataCollatorForLanguageModeling(tokenizer=self.tokenizer, |
| 226 | + mlm=self.use_masked_lm, |
| 227 | + mlm_probability=self.mlm_probability) |
| 228 | + # Return DataloaderSpec |
| 229 | + return DataloaderSpec(dataloader=dataloader_hparams.initialize_object( |
| 230 | + dataset=sized_iterable_dataset, |
| 231 | + batch_size=batch_size, |
| 232 | + sampler=None, |
| 233 | + drop_last=self.drop_last, |
| 234 | + collate_fn=collate_fn, |
| 235 | + ), |
| 236 | + split_fn=_split_dict_fn) |
| 237 | + |
| 238 | + |
| 239 | +class SizedIterableDataset(torch.utils.data.IterableDataset): |
| 240 | + |
| 241 | + def __init__(self, hf_iterable_dataset, num_samples): |
| 242 | + self.hf_iterable_dataset = hf_iterable_dataset |
| 243 | + self.num_samples = num_samples |
| 244 | + |
| 245 | + def __iter__(self): |
| 246 | + return iter(self.hf_iterable_dataset) |
| 247 | + |
| 248 | + def __len__(self): |
| 249 | + return self.num_samples |
0 commit comments