|
| 1 | +# Copyright (c) DeepSpeed Team. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +from deepspeed.ops.adam import DeepSpeedCPUAdam |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +class ZenFlowCPUAdam(DeepSpeedCPUAdam): |
| 11 | + |
| 12 | + def __init__(self, *args, overlap_step=False, **kwargs): |
| 13 | + super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) |
| 14 | + self.overlap_step = overlap_step |
| 15 | + if not self.overlap_step: |
| 16 | + print("ZenFlowCPUAdam initialized with normal step.") |
| 17 | + self.step = self._sequential_step |
| 18 | + else: |
| 19 | + print("ZenFlowCPUAdam initialized with overlap step.") |
| 20 | + self.step = self._parallel_step |
| 21 | + |
| 22 | + @torch.no_grad() |
| 23 | + def _sequential_step(self, step_id, closure=None): |
| 24 | + """Update the model parameters. |
| 25 | +
|
| 26 | + .. note:: |
| 27 | + This method will be called internally by ZeRO-Offload. DeepSpeed |
| 28 | + users should still use ``engine.step()`` as shown in the |
| 29 | + `Getting Started |
| 30 | + <https://www.deepspeed.ai/getting-started/#training>`_ guide. |
| 31 | +
|
| 32 | + Args: |
| 33 | + closure (callable, optional): closure to compute the loss. |
| 34 | + Defaults to ``None``. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + loss: if ``closure`` is provided. Otherwise ``None``. |
| 38 | + """ |
| 39 | + |
| 40 | + loss = None |
| 41 | + if closure is not None: |
| 42 | + with torch.enable_grad(): |
| 43 | + loss = closure() |
| 44 | + |
| 45 | + # intended device for step |
| 46 | + device = torch.device('cpu') |
| 47 | + |
| 48 | + for group_id, group in enumerate(self.param_groups): |
| 49 | + for param_id, p in enumerate(group['params']): |
| 50 | + |
| 51 | + if p.grad is None: |
| 52 | + continue |
| 53 | + |
| 54 | + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ |
| 55 | + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." |
| 56 | + |
| 57 | + state = self.state[p] |
| 58 | + # State initialization |
| 59 | + if len(state) == 0: |
| 60 | + #print(f'group {group_id} param {param_id} = {p.numel()}') |
| 61 | + state['step'] = 0 |
| 62 | + |
| 63 | + #use full precision by default unless self.fp32_optimizer_states is off |
| 64 | + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype |
| 65 | + |
| 66 | + # gradient momentums |
| 67 | + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) |
| 68 | + #memory_format=torch.preserve_format) |
| 69 | + # gradient variances |
| 70 | + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) |
| 71 | + #memory_format=torch.preserve_format) |
| 72 | + |
| 73 | + state['step'] = step_id |
| 74 | + beta1, beta2 = group['betas'] |
| 75 | + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], |
| 76 | + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, |
| 77 | + state['exp_avg'], state['exp_avg_sq']) |
| 78 | + return loss |
| 79 | + |
| 80 | + @torch.no_grad() |
| 81 | + def _parallel_step(self, step_id, now_state, group_info, closure=None): |
| 82 | + """Update the model parameters. |
| 83 | +
|
| 84 | + .. note:: |
| 85 | + This method will be called internally by ZeRO-Offload. DeepSpeed |
| 86 | + users should still use ``engine.step()`` as shown in the |
| 87 | + `Getting Started |
| 88 | + <https://www.deepspeed.ai/getting-started/#training>`_ guide. |
| 89 | +
|
| 90 | + Args: |
| 91 | + closure (callable, optional): closure to compute the loss. |
| 92 | + Defaults to ``None``. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + loss: if ``closure`` is provided. Otherwise ``None``. |
| 96 | + """ |
| 97 | + |
| 98 | + loss = None |
| 99 | + if closure is not None: |
| 100 | + with torch.enable_grad(): |
| 101 | + loss = closure() |
| 102 | + |
| 103 | + # intended device for step |
| 104 | + device = torch.device('cpu') |
| 105 | + |
| 106 | + stale_param = None |
| 107 | + |
| 108 | + for group_id, group in enumerate(self.param_groups): |
| 109 | + for param_id, p in enumerate(group['params']): |
| 110 | + assert p.data.is_shared(), "param.data must be in shared memory" |
| 111 | + if not hasattr(p, 'overlap_grad'): |
| 112 | + continue |
| 113 | + |
| 114 | + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ |
| 115 | + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." |
| 116 | + |
| 117 | + state = self.state[p] |
| 118 | + # State initialization |
| 119 | + if len(state) == 0: |
| 120 | + #print(f'group {group_id} param {param_id} = {p.numel()}') |
| 121 | + # print("creating", flush=True) |
| 122 | + state['step'] = 0 |
| 123 | + |
| 124 | + #use full precision by default unless self.fp32_optimizer_states is off |
| 125 | + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype |
| 126 | + exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) |
| 127 | + exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) |
| 128 | + state['exp_avg'] = [exp_avg, exp_avg.clone()] |
| 129 | + state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] |
| 130 | + |
| 131 | + state['step'] = step_id |
| 132 | + beta1, beta2 = group_info['betas'] |
| 133 | + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, |
| 134 | + group_info['eps'], group_info['weight_decay'], |
| 135 | + group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, |
| 136 | + state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) |
| 137 | + p.stale_param.data.copy_(p.data.clone()) |
| 138 | + return loss |
0 commit comments