|
12 | 12 | from deepspeed.runtime.base_optimizer import ZeROOptimizer
|
13 | 13 | from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
|
14 | 14 | from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
|
15 |
| - align_dense_tensors, all_gather_dp_groups) |
| 15 | + align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace) |
16 | 16 | from deepspeed.runtime.zero.config import ZeroStageEnum
|
17 | 17 | from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
18 | 18 | from deepspeed.ops.adam import DeepSpeedCPUAdam
|
@@ -1722,11 +1722,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
|
1722 | 1722 |
|
1723 | 1723 | total_norm = total_norm.pow(1. / norm_type)
|
1724 | 1724 |
|
1725 |
| - norm_is_inf = total_norm.isinf() |
1726 |
| - norm_is_nan = total_norm.isnan() |
1727 |
| - |
1728 |
| - if norm_is_inf or norm_is_nan: |
1729 |
| - total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float) |
| 1725 | + mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) |
1730 | 1726 |
|
1731 | 1727 | return total_norm
|
1732 | 1728 |
|
@@ -1984,10 +1980,8 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
1984 | 1980 | if self.clip_grad > 0.:
|
1985 | 1981 | # norm is in fact norm*scale
|
1986 | 1982 | clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
1987 |
| - |
1988 |
| - # handle total_norm invalid value -1 |
1989 |
| - if clip > 1: |
1990 |
| - combined_scale = clip * self.loss_scale |
| 1983 | + clip = torch.clamp(clip, min=1.0) |
| 1984 | + combined_scale = clip * self.loss_scale |
1991 | 1985 |
|
1992 | 1986 | for grad in grad_groups_flat:
|
1993 | 1987 | if isinstance(grad, list):
|
|
0 commit comments