Skip to content

ModelCheckpoint's save_last does not adhere to documentation #20245

@godaup

Description

@godaup

Bug description

According to the discussion in #4335 the intended functionality of save_last is the deterministic access to the last stored checkpoint in order to resume training. The documentation states similarly:

save_last: When True, saves a last.ckpt copy whenever a checkpoint file gets saved. Can be set to 'link' on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint a deterministic manner. Default: None.

But unfortunately _save_last_checkpoint may be called without any previous checkpoint existing. Some example configuration:

ModelCheckpoint(
    filename='temp_backup',
    save_last=True,
    train_time_interval=datetime.timedelta(minutes=30), 
)

This will create a last.ckpt at the end of very epoch (either train or validation depending on save_on_train_epoch_end) - though no other checkpointing has triggered before. This is due to the _save_last_checkpoint called either way at the epoch end:

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Save a checkpoint at the end of the training epoch."""
    if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): 
        monitor_candidates = self._monitor_candidates(trainer)
        if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
            self._save_topk_checkpoint(trainer, monitor_candidates)
        self._save_last_checkpoint(trainer, monitor_candidates)  # < --- here

Hence in the end (ironically) save_last does what it was not intended to do: save the last epoch. In case of save_last='link' this is resolved after the first triggering - afterwards the regular _save_last_checkpoint calls will only write identical soft links to the actual checkpoint file. There are more possible configurations for this to happen:

ModelCheckpoint(
    monitor='val_loss',
    save_last=True,
    every_n_train_steps=1000
)

Even if after 2000 steps _save_topk_checkpoint omits the creation of a new file (via _save_monitor_checkpoint detecting no improvement) _save_last_checkpoint is triggered.

These additional (unintended) creations of checkpoints may cause severe delays between epochs. I am also pretty sure this is the cause of numerous issues. A quick search revealed e.g. #20075, #19845, #17916.

Another problem is actually breaking the interplay of using multiple ModelCheckpoints - creating the chance that last.ckpt is actually not the latest checkpoint.

cpt_a = ModelCheckpoint(
    monitor='val_loss',
    file_name='epoch{epoch:02d}-val_loss{val_loss:.2f}',
    save_last='link',
    save_on_train_epoch_end=False,
    enable_version_counter=False,
)
cpt_b = ModelCheckpoint(
    filename="temp_backup",
    save_last="link",
    train_time_interval=datetime.timedelta(minutes=30),
    enable_version_counter=False,
)

The intended behavior would be to have (after some while) one checkpoint file according to the best loss value, one temp_backup.ckpt file and a softlink last.ckpt pinpointing to whatever of the above was stored later (in case training gets interrupted and needs to be resumed). But the following may happen:

  • some epoch finishes and cpt_a detects a new best loss value and stores the full file, last.ckpt points to that file
  • cpt_b gets triggered and stores the temp_backup.ckpt, last.ckpt is updated to point towards temp_backup.ckpt
  • from now on the last.ckpt pointer is only determined by the order of checkpoint execution. For both cpt_a and cpt_b whenever another epoch finishes the call to _save_last_checkpoint will make last.ckpt point to their checkpoint - teven if cpt_a detects a new best validation loss and cpt_b is triggered after cpt_a last.ckpt will for the majority of time point to temp_backup.ckpt

Also note this example where last.ckpt will be a corrupted symlink:

ModelCheckpoint(
    filename="temp_backup",
    save_last="link",
    train_time_interval=datetime.timedelta(minutes=30),
)

During the first "end of epoch call" of _save_last_checkpoint it is detected that no previous checkpoint was stored (self._last_checkpoint_saved is initialized as empty string):

if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0:
    self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
    self._save_checkpoint(trainer, filepath)

but in the next iteration (next epoch) the first case is satisfied and ModelCheckpoint will symlink last.ckpt to itself...

As a clean solution I would assume that no need for a "top-level" _save_last_checkpoint should be required - based on a downstream trigger of _save_checkpoint the logic to update/generate the last.ckpt should live. As an ad-hoc solution I tested the following

class MyModelCheckpoint(ModelCheckpoint):
    def _save_last_checkpoint(self, trainer: "lightning.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None:
        """Only update last checkpoint in case there has just been a new checkpoint."""
        if self._last_global_step_saved == trainer.global_step:
            super()._save_last_checkpoint(trainer=trainer, monitor_candidates=monitor_candidates)

which resolved all problems mentioned above. As additional speed improvement the _save_last_checkpoint calls at the end of each epoch could also be slightly re-arranged

def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
    """Save a checkpoint at the end of the training epoch."""
    if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
        monitor_candidates = self._monitor_candidates(trainer)
        if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
            self._save_topk_checkpoint(trainer, monitor_candidates)
            self._save_last_checkpoint(trainer, monitor_candidates)

What version are you seeing the problem on?

v2.3

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version 2.3.3
#- PyTorch Version 2.4.0
#- Python version 3.8.13
#- OS Linux
#- CUDA version 12.2
#- GPU models and configuration: Nvidia GeForce RTX 3090
#- How you installed Lightning: via pip inside a conda env

More info

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions