Skip to content

Conversation

nopperl
Copy link
Contributor

@nopperl nopperl commented Feb 13, 2024

When loading a checkpoint with a different tp degree from the configured tp degree, the following error is raised:

Traceback (most recent call last):
  File "/home/nanotron/run_train.py", line 132, in <module>
    trainer = DistributedTrainer(config_file)
  File "/home/nanotron/src/nanotron/trainer.py", line 162, in __init__
    load_optimizer(
  File "/home/conda/envs/linux/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nanotron/src/nanotron/serialize/optimizer.py", line 222, in load_optimizer
    ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
KeyError: None

This happens only for the model.lm_head.pp_block.weight parameter. I assume this is because the optimizer states for this parameter are stored under the tied model.token_position_embeddings.pp_block.token_embedding.weight parameter. This PR fixes this by skipping trying to load the lm_head optimizer states. This is similar to weight loading, where the model.token_position_embeddings.pp_block.token_embedding.weight weights are loaded for model.lm_head.pp_block.weight (see https://github.com/huggingface/nanotron/blob/main/src/nanotron/serialize/weights.py#L347), but I think the optimizer states can be skipped.

To reproduce:

Setup config files:

cat > examples/debug_topology_agnostic.yaml << EOL
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml

checkpoints:
  checkpoint_interval: 10
  checkpoints_path: checkpoints/debug_topology_agnostic
  checkpoints_path_is_shared_file_system: true
  save_initial_state: false
data:
  dataset:
  num_loading_workers: 1
  seed: 42
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: false
  project: debug
  run: tiny_llama
  seed: 42
  step: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: float16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 32
    initializer_range: 0.02
    intermediate_size: 64
    is_llama_config: true
    max_position_embeddings: 256
    num_attention_heads: 4
    num_hidden_layers: 20
    num_key_value_heads: 4
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 256
optimizer:
  accumulate_grad_in_fp32: true
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-08
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_steps: 8
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 1
  pp: 1
  pp_engine: 1f1b
  recompute_granularity: SELECTIVE
  tp: 4
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 1
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 2
  sequence_length: 32
  train_steps: 10
  val_check_interval: -1
EOL
cat > examples/debug_topology_agnostic_continue.yaml << EOL
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml

checkpoints:
  checkpoint_interval: 10
  checkpoints_path: checkpoints/debug_topology_agnostic_continue/
  checkpoints_path_is_shared_file_system: true
  resume_checkpoint_path: checkpoints/debug_topology_agnostic
  save_initial_state: false
data:
  dataset:
  num_loading_workers: 1
  seed: 42
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: false
  project: debug
  run: tiny_llama
  seed: 42
  step: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: float16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 32
    initializer_range: 0.02
    intermediate_size: 64
    is_llama_config: true
    max_position_embeddings: 256
    num_attention_heads: 4
    num_hidden_layers: 20
    num_key_value_heads: 4
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 256
optimizer:
  accumulate_grad_in_fp32: true
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-08
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_steps: 8
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 1
  pp: 1
  pp_engine: 1f1b
  recompute_granularity: SELECTIVE
  tp: 2
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 1
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 2
  sequence_length: 32
  train_steps: 20
  val_check_interval: -1
EOL

Train first using tp=4:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml

Then continue with tp=2:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml

On main, this will lead to the above error.

@nopperl
Copy link
Contributor Author

nopperl commented Feb 15, 2024

closing in favour of #71

@nopperl nopperl closed this Feb 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant