Skip to content
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3309b49
Add ZenFlow optimizers (zero stage 1&2) for ZeRO integration
Antlera Jun 25, 2025
4e9fe2a
Add ZenFlowConfig for optimizer configuration
Antlera Jun 25, 2025
cac5703
Add ZenFlow (zero stage 1&2) integration in DeepSpeedEngine
Antlera Jun 25, 2025
0e9a0c9
Add unit tests for ZenFlowConfig
Antlera Jun 25, 2025
3353e34
Fix initialization and update logic for ZenFlow optimizers
Antlera Jun 26, 2025
28cdf89
Add unit tests for ZenFlowSelectiveAdamW optimizer
Antlera Jun 26, 2025
f534d5e
Add ZenFlow tutorial documentation
Antlera Jun 27, 2025
80ad488
Format code
Antlera Jun 27, 2025
9c05ccb
Fix check_grad_overflow parameter in ZenFlowZeroOptimizer
Antlera Jun 27, 2025
da80ff7
Refactor ZenFlowZeroOptimizer methods to include communication data type
Antlera Jun 27, 2025
417932a
Merge remote-tracking branch 'upstream/master' into zenflow_zero1_2
Antlera Jun 28, 2025
fee24ff
Refactor ZenFlow integration in DeepSpeedEngine
Antlera Jun 28, 2025
a528fd4
Refactor ZenFlow function callings in DeepSpeedEngine
Antlera Jun 28, 2025
9aac3c0
Merge branch 'master' into zenflow_zero1_2
tohtana Jun 30, 2025
f7bc35d
Fix bugs in ZenFlow + ZeRO Stage 1 and gradient reduction logic
JoshWoo2003 Jul 1, 2025
3638d78
Add unit tests for ZenFlow with ZeRO Stage 1 and 2
JoshWoo2003 Jul 1, 2025
fad8498
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into z…
Antlera Jul 1, 2025
6d68330
Refactor ZenFlow integration using seperate engine file
Antlera Jul 2, 2025
913f9a7
Fix missing `[comm_dtype]` and format code
Antlera Jul 2, 2025
6b8c82a
Merge branch 'master' into zenflow_zero1_2
tohtana Jul 2, 2025
bce0a7f
Update CPUADAM core range calculation in zenflow_stage_1_and_2.py
Antlera Jul 3, 2025
6f51348
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera Jul 3, 2025
0ef3faf
Fix bugs in ZenFlow unit tests
JoshWoo2003 Jul 9, 2025
a623556
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Jul 13, 2025
b898eaf
Merge remote-tracking branch 'origin/zenflow_zero1_2' into clr_branch…
Antlera Jul 16, 2025
e2a2b81
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera Jul 16, 2025
8d6b6f3
Fix: Add PyTorch version check for ZenFlow configuration
Antlera Jul 16, 2025
1e70efa
Merge branch 'master' into zenflow_zero1_2
tohtana Jul 16, 2025
891ac09
Enhance ZenFlow compatibility checks for PyTorch version
Antlera Jul 16, 2025
0d7d086
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera Jul 16, 2025
da902eb
Merge branch 'master' into zenflow_zero1_2
loadams Aug 1, 2025
d2d1a06
Fix bugs in ZenFlow unit tests when using CPU Torch
JoshWoo2003 Aug 1, 2025
4cb3178
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 2, 2025
4d1db6d
Merge branch 'master' into zenflow_zero1_2
tjruwase Aug 2, 2025
f3b2276
Added TODO comments to indicate the need for removing ZenFlow-specifi…
Antlera Aug 2, 2025
e48622c
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera Aug 2, 2025
bbb6f74
Fix formatting in test_zf.py
Antlera Aug 2, 2025
9f4fb58
Update docs/_tutorials/zenflow.md
Antlera Aug 2, 2025
df70150
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 4, 2025
29c5f28
Merge branch 'master' into zenflow_zero1_2
Antlera Aug 10, 2025
dc505a7
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 10, 2025
938e8a3
Fix copyrights.
Antlera Aug 10, 2025
8951fa0
Remove CUDA specific code.
Antlera Aug 10, 2025
ef166ad
Merge branch 'master' into zenflow_zero1_2
Antlera Aug 11, 2025
b849af6
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 11, 2025
0593fc2
Merge branch 'master' into zenflow_zero1_2
Antlera Aug 12, 2025
671c568
Merge branch 'master' into zenflow_zero1_2
tjruwase Aug 12, 2025
8ed876c
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 15, 2025
900491e
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/ops/adam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@

from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam
from .zenflow_cpu_adam import ZenFlowCPUAdam
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
138 changes: 138 additions & 0 deletions deepspeed/ops/adam/zenflow_cpu_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.ops.adam import DeepSpeedCPUAdam
import torch


class ZenFlowCPUAdam(DeepSpeedCPUAdam):

def __init__(self, *args, overlap_step=False, **kwargs):
super(ZenFlowCPUAdam, self).__init__(*args, **kwargs)
self.overlap_step = overlap_step
if not self.overlap_step:
print("ZenFlowCPUAdam initialized with normal step.")
self.step = self._sequential_step
else:
print("ZenFlowCPUAdam initialized with overlap step.")
self.step = self._parallel_step

@torch.no_grad()
def _sequential_step(self, step_id, closure=None):
"""Update the model parameters.

.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.

Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.

Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

# intended device for step
device = torch.device('cpu')

for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):

if p.grad is None:
continue

assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."

state = self.state[p]
# State initialization
if len(state) == 0:
#print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0

#use full precision by default unless self.fp32_optimizer_states is off
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype

# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
#memory_format=torch.preserve_format)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
#memory_format=torch.preserve_format)

state['step'] = step_id
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'])
return loss

@torch.no_grad()
def _parallel_step(self, step_id, now_state, group_info, closure=None):
"""Update the model parameters.

.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.

Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.

Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

# intended device for step
device = torch.device('cpu')

stale_param = None

for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
assert p.data.is_shared(), "param.data must be in shared memory"
if not hasattr(p, 'overlap_grad'):
continue

assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."

state = self.state[p]
# State initialization
if len(state) == 0:
#print(f'group {group_id} param {param_id} = {p.numel()}')
# print("creating", flush=True)
state['step'] = 0

#use full precision by default unless self.fp32_optimizer_states is off
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device)
exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device)
state['exp_avg'] = [exp_avg, exp_avg.clone()]
state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()]

state['step'] = step_id
beta1, beta2 = group_info['betas']
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2,
group_info['eps'], group_info['weight_decay'],
group_info['bias_correction'], p.data, p.overlap_grad[now_state].data,
state['exp_avg'][now_state], state['exp_avg_sq'][now_state])
p.stale_param.data.copy_(p.data.clone())
return loss
Loading