-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
When ModelCheckpoint
is enabled and no validation set is given, the trainer still save the model checkpoint at the end of every epoch.
I have turn on the every_n_train_steps=1000
and every_n_epochs=None
. And the expected behavior according to the docs of ModelCheckpoint API should only save the checkpoints at every 1000 steps, and should not save at the end of each epoch.
This slows down the training process especially when the training dataset is small. And it has cost me 10TBW of my SSDs.
A quick solution is to pass save_on_train_epoch_end=False
to the ModelCheckpoint
.
I believe this bug is caused by the failure of checking the _should_save_on_train_epoch_end
at L320.
pytorch-lightning/src/lightning/pytorch/callbacks/model_checkpoint.py
Lines 318 to 324 in bf25167
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
"""Save a checkpoint at the end of the training epoch.""" | |
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): | |
monitor_candidates = self._monitor_candidates(trainer) | |
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: | |
self._save_topk_checkpoint(trainer, monitor_candidates) | |
self._save_last_checkpoint(trainer, monitor_candidates) |
And the logic at L432-437 triggers the bug when no validation dataset is given.
pytorch-lightning/src/lightning/pytorch/callbacks/model_checkpoint.py
Lines 423 to 441 in bf25167
def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: | |
if self._save_on_train_epoch_end is not None: | |
return self._save_on_train_epoch_end | |
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded | |
# so let's not enforce saving at every training epoch end | |
if trainer.check_val_every_n_epoch != 1: | |
return False | |
# no validation means save on train epoch end | |
num_val_batches = ( | |
sum(trainer.num_val_batches) if isinstance(trainer.num_val_batches, list) else trainer.num_val_batches | |
) | |
if num_val_batches == 0: | |
return True | |
# if the user runs validation multiple times per training epoch, then we run after validation | |
# instead of on train epoch end | |
return trainer.val_check_interval == 1.0 |
What version are you seeing the problem on?
master
How to reproduce the bug
Run lightning with following callback, with a relative small dataset (e.g., a dataset with only 100 batch in one epoch)
ModelCheckpoint(
monitor="train_loss",
every_n_train_steps=10,
dirpath="checkpoints",
filename="{step}-{train_loss:.3f}",
save_top_k=1,
mode="min",
save_last=True,
)
Error messages and logs
No response
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A800 80GB PCIe
- NVIDIA A800 80GB PCIe
- NVIDIA A800 80GB PCIe
- NVIDIA A800 80GB PCIe
- available: True
- version: 12.1 - Lightning:
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.3.0
- torch: 2.3.1+cu121
- torchaudio: 2.3.1+cu121
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1+cu121 - Packages:
- absl-py: 2.1.0
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- asttokens: 2.4.1
- async-lru: 2.0.4
- attrs: 23.2.0
- babel: 2.15.0
- beautifulsoup4: 4.12.3
- bleach: 6.1.0
- certifi: 2024.6.2
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.1
- decorator: 5.1.1
- deepspeed: 0.14.4
- defusedxml: 0.7.1
- dill: 0.3.8
- distlib: 0.3.8
- einops: 0.8.0
- executing: 2.0.1
- fastjsonschema: 2.20.0
- filelock: 3.15.1
- flash-attn: 2.5.9.post1
- fonttools: 4.53.0
- fqdn: 1.5.1
- frozenlist: 1.4.1
- fsspec: 2024.6.0
- grpcio: 1.64.1
- h11: 0.14.0
- hjson: 3.1.0
- httpcore: 1.0.5
- httpx: 0.27.0
- huggingface-hub: 0.23.4
- hydra-core: 1.3.2
- identify: 2.5.36
- idna: 3.7
- ipykernel: 6.29.4
- ipython: 8.25.0
- isoduration: 20.11.0
- jedi: 0.19.1
- jinja2: 3.1.4
- json5: 0.9.25
- jsonpointer: 3.0.0
- jsonschema: 4.22.0
- jsonschema-specifications: 2023.12.1
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.1
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.2
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.2
- kiwisolver: 1.4.5
- lightning: 2.3.0
- lightning-utilities: 0.11.2
- markdown: 3.6
- markupsafe: 2.1.5
- matplotlib: 3.9.0
- matplotlib-inline: 0.1.7
- mistune: 3.0.2
- mpmath: 1.3.0
- multidict: 6.0.5
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.3
- ninja: 1.11.1.1
- nodeenv: 1.9.1
- notebook-shim: 0.2.4
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-ml-py: 12.555.43
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.5.40
- nvidia-nvtx-cu12: 12.1.105
- omegaconf: 2.3.0
- overrides: 7.7.0
- packaging: 24.1
- pandas: 2.2.2
- pandocfilters: 1.5.1
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- pre-commit: 3.7.1
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.47
- protobuf: 4.25.3
- psutil: 6.0.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pycparser: 2.22
- pydantic: 2.7.4
- pydantic-core: 2.18.4
- pyfunctional: 1.5.0
- pygments: 2.18.0
- pyparsing: 3.1.2
- python-dateutil: 2.9.0.post0
- python-json-logger: 2.0.7
- pytorch-lightning: 2.3.0
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- referencing: 0.35.1
- regex: 2024.5.15
- requests: 2.32.3
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rpds-py: 0.18.1
- safetensors: 0.4.3
- send2trash: 1.8.3
- setuptools: 69.5.1
- six: 1.16.0
- sniffio: 1.3.1
- soupsieve: 2.5
- stack-data: 0.6.3
- sympy: 1.12.1
- tabulate: 0.9.0
- tensorboard: 2.17.0
- tensorboard-data-server: 0.7.2
- tensorboardx: 2.6.2.2
- termcolor: 2.4.0
- terminado: 0.18.1
- tinycss2: 1.3.0
- tokenizers: 0.19.1
- torch: 2.3.1+cu121
- torchaudio: 2.3.1+cu121
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1+cu121
- tornado: 6.4.1
- tqdm: 4.66.4
- traitlets: 5.14.3
- transformers: 4.42.3
- triton: 2.3.1
- types-python-dateutil: 2.9.0.20240316
- typing-extensions: 4.12.2
- tzdata: 2024.1
- uri-template: 1.3.0
- urllib3: 2.2.2
- virtualenv: 20.26.2
- wcwidth: 0.2.13
- webcolors: 24.6.0
- webencodings: 0.5.1
- websocket-client: 1.8.0
- werkzeug: 3.0.3
- wheel: 0.43.0
- yarl: 1.9.4 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.9
- release: 5.15.0-67-generic
- version: Load fix #74~20.04.1-Ubuntu SMP Wed Feb 22 14:52:34 UTC 2023
More info
No response