Skip to content

ModelCheckpoint save ckpts at the end of every epoch even in step-saving strategy #20075

@leonardodalinky

Description

@leonardodalinky

Bug description

When ModelCheckpoint is enabled and no validation set is given, the trainer still save the model checkpoint at the end of every epoch.

I have turn on the every_n_train_steps=1000 and every_n_epochs=None. And the expected behavior according to the docs of ModelCheckpoint API should only save the checkpoints at every 1000 steps, and should not save at the end of each epoch.

This slows down the training process especially when the training dataset is small. And it has cost me 10TBW of my SSDs.

A quick solution is to pass save_on_train_epoch_end=False to the ModelCheckpoint.

I believe this bug is caused by the failure of checking the _should_save_on_train_epoch_end at L320.

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

And the logic at L432-437 triggers the bug when no validation dataset is given.

def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
# so let's not enforce saving at every training epoch end
if trainer.check_val_every_n_epoch != 1:
return False
# no validation means save on train epoch end
num_val_batches = (
sum(trainer.num_val_batches) if isinstance(trainer.num_val_batches, list) else trainer.num_val_batches
)
if num_val_batches == 0:
return True
# if the user runs validation multiple times per training epoch, then we run after validation
# instead of on train epoch end
return trainer.val_check_interval == 1.0

What version are you seeing the problem on?

master

How to reproduce the bug

Run lightning with following callback, with a relative small dataset (e.g., a dataset with only 100 batch in one epoch)


ModelCheckpoint(
    monitor="train_loss",
    every_n_train_steps=10,
    dirpath="checkpoints",
    filename="{step}-{train_loss:.3f}",
    save_top_k=1,
    mode="min",
    save_last=True,
)

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A800 80GB PCIe
    - NVIDIA A800 80GB PCIe
    - NVIDIA A800 80GB PCIe
    - NVIDIA A800 80GB PCIe
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.3.0
    - torch: 2.3.1+cu121
    - torchaudio: 2.3.1+cu121
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1+cu121
  • Packages:
    - absl-py: 2.1.0
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - attrs: 23.2.0
    - babel: 2.15.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - certifi: 2024.6.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - charset-normalizer: 3.3.2
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - deepspeed: 0.14.4
    - defusedxml: 0.7.1
    - dill: 0.3.8
    - distlib: 0.3.8
    - einops: 0.8.0
    - executing: 2.0.1
    - fastjsonschema: 2.20.0
    - filelock: 3.15.1
    - flash-attn: 2.5.9.post1
    - fonttools: 4.53.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2024.6.0
    - grpcio: 1.64.1
    - h11: 0.14.0
    - hjson: 3.1.0
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - huggingface-hub: 0.23.4
    - hydra-core: 1.3.2
    - identify: 2.5.36
    - idna: 3.7
    - ipykernel: 6.29.4
    - ipython: 8.25.0
    - isoduration: 20.11.0
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - json5: 0.9.25
    - jsonpointer: 3.0.0
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter-client: 8.6.2
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.1
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.2
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.2
    - kiwisolver: 1.4.5
    - lightning: 2.3.0
    - lightning-utilities: 0.11.2
    - markdown: 3.6
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mistune: 3.0.2
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - ninja: 1.11.1.1
    - nodeenv: 1.9.1
    - notebook-shim: 0.2.4
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-ml-py: 12.555.43
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.5.40
    - nvidia-nvtx-cu12: 12.1.105
    - omegaconf: 2.3.0
    - overrides: 7.7.0
    - packaging: 24.1
    - pandas: 2.2.2
    - pandocfilters: 1.5.1
    - parso: 0.8.4
    - pexpect: 4.9.0
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - pre-commit: 3.7.1
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.47
    - protobuf: 4.25.3
    - psutil: 6.0.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pycparser: 2.22
    - pydantic: 2.7.4
    - pydantic-core: 2.18.4
    - pyfunctional: 1.5.0
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - python-dateutil: 2.9.0.post0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.3.0
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - referencing: 0.35.1
    - regex: 2024.5.15
    - requests: 2.32.3
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rpds-py: 0.18.1
    - safetensors: 0.4.3
    - send2trash: 1.8.3
    - setuptools: 69.5.1
    - six: 1.16.0
    - sniffio: 1.3.1
    - soupsieve: 2.5
    - stack-data: 0.6.3
    - sympy: 1.12.1
    - tabulate: 0.9.0
    - tensorboard: 2.17.0
    - tensorboard-data-server: 0.7.2
    - tensorboardx: 2.6.2.2
    - termcolor: 2.4.0
    - terminado: 0.18.1
    - tinycss2: 1.3.0
    - tokenizers: 0.19.1
    - torch: 2.3.1+cu121
    - torchaudio: 2.3.1+cu121
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1+cu121
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - transformers: 4.42.3
    - triton: 2.3.1
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.12.2
    - tzdata: 2024.1
    - uri-template: 1.3.0
    - urllib3: 2.2.2
    - virtualenv: 20.26.2
    - wcwidth: 0.2.13
    - webcolors: 24.6.0
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.3
    - wheel: 0.43.0
    - yarl: 1.9.4
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-67-generic
    - version: Load fix #74~20.04.1-Ubuntu SMP Wed Feb 22 14:52:34 UTC 2023

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions