@@ -516,13 +516,40 @@ def checkpoint(
516
516
kwargs : dict
517
517
dictionary of string keys for keyword arguments to :attr:`function`.
518
518
"""
519
+ only_tensor_args = True
520
+ for arg in args :
521
+ if not isinstance (arg , torch .Tensor ):
522
+ only_tensor_args = False
523
+ break
524
+
519
525
# Pop out te.distributed.checkpoint() arguments
520
526
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
521
527
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs .pop ("use_reentrant" , True )
522
528
distribute_saved_activations = kwargs .pop ("distribute_saved_activations" , False )
523
529
tp_group = kwargs .pop ("tp_group" , None )
524
530
get_rng_state_tracker = kwargs .pop ("get_rng_state_tracker" , None )
525
531
532
+ # Ensure backward compatibility.
533
+ if not only_tensor_args :
534
+ warnings .warn (
535
+ "Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
536
+ "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
537
+ "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`." ,
538
+ DeprecationWarning , stacklevel = 2 ,
539
+ )
540
+ assert len (args ) > 3 , "Incorrect number of arguments for deprecated `checkpoint` API."
541
+ assert (
542
+ isinstance (args [0 ], bool ) and callable (args [1 ])
543
+ and isinstance (args [2 ], None | dist_group_type )
544
+ ), "Incorrect arguments for deprecated `checkpoint` API."
545
+ for arg in args [3 :]:
546
+ assert (
547
+ isinstance (arg , None | torch .Tensor )
548
+ ), f"Expected tensor argument, found { type (arg )} ."
549
+
550
+ distribute_saved_activations , get_rng_state_tracker , tp_group = args [:3 ] # pylint: disable=unbalanced-tuple-unpacking
551
+ args = args [3 :]
552
+
526
553
# Trigger the native PyTorch checkpoint if:
527
554
# 1. `function` is a `torch.nn.Module`
528
555
# AND
@@ -555,16 +582,6 @@ def checkpoint(
555
582
assert torch .distributed .is_initialized (), "torch.distributed is not initialized."
556
583
tp_group = torch .distributed .GroupMember .WORLD if tp_group is None else tp_group
557
584
558
- # Make sure at least one tensor input has `requires_grad=True`
559
- input_requires_grad = False
560
- for arg in args :
561
- if isinstance (arg , torch .Tensor ) and arg .requires_grad :
562
- input_requires_grad = True
563
- break
564
- assert input_requires_grad , (
565
- "`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
566
- )
567
-
568
585
return _CheckpointFunction .apply (
569
586
function ,
570
587
distribute_saved_activations ,
0 commit comments