Skip to content

Periodic checkpoints logged for every epoch to wandb with ModelCheckpoint & WandbLogger  #17916

@gsaltintas

Description

@gsaltintas

Bug description

Hello,
I am using ModelCheckpoint along with WandbLogger and I want to create snapshot checkpoints, i.e. create a checkpoint every N epochs. When I create my checkpoint callback and logger as given in the below snippet, the checkpoints created in my local machine are correct, i.e. there are checkpoints corresponding to epoch 49, 99, and so on. However, in my wandb project there is an artifact entry for each epoch and as a consequence both local wandb cache directory and wandb storage usage grows quickly.

I believe that the documentation of ModelCheckpoint is unclear about how to achieve periodic checkpoints because providing only every_n_epochs=50 alone or with save_top_k=3, monitor="epoch" only created one checkpoint in my case. Is there any other way to achieve periodic checkpointing?

I am experiencing this issue because of the save_last=True argument combined with the implementation of the WandbLogger which calls the following code block after checkpoints. I believe the direct upload of the last checkpoint after each epoch contradicts the described behavior in the documentation, and these checkpoints should only be uploaded after training is completed with log_model=True. There seems to be no way to get periodic checkpoints along with the last checkpoint before training ends with the current behavior.

# log checkpoints as artifacts
        if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
            self._scan_and_log_checkpoints(checkpoint_callback)

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import pytorch_lightning as L

wandb_logger = L.loggers.WandbLogger(
                entity=ENTITY,
                save_dir=LOG_DIR,
                log_model=True,  # uploads the checkpoints that are created at the end of the training
            )

checkpoint_callback = L.callbacks.ModelCheckpoint(
                dirpath=Path(LOG_DIR).joinpath(ID, "checkpoints"),
                save_last=True,
                save_top_k=-1,
                every_n_epochs=50
            )

Error messages and logs

No response

Environment

Current environment
#- Lightning Component: Trainer, LightningModule
#- PyTorch Lightning Version: 2.0.3
#- PyTorch Version (e.g., 2.0): 2.0.1
#- Python version (e.g., 3.9): 3.11.2
#- Wandb version: 0.15.4
#- OS (e.g., Linux): Linux
#- How you installed Lightning(`conda`, `pip`, source): conda/micromamba
#- Running environment of LightningApp (e.g. local, cloud): local

More info

No response

cc @lantiga @morganmcg1 @borisdayma @scottire @parambharat

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions