Skip to content

Commit c19fc86

Browse files
committed
wip
1 parent 15df497 commit c19fc86

File tree

6 files changed

+167
-52
lines changed

6 files changed

+167
-52
lines changed

composer/datasets/streaming_lm_datasets.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import yahp as hp
1515
from transformers.testing_utils import CaptureLogger
1616

17-
from composer.core.types import Batch
17+
from composer.core.types import Batch, DataSpec
1818
from composer.datasets.dataloader import DataloaderHparams
19-
from composer.datasets.hparams import DataloaderSpec, DatasetHparams
19+
from composer.datasets.hparams import DatasetHparams
2020
from composer.utils import dist
2121
from composer.utils.data import get_subset_dataset
2222

@@ -73,18 +73,18 @@ def _load_dataset(self):
7373
split=self.split,
7474
streaming=True)
7575

76-
def _get_approx_num_samples(self):
76+
def _get_approx_num_samples_per_device(self):
7777
try:
7878
if self.max_samples > 0:
79-
return self.max_samples
79+
return self.max_samples // dist.get_world_size()
8080
else:
8181
n_shards, samples_per_shard = CACHED_DATASET_SIZES[self.dataset_name][self.dataset_config_name][self.split]
8282
n_shards = self.max_shards if self.max_shards > 0 else n_shards
83-
return n_shards * samples_per_shard
83+
return n_shards * samples_per_shard // dist.get_world_size()
8484
except:
8585
raise NotImplementedError
8686

87-
def _get_approx_num_tokens(self):
87+
def _get_approx_num_tokens_per_device(self):
8888
return 1e12
8989

9090
def _subsample(self, device_offset, text_batch):
@@ -166,7 +166,7 @@ def _group_tokens(self, token_batch):
166166
else:
167167
raise ValueError(f"Unknown group_method: '{group_method}'")
168168

169-
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataloaderSpec:
169+
def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataSpec:
170170
assert dataloader_hparams.num_workers == 1, "LM Streaming Dataloader only supports num_workers=1"
171171

172172
try:
@@ -209,13 +209,12 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
209209
batch_size=token_sample_batch_size,
210210
)
211211

212-
# Maybe limit the number of post-processed samples
213-
if self.max_samples > 0:
214-
token_dataset = token_dataset.take(self.max_samples // dist.get_world_size())
215-
216-
# Add approx num samples and create a SizedIterableDataset
217-
sized_iterable_dataset = SizedIterableDataset(token_dataset, self._get_approx_num_samples())
212+
# Limit the number of post-processed samples
213+
num_samples_per_device = self._get_approx_num_samples_per_device()
214+
token_dataset = token_dataset.take(num_samples_per_device)
218215

216+
# HACK: create a SizedIterableDataset
217+
sized_iterable_dataset = SizedIterableDataset(token_dataset, num_samples_per_device)
219218

220219
# Get collate_fn
221220
if self.tokenizer_name in ["gpt2"]:
@@ -225,25 +224,25 @@ def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHpara
225224
collate_fn = transformers.DataCollatorForLanguageModeling(tokenizer=self.tokenizer,
226225
mlm=self.use_masked_lm,
227226
mlm_probability=self.mlm_probability)
228-
# Return DataloaderSpec
229-
return DataloaderSpec(dataloader=dataloader_hparams.initialize_object(
227+
# Return DataSpec
228+
return DataSpec(dataloader=dataloader_hparams.initialize_object(
230229
dataset=sized_iterable_dataset,
231230
batch_size=batch_size,
232231
sampler=None,
233232
drop_last=self.drop_last,
234233
collate_fn=collate_fn,
235234
),
236-
split_fn=_split_dict_fn)
235+
split_batch=_split_dict_fn)
237236

238237

239238
class SizedIterableDataset(torch.utils.data.IterableDataset):
240239

241-
def __init__(self, hf_iterable_dataset, num_samples):
240+
def __init__(self, hf_iterable_dataset, num_samples_per_device):
242241
self.hf_iterable_dataset = hf_iterable_dataset
243-
self.num_samples = num_samples
242+
self.num_samples_per_device = num_samples_per_device
244243

245244
def __iter__(self):
246245
return iter(self.hf_iterable_dataset)
247246

248247
def __len__(self):
249-
return self.num_samples
248+
return self.num_samples_per_device

composer/trainer/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def initialize_object(self, state: State, grad_clip_norm: Optional[float]):
8383
elif state.precision == Precision.FP16:
8484
deepspeed_config["fp16"] = {
8585
"enabled": True,
86-
"initial_scale_power": 16,
86+
"initial_scale_power": 0,
8787
"loss_scale_window": 2000,
8888
}
8989

composer/trainer/trainer_hparams.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,10 @@ class TrainerHparams(hp.Hparams):
204204
def validate(self):
205205
super().validate()
206206

207-
if self.deepspeed is not None:
208-
207+
if self.deepspeed is None:
209208
if self.precision == Precision.FP16:
210209
raise ValueError("FP16 precision is only supported when training with DeepSpeed.")
211-
210+
else:
212211
if isinstance(self.device, CPUDeviceHparams):
213212
raise ValueError("Training on CPUs is not supported with DeepSpeed.")
214213

composer/utils/_time_conversion.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def convert(
4242
converting to or from :attr:`TimeUnit.TOKEN`.
4343
max_training_duration (str or Time, optional): The total training duration. Required only
4444
if converting to or from :attr:`TimeUnit.DURATION`.
45-
45+
4646
Raises:
47-
ValueError: If it is not possible to perform the conversion.
47+
ValueError: If it is not possible to perform the conversion.
4848
4949
Returns:
5050
Time: The time, in the specified ``unit``.
@@ -76,12 +76,13 @@ def convert(
7676
dataset_num_tokens=dataset_num_tokens)
7777
return _convert_to_duration(time_in_max_duration_unit, max_training_duration=max_training_duration)
7878
else:
79-
converted_time = _convert_from_duration(time, max_training_duration=max_training_duration)
80-
return convert(converted_time,
81-
unit,
82-
steps_per_epoch=steps_per_epoch,
83-
samples_per_epoch=samples_per_epoch,
84-
dataset_num_tokens=dataset_num_tokens)
79+
max_training_duration_in_unit = convert(max_training_duration,
80+
unit,
81+
steps_per_epoch=steps_per_epoch,
82+
samples_per_epoch=samples_per_epoch,
83+
dataset_num_tokens=dataset_num_tokens)
84+
converted_time = _convert_from_duration(time, max_training_duration=max_training_duration_in_unit)
85+
return converted_time
8586

8687
if time.unit == TimeUnit.EPOCH:
8788
if unit == TimeUnit.BATCH:
@@ -260,7 +261,7 @@ def _convert_sample_to_batch(
260261
time (Time): The time
261262
steps_per_epoch (int): The number of optimization steps per epoch.
262263
samples_per_epoch (int): The number of samples per epoch.
263-
264+
264265
Raises:
265266
RuntimeError: Raised if ``time.unit != TimeUnit.SAMPLE``
266267
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# GPT3-125m with streaming C4 dataset
2+
3+
train_dataset:
4+
streaming_lm:
5+
dataset_name: c4
6+
dataset_config_name: en
7+
split: train
8+
max_shards: -1
9+
max_samples: 5120 # 512sa * 10ba
10+
max_seq_len: 2048
11+
group_method: concat
12+
tokenizer_name: gpt2
13+
use_masked_lm: false
14+
seed: 17
15+
shuffle: true
16+
drop_last: true
17+
val_dataset:
18+
streaming_lm:
19+
dataset_name: c4
20+
dataset_config_name: en
21+
split: validation
22+
max_shards: -1
23+
max_samples: 100
24+
max_seq_len: 2048
25+
group_method: concat
26+
tokenizer_name: gpt2
27+
use_masked_lm: false
28+
seed: 17
29+
shuffle: false
30+
drop_last: true
31+
32+
model:
33+
gpt2:
34+
use_pretrained: false
35+
tokenizer_name: gpt2
36+
model_config:
37+
activation_function: gelu_new
38+
architectures:
39+
- GPT2LMHeadModel
40+
attn_pdrop: 0.1
41+
bos_token_id: 50256
42+
embd_pdrop: 0.1
43+
eos_token_id: 50256
44+
initializer_range: 0.02
45+
layer_norm_epsilon: 1.0e-05
46+
model_type: gpt2
47+
n_embd: 2048
48+
n_head: 16
49+
n_inner: 8192
50+
n_layer: 24
51+
n_positions: 2048
52+
resid_pdrop: 0.1
53+
scale_attn_weights: true
54+
summary_activation: null
55+
summary_first_dropout: 0.1
56+
summary_proj_to_labels: true
57+
summary_type: cls_index
58+
summary_use_proj: true
59+
task_specific_params:
60+
text-generation:
61+
do_sample: true
62+
max_length: 50
63+
transformers_version: 4.11.0.dev0
64+
use_cache: true
65+
vocab_size: 50257
66+
optimizer:
67+
decoupled_adamw:
68+
lr: 2.0e-4
69+
betas:
70+
- 0.9
71+
- 0.95
72+
eps: 1.0e-08
73+
weight_decay: 0.0
74+
schedulers:
75+
- warmup:
76+
warmup_method: linear
77+
warmup_iters: 0.2dur
78+
warmup_factor: 0
79+
interval: batch
80+
- linear_decay:
81+
start_factor: 1.0
82+
end_factor: 0.0
83+
total_iters: 0.8dur
84+
interval: batch
85+
verbose: false
86+
loggers:
87+
- file:
88+
log_level: batch
89+
filename: stdout
90+
buffer_size: 1
91+
flush_every_n_batches: 100
92+
every_n_batches: 1
93+
every_n_epochs: 1
94+
max_duration: 1ep
95+
train_batch_size: 512
96+
eval_batch_size: 8 # use micro_bs_per_gpu = 1 to accomodate 10GB limit
97+
seed: 17
98+
device:
99+
gpu: {}
100+
deepspeed:
101+
zero_stage: 0
102+
# optimizer_offload: true
103+
# parameter_offload: true
104+
# overlap_comm: false
105+
# gradient_checkpointing: false
106+
dataloader:
107+
pin_memory: true
108+
persistent_workers: true
109+
num_workers: 1
110+
timeout: 0
111+
prefetch_factor: 2
112+
precision: fp16
113+
grad_clip_norm: 1.0
114+
grad_accum: 1
115+
validate_every_n_batches: 3
116+
validate_every_n_epochs: 1

composer/yamls/models/gpt2_125m_streaming.yaml renamed to composer/yamls/models/gpt3_125m.yaml

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
# GPT2-125m with streaming C4 dataset
1+
# GPT3-125m with streaming C4 dataset
22

33
train_dataset:
44
streaming_lm:
55
dataset_name: c4
66
dataset_config_name: en
77
split: train
88
max_shards: -1
9-
max_samples: 7168000
10-
max_seq_len: 1024
9+
max_samples: 2560 # 256sa * 20ba
10+
max_seq_len: 2048
1111
group_method: concat
1212
tokenizer_name: gpt2
1313
use_masked_lm: false
@@ -20,8 +20,8 @@ val_dataset:
2020
dataset_config_name: en
2121
split: validation
2222
max_shards: -1
23-
max_samples: 128000
24-
max_seq_len: 1024
23+
max_samples: 100
24+
max_seq_len: 2048
2525
group_method: concat
2626
tokenizer_name: gpt2
2727
use_masked_lm: false
@@ -44,12 +44,11 @@ model:
4444
initializer_range: 0.02
4545
layer_norm_epsilon: 1.0e-05
4646
model_type: gpt2
47-
n_ctx: 1024
4847
n_embd: 768
4948
n_head: 12
5049
n_inner: 3072
5150
n_layer: 12
52-
n_positions: 1024
51+
n_positions: 2048
5352
resid_pdrop: 0.1
5453
scale_attn_weights: true
5554
summary_activation: null
@@ -65,34 +64,35 @@ model:
6564
use_cache: true
6665
vocab_size: 50257
6766
optimizer:
68-
adamw:
67+
decoupled_adamw:
6968
lr: 6.0e-4
7069
betas:
7170
- 0.9
72-
- 0.999
71+
- 0.95
7372
eps: 1.0e-08
7473
weight_decay: 0.0
7574
schedulers:
7675
- warmup:
7776
warmup_method: linear
77+
warmup_iters: 0.2dur
7878
warmup_factor: 0
79-
interval: step
80-
warmup_iters: 140ba
81-
- cosine_decay:
82-
interval: step
83-
eta_min: 0
79+
interval: batch
80+
- linear_decay:
81+
start_factor: 1.0
82+
end_factor: 0.0
83+
total_iters: 0.8dur
84+
interval: batch
8485
verbose: false
85-
T_max: 13860ba
8686
loggers:
8787
- file:
8888
log_level: batch
8989
filename: stdout
9090
buffer_size: 1
9191
flush_every_n_batches: 100
92-
every_n_batches: 100
92+
every_n_batches: 1
9393
every_n_epochs: 1
94-
max_epochs: 1
95-
train_batch_size: 512
94+
max_duration: 1ep
95+
train_batch_size: 256
9696
eval_batch_size: 8 # use micro_bs_per_gpu = 1 to accomodate 10GB limit
9797
seed: 17
9898
device:
@@ -105,6 +105,6 @@ dataloader:
105105
prefetch_factor: 2
106106
precision: amp
107107
grad_clip_norm: 1.0
108-
grad_accum: 22
109-
validate_every_n_batches: 1000
108+
grad_accum: 1
109+
validate_every_n_batches: 3
110110
validate_every_n_epochs: 1

0 commit comments

Comments
 (0)