Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
from composer.datasets.ade20k import ADE20kDatasetHparams as ADE20kDatasetHparams
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
from composer.datasets.c4 import C4DatasetHparams as C4DatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
Expand Down
360 changes: 360 additions & 0 deletions composer/datasets/c4.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion composer/datasets/dataset_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from composer.datasets.ade20k import ADE20kDatasetHparams
from composer.datasets.brats import BratsDatasetHparams
from composer.datasets.c4 import C4DatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams
from composer.datasets.glue import GLUEHparams
from composer.datasets.imagenet import ImagenetDatasetHparams
Expand All @@ -15,7 +16,8 @@
"cifar10": CIFAR10DatasetHparams,
"mnist": MNISTDatasetHparams,
"lm": LMDatasetHparams,
"glue": GLUEHparams
"glue": GLUEHparams,
"c4": C4DatasetHparams,
}


Expand Down
5 changes: 3 additions & 2 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,10 @@ def validate(self):
super().validate()

if self.deepspeed is not None:
zero_stage = cast(int, self.deepspeed.get("zero_stage", 0))
self.deepspeed["zero_stage"] = cast(int, self.deepspeed.get("zero_stage", 0))
self.deepspeed["steps_per_print"] = cast(int, self.deepspeed.get("steps_per_print", 1e20))

if self.deterministic_mode and zero_stage > 0:
if self.deterministic_mode and self.deepspeed["zero_stage"] > 0:
raise ValueError("Deepspeed with zero stage > 0 is not compatible with deterministic mode")

if isinstance(self.device, CPUDeviceHparams):
Expand Down
86 changes: 86 additions & 0 deletions composer/yamls/models/gpt3_125m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
train_dataset:
c4:
split: train
max_samples: 3584000 # Compute-optimal 7.3e9 tok ~= 256[bs] * 14000[ba] * 2048[msl] = 3584000[sa] * 2048[msl]
max_seq_len: 2048
tokenizer_name: gpt2
group_method: concat
seed: 17
shuffle: true
drop_last: true
val_dataset:
c4:
split: validation
max_samples: 102400 # Approx 100k samples
max_seq_len: 2048
tokenizer_name: gpt2
group_method: concat
seed: 17
shuffle: false
drop_last: false
model:
gpt2:
use_pretrained: false
tokenizer_name: gpt2
model_config:
activation_function: gelu_new
architectures:
- GPT2LMHeadModel
attn_pdrop: 0.0
bos_token_id: 50256
embd_pdrop: 0.0
eos_token_id: 50256
initializer_range: 0.02
layer_norm_epsilon: 1.0e-05
model_type: gpt2
n_embd: 768
n_head: 12
n_inner: 3072
n_layer: 12
n_positions: 2048
resid_pdrop: 0.0
scale_attn_weights: true
summary_activation: null
summary_first_dropout: 0.0
summary_proj_to_labels: true
summary_type: cls_index
summary_use_proj: true
task_specific_params:
text-generation:
do_sample: true
max_length: 50
transformers_version: 4.16.2
use_cache: true
vocab_size: 50257
optimizer:
decoupled_adamw:
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0
schedulers:
- cosine_decay_with_warmup:
warmup_time: 0.01dur
loggers:
- tqdm: {}
max_duration: 1ep
train_batch_size: 256 # 0.5e6 tok ~= 256[bs] * 2048[msl]
grad_accum: 2 # 256[bs] / 8[devices] / 16[per_gpu_microbatch_size] = 2[ga], assuming 8xA100-80GB
eval_batch_size: 128 # 128[bs] / 8[devices] = 16[per_gpu_microbatch_size], assuming 8xA100-80GB
seed: 17
device:
gpu: {}
dataloader:
pin_memory: true
persistent_workers: true
num_workers: 1
timeout: 0
prefetch_factor: 2
deepspeed:
zero_stage: 0
precision: fp16
grad_clip_norm: 1.0
validate_every_n_batches: 1000
validate_every_n_epochs: 1
69 changes: 42 additions & 27 deletions tests/test_dataset_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,55 @@

import pytest

from composer.datasets import (ADE20kDatasetHparams, BratsDatasetHparams, CIFAR10DatasetHparams, DataloaderHparams,
DatasetHparams, GLUEHparams, ImagenetDatasetHparams, LMDatasetHparams,
from composer.datasets import (ADE20kDatasetHparams, BratsDatasetHparams, C4DatasetHparams, CIFAR10DatasetHparams,
DataloaderHparams, DatasetHparams, GLUEHparams, ImagenetDatasetHparams, LMDatasetHparams,
MNISTDatasetHparams, SyntheticHparamsMixin)
from composer.trainer.trainer_hparams import dataset_registry

# for testing, we provide values for required hparams fields
# to initialize test hparams objects
default_required_fields: Dict[Type[DatasetHparams], Callable[[], DatasetHparams]] = {
# hparams with empty dicts have no required fields
CIFAR10DatasetHparams: lambda: CIFAR10DatasetHparams(
is_train=False,
download=False,
),
ADE20kDatasetHparams: lambda: ADE20kDatasetHparams(is_train=False),
BratsDatasetHparams: lambda: BratsDatasetHparams(is_train=False,),
ImagenetDatasetHparams: lambda: ImagenetDatasetHparams(
is_train=False,
crop_size=224,
resize_size=-1,
),
MNISTDatasetHparams: lambda: MNISTDatasetHparams(
is_train=False,
download=False,
),
LMDatasetHparams: lambda: LMDatasetHparams(
datadir=["hello"],
split='train',
tokenizer_name='gpt2',
),
GLUEHparams: lambda: GLUEHparams(
task="rte",
tokenizer_name="bert-base-uncased",
split="train",
),
CIFAR10DatasetHparams:
lambda: CIFAR10DatasetHparams(
is_train=False,
download=False,
),
ADE20kDatasetHparams:
lambda: ADE20kDatasetHparams(is_train=False),
BratsDatasetHparams:
lambda: BratsDatasetHparams(is_train=False,),
ImagenetDatasetHparams:
lambda: ImagenetDatasetHparams(
is_train=False,
crop_size=224,
resize_size=-1,
),
MNISTDatasetHparams:
lambda: MNISTDatasetHparams(
is_train=False,
download=False,
),
LMDatasetHparams:
lambda: LMDatasetHparams(
datadir=["hello"],
split='train',
tokenizer_name='gpt2',
),
GLUEHparams:
lambda: GLUEHparams(
task="rte",
tokenizer_name="bert-base-uncased",
split="train",
),
C4DatasetHparams:
lambda: C4DatasetHparams(
split="train",
max_samples=1000,
max_seq_len=100,
tokenizer_name="gpt2",
group_method="concat",
),
}


Expand Down