Skip to content

Commit 0d6b3af

Browse files
Moin Nadeemravi-mosaicmlMoin Nadeemabhi-mosaic
authored
Add BERT Base to Composer (#195)
* Load Checkpoints from Cloud Storage Added support to load checkpoints stored in object storage (rather than just on the local disk). Closes #192. - Refactored the run directory uploader to separate out object store related utilites to composer.utils.object_store (and added test coverage). - Updated the checkpointer hparams to optionally take `composer.utils.object_store.ObjectStoreProviderHparams`, which would be used to download the checkpoint from storage. - Updated the trainer init to propegate through this change. * Libcloud intersphinx * rebasing off of dev * starting an LR sweep * adding proper dataset and batch size * 2.0e-3 LR causes NaNs, lowering lr * changing adam * adding SST-2 * adding validation tracking * adding SST-2 -- training but not at the right accuracy * cleaning up code & debugging why training loss is so large * finalized YAML for SST-2, starting hparam sweeps * updating hparams to sweep: * finalized current setup for SST-2 * starting hparam sweeps on RTE * adding support for warmup_ratio * adding non-standard metrics * adding support for duration as a time abstraction * adding compatability with DataloaderSpec changes * adding a linear learning rate decay * adding linear LR warmup * finalizing GLUE * refactoring implementation to add regression tasks * fixing checkpoint bug * finalizing fine-tuning a checkpointed model * fixing checkpoint bug * adding validation * adding mid-training * starting LR sweep * adding checkpointing feedback part 1 * fix validation interval * address PR feedback * address PR feedback * adding save_checkpoint and load_checkpoint hparams interface * adding save_checkpoint and load_checkpoint hparams interface * yapf & pyright * fixed error with logging pre-training validation loss * cleaning up model forward pass * cleaning up custom metrics * renaming Checkpointer -> CheckpointSaver * addressing pyright * adding tests * moving commits to BERT branch * changing folder to be relative to run dir * formatting * adding tests * adding initial YAML changes * removing a copy of outdated files * adding GLUE default params * addressing pyright * finalizing task-specific YAMLs * code cleanup * yapf * adding license * addressing tests * formatting * adding tests for the duration abstraction * can i sue pyright for emotional damages? * final formatting * adding in finalized pre-training hyperparameters * Update composer/models/bert/bert_hparams.py Co-authored-by: Abhi Venigalla <[email protected]> * Load Checkpoints from Cloud Storage Added support to load checkpoints stored in object storage (rather than just on the local disk). Closes #192. - Refactored the run directory uploader to separate out object store related utilites to composer.utils.object_store (and added test coverage). - Updated the checkpointer hparams to optionally take `composer.utils.object_store.ObjectStoreProviderHparams`, which would be used to download the checkpoint from storage. - Updated the trainer init to propegate through this change. * Libcloud intersphinx * addressing PR feedback * changing checkpoints into a cloud URl * addressing Landan's feedback * filepath -> checkpoint in the YAMLs * Fixed merge * Removed auto-parsing s3 and gs urls, as libcloud requires authentication. Fixed tests. * Flattened run directory uploader hparams * Fixed object store provider hparams * updating sampler to be composer.dist * Added tqdm progress bars and chunk sizing paramterization Refactored checkpoint storage * Fix pyright * Fixed timeout * Fix checkpointing * Fixed deepspeed checkpoints * Cleaned up PR * finalized checkpointing loading * refactored metric to avoid lists * addressing pyright * updating YAMLs with checkpoints * final change * adding unit tests * adding LICENSE * addressing conflicts & tests * isort * removing finished TODOs * adding new GPT-2 YAMLs Co-authored-by: Ravi Rahman <[email protected]> Co-authored-by: Moin Nadeem <[email protected]> Co-authored-by: Abhi Venigalla <[email protected]>
1 parent aad5330 commit 0d6b3af

38 files changed

+1582
-47
lines changed

composer/core/state.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def load_model_state(self, state_dict: types.StateDict, strict: bool):
276276
"""
277277
if state_dict["_is_model_ddp_wrapped"] and not isinstance(self.model, DistributedDataParallel):
278278
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], "module.")
279-
280279
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
281280
if len(missing_keys) > 0:
282281
logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")

composer/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
66
from composer.datasets.dataloader import DDPDataLoader as DDPDataLoader
77
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
8+
from composer.datasets.glue import GLUEHparams as GLUEHparams
89
from composer.datasets.hparams import DatasetHparams as DatasetHparams
910
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
1011
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams

composer/datasets/glue.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2021 MosaicML. All Rights Reserved.
2+
3+
import logging
4+
from dataclasses import dataclass
5+
from multiprocessing import cpu_count
6+
7+
import yahp as hp
8+
9+
from composer.core import DataSpec
10+
from composer.datasets.dataloader import DataloaderHparams
11+
from composer.datasets.hparams import DatasetHparams
12+
from composer.datasets.lm_datasets import _split_dict_fn
13+
from composer.utils import dist
14+
15+
log = logging.getLogger(__name__)
16+
17+
18+
@dataclass
19+
class GLUEHparams(DatasetHparams):
20+
"""
21+
Sets up a generic GLUE dataset loader.
22+
23+
Args:
24+
task (str): the GLUE task to train on, choose one from: CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, and STS-B.
25+
tokenizer_name (str): The name of the HuggingFace tokenizer to preprocess text with.
26+
split (str): Whether to use 'train', 'validation' or 'test' split.
27+
max_seq_length (int): Optionally, the ability to set a custom sequence length for the training dataset.
28+
Default: 256
29+
30+
Returns:
31+
A :class:`~composer.core.DataSpec` object
32+
"""
33+
34+
task: str = hp.optional(
35+
"The GLUE task to train on, choose one from: CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, and STS-B.", default=None)
36+
tokenizer_name: str = hp.optional("The name of the tokenizer to preprocess text with.", default=None)
37+
split: str = hp.optional("Whether to use 'train', 'validation' or 'test' split.", default=None)
38+
max_seq_length: int = hp.optional(
39+
default=256, doc='Optionally, the ability to set a custom sequence length for the training dataset.')
40+
41+
def validate(self):
42+
self.task_to_keys = {
43+
"cola": ("sentence", None),
44+
"mnli": ("premise", "hypothesis"),
45+
"mrpc": ("sentence1", "sentence2"),
46+
"qnli": ("question", "sentence"),
47+
"qqp": ("question1", "question2"),
48+
"rte": ("sentence1", "sentence2"),
49+
"sst2": ("sentence", None),
50+
"stsb": ("sentence1", "sentence2"),
51+
}
52+
53+
if self.task not in self.task_to_keys.keys():
54+
raise ValueError(f"The task must be a valid GLUE task, options are {' ,'.join(self.task_to_keys.keys())}.")
55+
56+
if (self.max_seq_length % 8) != 0:
57+
log.warning("For best hardware acceleration, it is recommended that sequence lengths be multiples of 8.")
58+
59+
if self.tokenizer_name is None:
60+
raise ValueError("A tokenizer name must be specified to tokenize the dataset.")
61+
62+
if self.split is None:
63+
raise ValueError("A dataset split must be specified.")
64+
65+
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataSpec:
66+
# TODO (Moin): I think this code is copied verbatim in a few different places. Move this into a function.
67+
try:
68+
import datasets
69+
import transformers
70+
except ImportError:
71+
raise ImportError('huggingface transformers and datasets are not installed. '
72+
'Please install with `pip install mosaicml-composer[nlp]`')
73+
74+
self.validate()
75+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty)
76+
77+
log.info(f"Loading {self.task.upper()}...")
78+
self.dataset = datasets.load_dataset("glue", self.task, split=self.split)
79+
80+
n_cpus = cpu_count()
81+
log.info(f"Starting tokenization step by preprocessing over {n_cpus} threads!")
82+
text_column_names = self.task_to_keys[self.task]
83+
84+
def tokenize_function(inp):
85+
# truncates sentences to max_length or pads them to max_length
86+
87+
first_half = inp[text_column_names[0]]
88+
second_half = inp[text_column_names[1]] if text_column_names[1] in inp else None
89+
return self.tokenizer(
90+
text=first_half,
91+
text_pair=second_half,
92+
padding="max_length",
93+
max_length=self.max_seq_length,
94+
truncation=True,
95+
)
96+
97+
columns_to_remove = ["idx"] + [i for i in text_column_names if i is not None]
98+
assert isinstance(self.dataset, datasets.Dataset)
99+
dataset = self.dataset.map(
100+
tokenize_function,
101+
batched=True,
102+
num_proc=n_cpus,
103+
batch_size=1000,
104+
remove_columns=columns_to_remove,
105+
new_fingerprint=f"{self.task}-tokenization-{self.split}",
106+
load_from_cache_file=True,
107+
)
108+
109+
data_collator = transformers.data.data_collator.default_data_collator
110+
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
111+
112+
return DataSpec(
113+
dataloader=dataloader_hparams.initialize_object(
114+
dataset=dataset, #type: ignore (thirdparty)
115+
batch_size=batch_size,
116+
sampler=sampler,
117+
drop_last=self.drop_last,
118+
collate_fn=data_collator,
119+
),
120+
split_batch=_split_dict_fn)

composer/datasets/lm_datasets.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,65 @@ def _split_dict_fn(batch: Batch, n_microbatches: int) -> List[Batch]:
2828
@dataclass
2929
class LMDatasetHparams(DatasetHparams):
3030
"""
31-
Defines a generic dataset class for autoregressive language models.
31+
Defines a generic dataset class for autoregressive and masked language models trained with self-supervised learning.
3232
"""
3333

3434
# TODO(moin): Switch datadir to be a string, rather than a list of strings, to be similar to the
3535
# other datasets
3636
datadir: List[str] = hp.optional( # type: ignore
3737
"Path to the Huggingface Datasets directory.", default_factory=list)
38+
3839
split: Optional[str] = hp.optional("Whether to use 'train', 'validation' or 'test' split.", default=None)
3940
tokenizer_name: Optional[str] = hp.optional("The name of the tokenizer to preprocess text with.", default=None)
41+
use_masked_lm: bool = hp.optional("Whether the dataset shoud be encoded with masked language modeling or not.",
42+
default=None)
4043
num_tokens: int = hp.optional(doc='If desired, the number of tokens to truncate the dataset to.', default=0)
44+
mlm_probability: float = hp.optional("If using masked language modeling, the probability to mask tokens with.",
45+
default=0.15)
4146
seed: int = hp.optional("Which seed to use to generate train and validation splits.", default=5)
4247
subsample_ratio: float = hp.optional(default=1.0, doc='If desired, the percentage of the dataset to use.')
4348
train_sequence_length: int = hp.optional(
4449
default=1024, doc='Optionally, the ability to set a custom sequence length for the training dataset.')
4550
val_sequence_length: int = hp.optional(
4651
default=1024, doc='Optionally, the ability to set a custom sequence length for the validation dataset.')
4752

53+
def validate(self):
54+
if self.datadir is None:
55+
raise ValueError("A data directory must be specified.")
56+
57+
if self.split not in ['train', 'validation', 'test']:
58+
raise ValueError("The dataset split must be one of 'train', 'validation', or 'test'.")
59+
60+
if self.tokenizer_name is None:
61+
raise ValueError("A tokenizer name must be specified to tokenize the dataset.")
62+
63+
if self.use_masked_lm is None:
64+
raise ValueError("To determine masking, use_masked_lm must be specified.")
65+
66+
if self.use_masked_lm:
67+
if self.mlm_probability <= 0.0:
68+
raise ValueError(
69+
"If using Masked Language Modeling, you must replace tokens with a non-zero probability.")
70+
71+
if self.num_tokens > 0 and self.subsample_ratio < 1.0:
72+
raise Exception("Must specify one of num_tokens OR subsample_ratio, cannot specify both.")
73+
74+
if (self.train_sequence_length % 8 != 0) or (self.val_sequence_length % 8 != 0):
75+
log.warning("For best hardware acceleration, it is recommended that sequence lengths be multiples of 8.")
76+
4877
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataSpec:
4978
try:
5079
import datasets
5180
import transformers
5281
except ImportError as e:
5382
raise ImportError('huggingface transformers and datasets are not installed. '
5483
'Please install with `pip install mosaicml-composer[nlp]`') from e
84+
85+
self.validate()
5586
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty)
5687
self.config = transformers.AutoConfig.from_pretrained(self.tokenizer_name) #type: ignore (thirdparty)
5788
lm_datasets = [datasets.load_from_disk(i) for i in self.datadir] #type: ignore (thirdparty)
5889

59-
# TODO: this re-loads a large dataset into memory three times
60-
if self.split not in ['train', 'validation', 'test']:
61-
raise ValueError("The dataset split must be one of 'train', 'validation', or 'test'.")
62-
6390
# merge the dataset to re-sample from
6491
if self.split is None:
6592
raise ValueError("split is required")
@@ -74,9 +101,6 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
74101
# shuffle the dataset
75102
lm_datasets = lm_datasets.shuffle(indices_cache_file_name=indices_cache_file_name, seed=self.seed)
76103

77-
if self.num_tokens > 0 and self.subsample_ratio < 1.0:
78-
raise Exception("Must specify one of num_tokens OR subsample_ratio, cannot specify both.")
79-
80104
total_num_samples = len(lm_datasets)
81105
tokens_per_sample = len(lm_datasets[0]['input_ids'])
82106
total_num_tokens = total_num_samples * tokens_per_sample
@@ -91,6 +115,8 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
91115
elif self.subsample_ratio < 1.0:
92116
num_samples = round(total_num_samples * self.subsample_ratio)
93117
self.num_tokens = num_samples * tokens_per_sample
118+
elif self.subsample_ratio == 1.0 and self.num_tokens == 0:
119+
self.num_tokens = total_num_tokens
94120
else:
95121
log.warning("No subsampling going on!")
96122

@@ -100,8 +126,10 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
100126
log.info(f"Total number of samples: {num_samples:e}")
101127
log.info(f"Total number of tokens: {self.num_tokens:e}")
102128
dataset = lm_datasets
103-
data_collator = transformers.default_data_collator
104129

130+
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=self.tokenizer,
131+
mlm=self.use_masked_lm,
132+
mlm_probability=self.mlm_probability)
105133
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
106134

107135
return DataSpec(dataloader=dataloader_hparams.initialize_object(

composer/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from composer.models.base import BaseMosaicModel as BaseMosaicModel
44
from composer.models.base import MosaicClassifier as MosaicClassifier
5+
from composer.models.bert import BERTForClassificationHparams as BERTForClassificationHparams
6+
from composer.models.bert import BERTHparams as BERTHparams
7+
from composer.models.bert import BERTModel as BERTModel
58
from composer.models.classify_mnist import MNIST_Classifier as MNIST_Classifier
69
from composer.models.classify_mnist import MnistClassifierHparams as MnistClassifierHparams
710
from composer.models.efficientnetb0 import EfficientNetB0 as EfficientNetB0

composer/models/bert/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2021 MosaicML. All Rights Reserved.
2+
3+
from composer.models.bert.bert_hparams import BERTForClassificationHparams as BERTForClassificationHparams
4+
from composer.models.bert.bert_hparams import BERTHparams as BERTHparams
5+
from composer.models.bert.model import BERTModel as BERTModel
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2021 MosaicML. All Rights Reserved.
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING
5+
6+
import yahp as hp
7+
8+
from composer.models.transformer_hparams import TransformerHparams
9+
10+
if TYPE_CHECKING:
11+
from composer.models.transformer_shared import MosaicTransformer
12+
13+
14+
@dataclass
15+
class BERTForClassificationHparams(TransformerHparams):
16+
num_labels: int = hp.optional(doc="The number of possible labels for the task.", default=2)
17+
18+
def validate(self):
19+
if self.num_labels < 1:
20+
raise ValueError("The number of target labels must be at least one.")
21+
22+
def initialize_object(self) -> "MosaicTransformer":
23+
try:
24+
import transformers
25+
except ImportError as e:
26+
raise ImportError('transformers is not installed. '
27+
'Please install with `pip install mosaicml-composer[nlp]`') from e
28+
29+
from composer.models.bert.model import BERTModel
30+
self.validate()
31+
32+
model_hparams = {"num_labels": self.num_labels}
33+
34+
if self.model_config:
35+
config = transformers.BertConfig.from_dict(self.model_config, **model_hparams)
36+
elif self.pretrained_model_name is not None:
37+
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name, **model_hparams)
38+
else:
39+
raise ValueError('One of pretrained_model_name or model_config needed.')
40+
config.num_labels = self.num_labels
41+
42+
if self.use_pretrained:
43+
# TODO (Moin): handle the warnings on not using the seq_relationship head
44+
model = transformers.AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name,
45+
**model_hparams)
46+
else:
47+
model = transformers.AutoModelForSequenceClassification.from_config( #type: ignore (thirdparty)
48+
config, **model_hparams)
49+
50+
return BERTModel(
51+
module=model,
52+
config=config, #type: ignore (thirdparty)
53+
tokenizer_name=self.tokenizer_name,
54+
)
55+
56+
57+
@dataclass
58+
class BERTHparams(TransformerHparams):
59+
60+
def initialize_object(self) -> "MosaicTransformer":
61+
try:
62+
import transformers
63+
except ImportError as e:
64+
raise ImportError('transformers is not installed. '
65+
'Please install with `pip install mosaicml-composer[nlp]`') from e
66+
67+
from composer.models.bert.model import BERTModel
68+
self.validate()
69+
70+
if self.model_config:
71+
config = transformers.BertConfig.from_dict(self.model_config)
72+
elif self.pretrained_model_name is not None:
73+
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name)
74+
else:
75+
raise ValueError('One of pretrained_model_name or model_config needed.')
76+
77+
# set the number of labels ot the vocab size, used for measuring MLM accuracy
78+
config.num_labels = config.vocab_size
79+
80+
if self.use_pretrained:
81+
# TODO (Moin): handle the warnings on not using the seq_relationship head
82+
model = transformers.AutoModelForMaskedLM.from_pretrained(self.pretrained_model_name)
83+
else:
84+
model = transformers.AutoModelForMaskedLM.from_config(config) #type: ignore (thirdparty)
85+
86+
return BERTModel(
87+
module=model,
88+
config=config, #type: ignore (thirdparty)
89+
tokenizer_name=self.tokenizer_name,
90+
)

0 commit comments

Comments
 (0)