Skip to content

Commit 3c1817f

Browse files
nelyahuloadamshwchen2017
authored
Reland perf fix for nan inf check (#7184)
replace previous usage with logical ops for nan/inf detect with torch.where --------- Signed-off-by: Nadav Elyahu <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Hongwei Chen <[email protected]>
1 parent 79ff162 commit 3c1817f

File tree

3 files changed

+16
-20
lines changed

3 files changed

+16
-20
lines changed

deepspeed/runtime/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,14 @@ def get_only_unique_item(items):
823823
return unique_item
824824

825825

826+
def mask_nan_or_inf_with_val_inplace(input, device=None, val=-1.):
827+
norm_is_inf = input.isinf()
828+
norm_is_nan = input.isnan()
829+
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
830+
err = torch.tensor(-1.0, device=device, dtype=torch.float)
831+
input.masked_fill_(inf_or_nan, err)
832+
833+
826834
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
827835
"""Get norm of an iterable of tensors.
828836
@@ -897,8 +905,7 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
897905
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
898906
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)
899907

900-
inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan())
901-
total_norm.masked_fill_(inf_or_nan, -1)
908+
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
902909

903910
return total_norm
904911

deepspeed/runtime/zero/stage3.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from deepspeed.utils.torch import register_grad_hook
2020
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
2121
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
22-
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
22+
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace
2323
from deepspeed.runtime.zero.partition_parameters import *
2424
from deepspeed.runtime.zero.config import ZeroStageEnum
2525
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
@@ -1453,12 +1453,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
14531453

14541454
total_norm = total_norm_cuda[0]**(1. / norm_type)
14551455

1456-
norm_is_inf = total_norm.isinf()
1457-
norm_is_nan = total_norm.isnan()
1458-
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
1459-
1460-
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
1461-
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
1456+
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
14621457

14631458
return total_norm.cpu()
14641459

@@ -1815,7 +1810,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
18151810
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
18161811

18171812
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
1818-
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
1813+
total_norm = torch.where(inf_or_nan, err, total_norm)
18191814

18201815
return total_norm
18211816

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from deepspeed.runtime.base_optimizer import ZeROOptimizer
1313
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
1414
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
15-
align_dense_tensors, all_gather_dp_groups)
15+
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace)
1616
from deepspeed.runtime.zero.config import ZeroStageEnum
1717
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
1818
from deepspeed.ops.adam import DeepSpeedCPUAdam
@@ -1722,11 +1722,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
17221722

17231723
total_norm = total_norm.pow(1. / norm_type)
17241724

1725-
norm_is_inf = total_norm.isinf()
1726-
norm_is_nan = total_norm.isnan()
1727-
1728-
if norm_is_inf or norm_is_nan:
1729-
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
1725+
mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
17301726

17311727
return total_norm
17321728

@@ -1984,10 +1980,8 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
19841980
if self.clip_grad > 0.:
19851981
# norm is in fact norm*scale
19861982
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
1987-
1988-
# handle total_norm invalid value -1
1989-
if clip > 1:
1990-
combined_scale = clip * self.loss_scale
1983+
clip = torch.clamp(clip, min=1.0)
1984+
combined_scale = clip * self.loss_scale
19911985

19921986
for grad in grad_groups_flat:
19931987
if isinstance(grad, list):

0 commit comments

Comments
 (0)