-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
According to the discussion in #4335 the intended functionality of save_last
is the deterministic access to the last stored checkpoint in order to resume training. The documentation states similarly:
save_last: When
True
, saves alast.ckpt
copy whenever a checkpoint file gets saved. Can be set to'link'
on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint a deterministic manner. Default:None
.
But unfortunately _save_last_checkpoint
may be called without any previous checkpoint existing. Some example configuration:
ModelCheckpoint(
filename='temp_backup',
save_last=True,
train_time_interval=datetime.timedelta(minutes=30),
)
This will create a last.ckpt
at the end of very epoch (either train or validation depending on save_on_train_epoch_end
) - though no other checkpointing has triggered before. This is due to the _save_last_checkpoint
called either way at the epoch end:
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates) # < --- here
Hence in the end (ironically) save_last
does what it was not intended to do: save the last epoch. In case of save_last='link'
this is resolved after the first triggering - afterwards the regular _save_last_checkpoint
calls will only write identical soft links to the actual checkpoint file. There are more possible configurations for this to happen:
ModelCheckpoint(
monitor='val_loss',
save_last=True,
every_n_train_steps=1000
)
Even if after 2000 steps _save_topk_checkpoint
omits the creation of a new file (via _save_monitor_checkpoint
detecting no improvement) _save_last_checkpoint
is triggered.
These additional (unintended) creations of checkpoints may cause severe delays between epochs. I am also pretty sure this is the cause of numerous issues. A quick search revealed e.g. #20075, #19845, #17916.
Another problem is actually breaking the interplay of using multiple ModelCheckpoints - creating the chance that last.ckpt
is actually not the latest checkpoint.
cpt_a = ModelCheckpoint(
monitor='val_loss',
file_name='epoch{epoch:02d}-val_loss{val_loss:.2f}',
save_last='link',
save_on_train_epoch_end=False,
enable_version_counter=False,
)
cpt_b = ModelCheckpoint(
filename="temp_backup",
save_last="link",
train_time_interval=datetime.timedelta(minutes=30),
enable_version_counter=False,
)
The intended behavior would be to have (after some while) one checkpoint file according to the best loss value, one temp_backup.ckpt
file and a softlink last.ckpt
pinpointing to whatever of the above was stored later (in case training gets interrupted and needs to be resumed). But the following may happen:
- some epoch finishes and cpt_a detects a new best loss value and stores the full file,
last.ckpt
points to that file - cpt_b gets triggered and stores the
temp_backup.ckpt
,last.ckpt
is updated to point towardstemp_backup.ckpt
- from now on the
last.ckpt
pointer is only determined by the order of checkpoint execution. For both cpt_a and cpt_b whenever another epoch finishes the call to_save_last_checkpoint
will makelast.ckpt
point to their checkpoint - teven if cpt_a detects a new best validation loss and cpt_b is triggered after cpt_alast.ckpt
will for the majority of time point totemp_backup.ckpt
Also note this example where last.ckpt
will be a corrupted symlink:
ModelCheckpoint(
filename="temp_backup",
save_last="link",
train_time_interval=datetime.timedelta(minutes=30),
)
During the first "end of epoch call" of _save_last_checkpoint
it is detected that no previous checkpoint was stored (self._last_checkpoint_saved
is initialized as empty string):
if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
but in the next iteration (next epoch) the first case is satisfied and ModelCheckpoint will symlink last.ckpt
to itself...
As a clean solution I would assume that no need for a "top-level" _save_last_checkpoint
should be required - based on a downstream trigger of _save_checkpoint
the logic to update/generate the last.ckpt
should live. As an ad-hoc solution I tested the following
class MyModelCheckpoint(ModelCheckpoint):
def _save_last_checkpoint(self, trainer: "lightning.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None:
"""Only update last checkpoint in case there has just been a new checkpoint."""
if self._last_global_step_saved == trainer.global_step:
super()._save_last_checkpoint(trainer=trainer, monitor_candidates=monitor_candidates)
which resolved all problems mentioned above. As additional speed improvement the _save_last_checkpoint
calls at the end of each epoch could also be slightly re-arranged
def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
What version are you seeing the problem on?
v2.3
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version 2.3.3
#- PyTorch Version 2.4.0
#- Python version 3.8.13
#- OS Linux
#- CUDA version 12.2
#- GPU models and configuration: Nvidia GeForce RTX 3090
#- How you installed Lightning: via pip inside a conda env
More info
No response