Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion composer/algorithms/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def match(self, event: Event, state: State) -> bool:
if event in [Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT] and self.ema_started:
checkpoint_savers = [cb for cb in state.callbacks if isinstance(cb, CheckpointSaver)]
for checkpoint_saver in checkpoint_savers:
if checkpoint_saver.checkpoint_save_interval(state, event) is True:
if checkpoint_saver.save_interval(state, event) is True:
return True

# Otherwise, always run on some events after ema has started
Expand Down
103 changes: 49 additions & 54 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State,
and at the end of training.

Returns:
Callable[[State, Event], bool]: A function that can be passed as the ``checkpoint_save_interval``
Callable[[State, Event], bool]: A function that can be passed as the ``save_interval``
argument into the :class:`.CheckpointSaver`.
"""
if isinstance(interval, str):
Expand All @@ -55,7 +55,7 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State,
raise NotImplementedError(
f'Unknown checkpointing interval: {interval.unit}. Must be TimeUnit.EPOCH or TimeUnit.BATCH.')

def checkpoint_save_interval(state: State, event: Event):
def save_interval(state: State, event: Event):
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

Expand All @@ -75,15 +75,15 @@ def checkpoint_save_interval(state: State, event: Event):

return False

return checkpoint_save_interval
return save_interval


class CheckpointSaver(Callback): # noqa: D101
__doc__ = f"""Callback to save checkpoints.

.. note::

If the ``checkpoint_save_path`` argument is specified when constructing the :class:`.Trainer`, then the :class:`.CheckpointSaver`
If the ``folder`` argument is specified when constructing the :class:`.Trainer`, then the :class:`.CheckpointSaver`
callback need not be constructed manually. However, for advanced checkpointing use cases
(such as saving a weights-only checkpoint at one interval and the full training state
at another interval), instance(s) of this :class:`.CheckpointSaver` callback can be specified in the
Expand All @@ -99,16 +99,16 @@ class CheckpointSaver(Callback): # noqa: D101

>>> trainer = Trainer(..., callbacks=[
... CheckpointSaver(
... checkpoint_save_path='{{run_name}}/checkpoints',
... checkpoint_filename="ep{{epoch}}-ba{{batch}}-rank{{rank}}",
... latest_checkpoint_filename="latest-rank{{rank}}",
... checkpoint_save_interval="1ep",
... folder='{{run_name}}/checkpoints',
... filename="ep{{epoch}}-ba{{batch}}-rank{{rank}}",
... latest_filename="latest-rank{{rank}}",
... save_interval="1ep",
... weights_only=False,
... )
... ])

Args:
checkpoint_save_path (str, optional): Format string for the checkpoint_save_path where checkpoints will be saved.
folder (str, optional): Format string for the save_folder where checkpoints will be saved.
Default: ``'{{run_name}}/checkpoints'``.

The following format variables are available:
Expand All @@ -120,10 +120,10 @@ class CheckpointSaver(Callback): # noqa: D101
When training with multiple devices (i.e. GPUs), ensure that ``'{{rank}}'`` appears in the format.
Otherwise, multiple processes may attempt to write to the same file.

checkpoint_filename (str, optional): A format string describing how to name checkpoints.
filename (str, optional): A format string describing how to name checkpoints.
Default: ``'ep{{epoch}}-ba{{batch}}-rank{{rank}}.pt'``.

Checkpoints will be saved approximately to ``{{checkpoint_save_path}}/{{checkpoint_filename.format(...)}}``.
Checkpoints will be saved approximately to ``{{folder}}/{{filename.format(...)}}``.

The following format variables are available:

Expand All @@ -136,7 +136,7 @@ class CheckpointSaver(Callback): # noqa: D101

* When using DeepSpeed, each rank will save a checkpoint file in tarball format. DeepSpeed
requires tarball format, as it saves model and optimizer states in separate files.
Ensure that ``'{{rank}}'`` appears within the ``checkpoint_filename``. Otherwise, multiple ranks
Ensure that ``'{{rank}}'`` appears within the ``filename``. Otherwise, multiple ranks
may attempt to write to the same file(s), leading to corrupted checkpoints. If no tarball file
extension is specified, ``'.tar'`` will be used.

Expand All @@ -152,7 +152,7 @@ class CheckpointSaver(Callback): # noqa: D101
Consider the following scenario where:

* The :attr:`~.State.run_name` is ``'awesome-training-run'``
* The default ``checkpoint_save_path='{{run_name}}/checkpoints'`` is used.
* The default ``folder='{{run_name}}/checkpoints'`` is used.
* The default ``name='ep{{epoch}}-ba{{batch}}-rank{{rank}}'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.
Expand All @@ -175,15 +175,15 @@ class CheckpointSaver(Callback): # noqa: D101

.. seealso:: :doc:`Artifact Logging</trainer/artifact_logging>` for notes for file artifact logging.

The same format variables for ``checkpoint_filename`` are available.
The same format variables for ``filename`` are available.

Leading slashes (``'/'``) will be stripped.

To disable logging trace files as file artifacts, set this parameter to ``None``.
latest_checkpoint_filename (str, optional): A format string for a symlink which points to the last saved checkpoint.
latest_filename (str, optional): A format string for a symlink which points to the last saved checkpoint.
Default: ``'latest-rank{{rank}}.pt'``.

Symlinks will be created approximately at ``{{checkpoint_save_path}}/{{latest_checkpoint_filename.format(...)}}``.
Symlinks will be created approximately at ``{{folder}}/{{latest_filename.format(...)}}``.

The same format variables as for ``name`` are available.

Expand All @@ -192,9 +192,9 @@ class CheckpointSaver(Callback): # noqa: D101
Consider the following scenario, where:

* The :attr:`~.State.run_name` is 'awesome-training-run'
* The default ``checkpoint_save_path='{{run_name}}/checkpoints'`` is used.
* The default ``folder='{{run_name}}/checkpoints'`` is used.
* The default ``name='ep{{epoch}}-ba{{batch}}-rank{{rank}}'`` is used.
* The default ``latest_checkpoint_filename='latest-rank{{rank}}'`` is used.
* The default ``latest_filename='latest-rank{{rank}}'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.

Expand All @@ -220,21 +220,21 @@ class CheckpointSaver(Callback): # noqa: D101
Default: ``'{{run_name}}/checkpoints/latest-rank{{rank}}"``.

Whenever a new checkpoint is saved, a symlink artifact is created or updated to point to the latest checkpoint's ``artifact_name``.
The artifact name will be determined by this format string. This parameter has no effect if ``latest_checkpoint_filename`` or ``artifact_name`` is ``None``.
The artifact name will be determined by this format string. This parameter has no effect if ``latest_filename`` or ``artifact_name`` is ``None``.

.. seealso:: :doc:`Artifact Logging</trainer/artifact_logging>` for notes for file artifact logging.

The same format variables for ``checkpoint_filename`` are available.
The same format variables for ``filename`` are available.

Leading slashes (``'/'``) will be stripped.

To disable symlinks in logger, set this parameter to ``None``.

overwrite (bool, optional): Whether existing checkpoints should be overridden.
If ``False`` (the default), then the ``checkpoint_save_path`` must not exist or must not contain checkpoints which may conflict
If ``False`` (the default), then the ``folder`` must not exist or must not contain checkpoints which may conflict
with the current run. Default: ``False``.

checkpoint_save_interval (Time | str | int | (State, Event) -> bool): A :class:`.Time`, time-string, integer (in epochs),
save_interval (Time | str | int | (State, Event) -> bool): A :class:`.Time`, time-string, integer (in epochs),
or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved.

If an integer, checkpoints will be saved every n epochs.
Expand Down Expand Up @@ -281,27 +281,26 @@ class CheckpointSaver(Callback): # noqa: D101

def __init__(
self,
checkpoint_save_path: str = '{run_name}/checkpoints',
checkpoint_filename: str = 'ep{epoch}-ba{batch}-rank{rank}.pt',
folder: str = '{run_name}/checkpoints',
filename: str = 'ep{epoch}-ba{batch}-rank{rank}.pt',
artifact_name: Optional[str] = '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}',
latest_checkpoint_filename: Optional[str] = 'latest-rank{rank}.pt',
latest_filename: Optional[str] = 'latest-rank{rank}.pt',
latest_artifact_name: Optional[str] = '{run_name}/checkpoints/latest-rank{rank}',
checkpoint_save_interval: Union[Time, str, int, Callable[[State, Event], bool]] = '1ep',
save_interval: Union[Time, str, int, Callable[[State, Event], bool]] = '1ep',
*,
overwrite: bool = False,
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
):
if not callable(checkpoint_save_interval):
checkpoint_save_interval = checkpoint_periodically(checkpoint_save_interval)
self.checkpoint_save_interval = checkpoint_save_interval
if not callable(save_interval):
save_interval = checkpoint_periodically(save_interval)
self.save_interval = save_interval
self.last_checkpoint_batch: Optional[Time] = None

self.checkpoint_save_path = checkpoint_save_path
self.folder = folder

self.checkpoint_filename = PartialFilePath(checkpoint_filename.lstrip('/'), checkpoint_save_path)
self.latest_checkpoint_filename = PartialFilePath(latest_checkpoint_filename.lstrip('/'),
checkpoint_save_path) if latest_checkpoint_filename else None
self.filename = PartialFilePath(filename.lstrip('/'), folder)
self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), folder) if latest_filename else None

self.artifact_name = PartialFilePath(artifact_name) if artifact_name else None
self.latest_artifact_name = PartialFilePath(latest_artifact_name) if latest_artifact_name else None
Expand All @@ -312,34 +311,31 @@ def __init__(
self.weights_only = weights_only

def init(self, state: State, logger: Logger) -> None:
checkpoint_save_path = format_name_with_dist(self.checkpoint_save_path, state.run_name)
os.makedirs(checkpoint_save_path, exist_ok=True)
folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

def fit_start(self, state: State, logger: Logger) -> None:
if not self.overwrite:
# checks that checkpoint_save_path contains no files with a timestamp after the current timestamp,
# checks that save_folder contains no files with a timestamp after the current timestamp,
# which has potential for future conflicts.
checkpoint_save_path = format_name_with_dist(self.checkpoint_save_path, state.run_name)
ensure_folder_has_no_conflicting_files(checkpoint_save_path, self.checkpoint_filename.filename,
state.timestamp)
folder = format_name_with_dist(self.folder, state.run_name)
ensure_folder_has_no_conflicting_files(folder, self.filename.filename, state.timestamp)

dist.barrier() # holds all ranks until checkpoint_save_path check is done
dist.barrier() # holds all ranks until folder check is done

if is_model_deepspeed(state.model) and self.weights_only:
raise NotImplementedError('weights_only=True is not supported when using DeepSpeed.')

def batch_checkpoint(self, state: State, logger: Logger):
if self.checkpoint_save_interval(
state, Event.BATCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
if self.save_interval(state, Event.BATCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(
state,
logger,
self.get_log_level(state, default=LogLevel.BATCH),
)

def epoch_checkpoint(self, state: State, logger: Logger):
if self.checkpoint_save_interval(
state, Event.EPOCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
if self.save_interval(state, Event.EPOCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(
state,
logger,
Expand All @@ -362,30 +358,29 @@ def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel):

is_deepspeed = is_model_deepspeed(state.model)

if is_deepspeed and '{rank}' not in self.checkpoint_filename.filename:
raise ValueError(
f'Save checkpoint_filename {self.checkpoint_filename.filename} must have {{rank}} for deepspeed.')
if is_deepspeed and '{rank}' not in self.filename.filename:
raise ValueError(f'Save filename {self.filename.filename} must have {{rank}} for deepspeed.')

# save the checkpoint to the filename
checkpoint_filename = self.checkpoint_filename.format(state, is_deepspeed)
filename = self.filename.format(state, is_deepspeed)

saved_path = checkpoint.save_checkpoint(
state=state,
filename=checkpoint_filename,
filename=filename,
weights_only=self.weights_only,
)

if not saved_path: # not all ranks save
return

if self.latest_checkpoint_filename is not None:
symlink = self.latest_checkpoint_filename.format(state, is_deepspeed)
if self.latest_filename is not None:
symlink = self.latest_filename.format(state, is_deepspeed)
os.makedirs(os.path.dirname(symlink), exist_ok=True)
try:
os.remove(symlink)
except FileNotFoundError:
pass
os.symlink(os.path.relpath(checkpoint_filename, os.path.dirname(symlink)), symlink)
os.symlink(os.path.relpath(filename, os.path.dirname(symlink)), symlink)

# if artifact name provided, upload the checkpoint
if self.artifact_name is not None:
Expand All @@ -396,7 +391,7 @@ def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel):

logger.file_artifact(log_level=log_level,
artifact_name=artifact_name,
file_path=checkpoint_filename,
file_path=filename,
overwrite=self.overwrite)

if self.latest_artifact_name is not None:
Expand All @@ -416,7 +411,7 @@ def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel):
overwrite=True,
)

self.saved_checkpoints.append(checkpoint_filename)
self.saved_checkpoints.append(filename)

if self.num_checkpoints_to_keep >= 0:
self._rotate_checkpoints()
Expand Down
Loading