Skip to content

Commit 1d7b90a

Browse files
AntleraJoshWoo2003tohtanasfc-gh-truwaseloadams
authored
Add Zenflow code for Stage 1 & 2 (#7391)
This PR adds ZenFlow, a importance-aware offloaded training framework for DeepSpeed ZeRO. ZenFlow enables multi-step overlap between computation and communication during offloaded training, improving GPU utilization and reducing stalls. Highlights: - New ZenFlow optimizers (ZenFlowCPUAdam, ZenFlowSelectiveAdamW) - ZenFlowZeroOptimizer for ZeRO Stage 1/2 integration - Configurable via ZenFlowConfig, integrated with DeepSpeedZeroConfig - Unit tests and documentation included Note: This PR focuses on Stage 1 and 2 integration. Stage 3 support will be introduced in a follow-up PR. --------- Signed-off-by: Tingfeng Lan <[email protected]> Signed-off-by: Yusen Wu <[email protected]> Signed-off-by: Olatunji Ruwase <[email protected]> Co-authored-by: Yusen Wu <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Guokai Ma <[email protected]>
1 parent 33cd945 commit 1d7b90a

File tree

15 files changed

+2601
-15
lines changed

15 files changed

+2601
-15
lines changed

deepspeed/ops/adam/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
# Copyright (c) Microsoft Corporation.
1+
# Copyright (c) DeepSpeed Team.
22
# SPDX-License-Identifier: Apache-2.0
33

44
# DeepSpeed Team
55

66
from .cpu_adam import DeepSpeedCPUAdam
77
from .fused_adam import FusedAdam
8+
from .zenflow_cpu_adam import ZenFlowCPUAdam
9+
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) DeepSpeed Team.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from deepspeed.ops.adam import DeepSpeedCPUAdam
7+
import torch
8+
9+
10+
class ZenFlowCPUAdam(DeepSpeedCPUAdam):
11+
12+
def __init__(self, *args, overlap_step=False, **kwargs):
13+
super(ZenFlowCPUAdam, self).__init__(*args, **kwargs)
14+
self.overlap_step = overlap_step
15+
if not self.overlap_step:
16+
print("ZenFlowCPUAdam initialized with normal step.")
17+
self.step = self._sequential_step
18+
else:
19+
print("ZenFlowCPUAdam initialized with overlap step.")
20+
self.step = self._parallel_step
21+
22+
@torch.no_grad()
23+
def _sequential_step(self, step_id, closure=None):
24+
"""Update the model parameters.
25+
26+
.. note::
27+
This method will be called internally by ZeRO-Offload. DeepSpeed
28+
users should still use ``engine.step()`` as shown in the
29+
`Getting Started
30+
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
31+
32+
Args:
33+
closure (callable, optional): closure to compute the loss.
34+
Defaults to ``None``.
35+
36+
Returns:
37+
loss: if ``closure`` is provided. Otherwise ``None``.
38+
"""
39+
40+
loss = None
41+
if closure is not None:
42+
with torch.enable_grad():
43+
loss = closure()
44+
45+
# intended device for step
46+
device = torch.device('cpu')
47+
48+
for group_id, group in enumerate(self.param_groups):
49+
for param_id, p in enumerate(group['params']):
50+
51+
if p.grad is None:
52+
continue
53+
54+
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
55+
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
56+
57+
state = self.state[p]
58+
# State initialization
59+
if len(state) == 0:
60+
#print(f'group {group_id} param {param_id} = {p.numel()}')
61+
state['step'] = 0
62+
63+
#use full precision by default unless self.fp32_optimizer_states is off
64+
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
65+
66+
# gradient momentums
67+
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
68+
#memory_format=torch.preserve_format)
69+
# gradient variances
70+
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
71+
#memory_format=torch.preserve_format)
72+
73+
state['step'] = step_id
74+
beta1, beta2 = group['betas']
75+
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
76+
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
77+
state['exp_avg'], state['exp_avg_sq'])
78+
return loss
79+
80+
@torch.no_grad()
81+
def _parallel_step(self, step_id, now_state, group_info, closure=None):
82+
"""Update the model parameters.
83+
84+
.. note::
85+
This method will be called internally by ZeRO-Offload. DeepSpeed
86+
users should still use ``engine.step()`` as shown in the
87+
`Getting Started
88+
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
89+
90+
Args:
91+
closure (callable, optional): closure to compute the loss.
92+
Defaults to ``None``.
93+
94+
Returns:
95+
loss: if ``closure`` is provided. Otherwise ``None``.
96+
"""
97+
98+
loss = None
99+
if closure is not None:
100+
with torch.enable_grad():
101+
loss = closure()
102+
103+
# intended device for step
104+
device = torch.device('cpu')
105+
106+
stale_param = None
107+
108+
for group_id, group in enumerate(self.param_groups):
109+
for param_id, p in enumerate(group['params']):
110+
assert p.data.is_shared(), "param.data must be in shared memory"
111+
if not hasattr(p, 'overlap_grad'):
112+
continue
113+
114+
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
115+
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
116+
117+
state = self.state[p]
118+
# State initialization
119+
if len(state) == 0:
120+
#print(f'group {group_id} param {param_id} = {p.numel()}')
121+
# print("creating", flush=True)
122+
state['step'] = 0
123+
124+
#use full precision by default unless self.fp32_optimizer_states is off
125+
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
126+
exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device)
127+
exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device)
128+
state['exp_avg'] = [exp_avg, exp_avg.clone()]
129+
state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()]
130+
131+
state['step'] = step_id
132+
beta1, beta2 = group_info['betas']
133+
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2,
134+
group_info['eps'], group_info['weight_decay'],
135+
group_info['bias_correction'], p.data, p.overlap_grad[now_state].data,
136+
state['exp_avg'][now_state], state['exp_avg_sq'][now_state])
137+
p.stale_param.data.copy_(p.data.clone())
138+
return loss

0 commit comments

Comments
 (0)