-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Hello,
I am using ModelCheckpoint
along with WandbLogger
and I want to create snapshot checkpoints, i.e. create a checkpoint every N
epochs. When I create my checkpoint callback and logger as given in the below snippet, the checkpoints created in my local machine are correct, i.e. there are checkpoints corresponding to epoch 49, 99, and so on. However, in my wandb project there is an artifact entry for each epoch and as a consequence both local wandb cache directory and wandb storage usage grows quickly.
I believe that the documentation of ModelCheckpoint
is unclear about how to achieve periodic checkpoints because providing only every_n_epochs=50
alone or with save_top_k=3, monitor="epoch"
only created one checkpoint in my case. Is there any other way to achieve periodic checkpointing?
I am experiencing this issue because of the save_last=True
argument combined with the implementation of the WandbLogger
which calls the following code block after checkpoints. I believe the direct upload of the last checkpoint after each epoch contradicts the described behavior in the documentation, and these checkpoints should only be uploaded after training is completed with log_model=True
. There seems to be no way to get periodic checkpoints along with the last checkpoint before training ends with the current behavior.
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
What version are you seeing the problem on?
v2.0
How to reproduce the bug
import pytorch_lightning as L
wandb_logger = L.loggers.WandbLogger(
entity=ENTITY,
save_dir=LOG_DIR,
log_model=True, # uploads the checkpoints that are created at the end of the training
)
checkpoint_callback = L.callbacks.ModelCheckpoint(
dirpath=Path(LOG_DIR).joinpath(ID, "checkpoints"),
save_last=True,
save_top_k=-1,
every_n_epochs=50
)
Error messages and logs
No response
Environment
Current environment
#- Lightning Component: Trainer, LightningModule
#- PyTorch Lightning Version: 2.0.3
#- PyTorch Version (e.g., 2.0): 2.0.1
#- Python version (e.g., 3.9): 3.11.2
#- Wandb version: 0.15.4
#- OS (e.g., Linux): Linux
#- How you installed Lightning(`conda`, `pip`, source): conda/micromamba
#- Running environment of LightningApp (e.g. local, cloud): local
More info
No response