Skip to content

Commit 17d2d78

Browse files
SkafteNickilantiga
authored andcommitted
Tuner cleanup on error (#21162)
* Make sure temp checkpoints are cleaned up on failed tuning * add testing * changelog --------- Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e1e2534)
1 parent 509f562 commit 17d2d78

File tree

5 files changed

+154
-44
lines changed

5 files changed

+154
-44
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
3434

3535

36+
- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))
37+
38+
39+
---
40+
3641
## [2.5.4] - 2025-08-29
3742

3843
### Fixed

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,27 @@ def _scale_batch_size(
7676
if trainer.progress_bar_callback:
7777
trainer.progress_bar_callback.disable()
7878

79-
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
80-
81-
if mode == "power":
82-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
83-
elif mode == "binsearch":
84-
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
79+
try:
80+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
8581

86-
garbage_collection_cuda()
82+
if mode == "power":
83+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
84+
elif mode == "binsearch":
85+
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
8786

88-
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
87+
garbage_collection_cuda()
8988

90-
__scale_batch_restore_params(trainer, params)
89+
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
90+
except Exception as ex:
91+
raise ex
92+
finally:
93+
__scale_batch_restore_params(trainer, params)
9194

92-
if trainer.progress_bar_callback:
93-
trainer.progress_bar_callback.enable()
95+
if trainer.progress_bar_callback:
96+
trainer.progress_bar_callback.enable()
9497

95-
trainer._checkpoint_connector.restore(ckpt_path)
96-
trainer.strategy.remove_checkpoint(ckpt_path)
98+
trainer._checkpoint_connector.restore(ckpt_path)
99+
trainer.strategy.remove_checkpoint(ckpt_path)
97100

98101
return new_size
99102

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -257,40 +257,45 @@ def _lr_find(
257257
# Initialize lr finder object (stores results)
258258
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
259259

260-
# Configure optimizer and scheduler
261-
lr_finder._exchange_scheduler(trainer)
262-
263-
# Fit, lr & loss logged in callback
264-
_try_loop_run(trainer, params)
265-
266-
# Prompt if we stopped early
267-
if trainer.global_step != num_training + start_steps:
268-
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
269-
270-
# Transfer results from callback to lr finder object
271-
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
272-
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
273-
274-
__lr_finder_restore_params(trainer, params)
275-
276-
if trainer.progress_bar_callback:
277-
trainer.progress_bar_callback.enable()
278-
279-
# Update results across ranks
280-
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
282-
# Restore initial state of model (this will also restore the original optimizer state)
283-
trainer._checkpoint_connector.restore(ckpt_path)
284-
trainer.strategy.remove_checkpoint(ckpt_path)
285-
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
286-
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
287-
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
288-
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
289-
trainer.fit_loop.setup_data()
260+
lr_finder_finished = False
261+
try:
262+
# Configure optimizer and scheduler
263+
lr_finder._exchange_scheduler(trainer)
264+
265+
# Fit, lr & loss logged in callback
266+
_try_loop_run(trainer, params)
267+
268+
# Prompt if we stopped early
269+
if trainer.global_step != num_training + start_steps:
270+
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
271+
272+
# Transfer results from callback to lr finder object
273+
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
274+
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
275+
276+
__lr_finder_restore_params(trainer, params)
277+
278+
if trainer.progress_bar_callback:
279+
trainer.progress_bar_callback.enable()
280+
281+
# Update results across ranks
282+
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
283+
lr_finder_finished = True
284+
except Exception as ex:
285+
raise ex
286+
finally:
287+
# Restore initial state of model (this will also restore the original optimizer state)
288+
trainer._checkpoint_connector.restore(ckpt_path)
289+
trainer.strategy.remove_checkpoint(ckpt_path)
290+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
291+
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
292+
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
293+
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
294+
trainer.fit_loop.setup_data()
290295

291296
# Apply LR suggestion after restoring so it persists for the real training run
292297
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293-
if update_attr:
298+
if update_attr and lr_finder_finished:
294299
lr = lr_finder.suggestion()
295300
if lr is not None:
296301
# update the attribute on the LightningModule (e.g., lr or learning_rate)

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import math
1617
import os
@@ -750,3 +751,52 @@ def __init__(self):
750751
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
751752
"Gradients should differ significantly in exponential mode when using proper spacing"
752753
)
754+
755+
756+
def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
757+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""
758+
759+
class FailingModel(BoringModel):
760+
def __init__(self, fail_on_step=2):
761+
super().__init__()
762+
self.fail_on_step = fail_on_step
763+
self.current_step = 0
764+
self.learning_rate = 1e-3
765+
766+
def training_step(self, batch, batch_idx):
767+
self.current_step += 1
768+
if self.current_step >= self.fail_on_step:
769+
raise RuntimeError("Intentional failure for testing cleanup")
770+
return super().training_step(batch, batch_idx)
771+
772+
def configure_optimizers(self):
773+
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
774+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
775+
return [optimizer], [lr_scheduler]
776+
777+
model = FailingModel()
778+
lr_finder = LearningRateFinder(num_training_steps=5)
779+
780+
trainer = Trainer(
781+
default_root_dir=tmp_path,
782+
max_epochs=1,
783+
enable_checkpointing=False,
784+
enable_progress_bar=False,
785+
enable_model_summary=False,
786+
logger=False,
787+
callbacks=[lr_finder],
788+
)
789+
790+
# Check no lr_find checkpoint files exist initially
791+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
792+
assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"
793+
794+
# Run lr finder and expect it to fail
795+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
796+
trainer.fit(model)
797+
798+
# Check that no lr_find checkpoint files are left behind
799+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
800+
assert len(lr_find_checkpoints) == 0, (
801+
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
802+
)

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import os
1617
from copy import deepcopy
@@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
486487

487488
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
488489
assert trainer.num_val_batches[0] != steps_per_trial
490+
491+
492+
def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
493+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""
494+
495+
class FailingModel(BoringModel):
496+
def __init__(self, fail_on_step=2):
497+
super().__init__()
498+
self.fail_on_step = fail_on_step
499+
self.current_step = 0
500+
self.batch_size = 2
501+
502+
def training_step(self, batch, batch_idx):
503+
self.current_step += 1
504+
if self.current_step >= self.fail_on_step:
505+
raise RuntimeError("Intentional failure for testing cleanup")
506+
return super().training_step(batch, batch_idx)
507+
508+
def train_dataloader(self):
509+
return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)
510+
511+
model = FailingModel()
512+
batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
513+
trainer = Trainer(
514+
default_root_dir=tmp_path,
515+
max_epochs=1,
516+
enable_checkpointing=False,
517+
enable_progress_bar=False,
518+
enable_model_summary=False,
519+
logger=False,
520+
callbacks=[batch_size_finder],
521+
)
522+
523+
# Check no scale_batch_size checkpoint files exist initially
524+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
525+
assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"
526+
527+
# Run batch size scaler and expect it to fail
528+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
529+
trainer.fit(model)
530+
531+
# Check that no scale_batch_size checkpoint files are left behind
532+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
533+
assert len(scale_checkpoints) == 0, (
534+
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
535+
)

0 commit comments

Comments
 (0)