-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Proposals
This is not so much a bug report as an RFC to clarify the ModelCheckpoint
callback arguments:
-
save_last
: to me, this means that whenever we save a checkpoint, we save a checkpoint with filename"last.ckpt"
. This provides a pre-determined checkpoint name, which is very helpful for resuming from failures. Importantly, it should not determine when checkpoints are saved. Currently it's easy to confuse this parameter to mean "save the checkpoint after the last epoch," which I think should be split out as a separate argument. This distinction would also clarify the typing and validation: there's no need for it to be anOptional[bool]
: either we save a checkpoint as"last.ckpt"
or not. So it could be a regularbool
. -
There's an inefficiency right now where we generate the checkpoint dict twice if
save_last=True
. For techniques like ZeRO that deal with sharded optimizer states, each checkpoint dict creation triggers communications across all ranks. Instead, we should gather the checkpoint dict once, and then save to different file paths accordingly (cc @justusschock @awaelchli @akihironitta @rohitgr7 @carmocca @ninginthecloud @jjenniferdai @SeanNaren, @blefaudeux) -
save_top_k
: sincemonitor
isNone
by default, this should forcesave_top_k
to be -1. The counterargument is that this can cause storage concerns. But I think this is easily correctable on the user-side: configuresave_top_k
+monitor
-
period
: we should rename this asevery_n_epochs
. this opens up extensions for checkpointing afterevery_n_steps
during training and checkpointing after a specified time interval. With those extensions in mind,period
is ambiguous. Another request here is to change the default filename from"{epoch}"
to"{epoch}-{global_step}"
to better support mid-epoch checkpointing