Skip to content

Commit 12cbd86

Browse files
authored
[PyTorch] Fix backward compatibility with checkpoint API (#740)
* Fix backward compatibility with checkpoint API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments and fix lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent df1b16d commit 12cbd86

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

transformer_engine/pytorch/distributed.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,40 @@ def checkpoint(
516516
kwargs : dict
517517
dictionary of string keys for keyword arguments to :attr:`function`.
518518
"""
519+
only_tensor_args = True
520+
for arg in args:
521+
if not isinstance(arg, torch.Tensor):
522+
only_tensor_args = False
523+
break
524+
519525
# Pop out te.distributed.checkpoint() arguments
520526
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
521527
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
522528
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
523529
tp_group = kwargs.pop("tp_group", None)
524530
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
525531

532+
# Ensure backward compatibility.
533+
if not only_tensor_args:
534+
warnings.warn(
535+
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
536+
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
537+
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
538+
DeprecationWarning, stacklevel=2,
539+
)
540+
assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API."
541+
assert (
542+
isinstance(args[0], bool) and callable(args[1])
543+
and isinstance(args[2], None | dist_group_type)
544+
), "Incorrect arguments for deprecated `checkpoint` API."
545+
for arg in args[3:]:
546+
assert (
547+
isinstance(arg, None | torch.Tensor)
548+
), f"Expected tensor argument, found {type(arg)}."
549+
550+
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
551+
args = args[3:]
552+
526553
# Trigger the native PyTorch checkpoint if:
527554
# 1. `function` is a `torch.nn.Module`
528555
# AND
@@ -555,16 +582,6 @@ def checkpoint(
555582
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
556583
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group
557584

558-
# Make sure at least one tensor input has `requires_grad=True`
559-
input_requires_grad = False
560-
for arg in args:
561-
if isinstance(arg, torch.Tensor) and arg.requires_grad:
562-
input_requires_grad = True
563-
break
564-
assert input_requires_grad, (
565-
"`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
566-
)
567-
568585
return _CheckpointFunction.apply(
569586
function,
570587
distribute_saved_activations,

0 commit comments

Comments
 (0)