-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Add Zenflow code for Stage 1 & 2 #7391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 41 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
3309b49
Add ZenFlow optimizers (zero stage 1&2) for ZeRO integration
Antlera 4e9fe2a
Add ZenFlowConfig for optimizer configuration
Antlera cac5703
Add ZenFlow (zero stage 1&2) integration in DeepSpeedEngine
Antlera 0e9a0c9
Add unit tests for ZenFlowConfig
Antlera 3353e34
Fix initialization and update logic for ZenFlow optimizers
Antlera 28cdf89
Add unit tests for ZenFlowSelectiveAdamW optimizer
Antlera f534d5e
Add ZenFlow tutorial documentation
Antlera 80ad488
Format code
Antlera 9c05ccb
Fix check_grad_overflow parameter in ZenFlowZeroOptimizer
Antlera da80ff7
Refactor ZenFlowZeroOptimizer methods to include communication data type
Antlera 417932a
Merge remote-tracking branch 'upstream/master' into zenflow_zero1_2
Antlera fee24ff
Refactor ZenFlow integration in DeepSpeedEngine
Antlera a528fd4
Refactor ZenFlow function callings in DeepSpeedEngine
Antlera 9aac3c0
Merge branch 'master' into zenflow_zero1_2
tohtana f7bc35d
Fix bugs in ZenFlow + ZeRO Stage 1 and gradient reduction logic
JoshWoo2003 3638d78
Add unit tests for ZenFlow with ZeRO Stage 1 and 2
JoshWoo2003 fad8498
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into z…
Antlera 6d68330
Refactor ZenFlow integration using seperate engine file
Antlera 913f9a7
Fix missing `[comm_dtype]` and format code
Antlera 6b8c82a
Merge branch 'master' into zenflow_zero1_2
tohtana bce0a7f
Update CPUADAM core range calculation in zenflow_stage_1_and_2.py
Antlera 6f51348
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera 0ef3faf
Fix bugs in ZenFlow unit tests
JoshWoo2003 a623556
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase b898eaf
Merge remote-tracking branch 'origin/zenflow_zero1_2' into clr_branch…
Antlera e2a2b81
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera 8d6b6f3
Fix: Add PyTorch version check for ZenFlow configuration
Antlera 1e70efa
Merge branch 'master' into zenflow_zero1_2
tohtana 891ac09
Enhance ZenFlow compatibility checks for PyTorch version
Antlera 0d7d086
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera da902eb
Merge branch 'master' into zenflow_zero1_2
loadams d2d1a06
Fix bugs in ZenFlow unit tests when using CPU Torch
JoshWoo2003 4cb3178
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase 4d1db6d
Merge branch 'master' into zenflow_zero1_2
tjruwase f3b2276
Added TODO comments to indicate the need for removing ZenFlow-specifi…
Antlera e48622c
Merge branch 'zenflow_zero1_2' of github.com:Antlera/DeepSpeed into c…
Antlera bbb6f74
Fix formatting in test_zf.py
Antlera 9f4fb58
Update docs/_tutorials/zenflow.md
Antlera df70150
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase 29c5f28
Merge branch 'master' into zenflow_zero1_2
Antlera dc505a7
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase 938e8a3
Fix copyrights.
Antlera 8951fa0
Remove CUDA specific code.
Antlera ef166ad
Merge branch 'master' into zenflow_zero1_2
Antlera b849af6
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase 0593fc2
Merge branch 'master' into zenflow_zero1_2
Antlera 671c568
Merge branch 'master' into zenflow_zero1_2
tjruwase 8ed876c
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase 900491e
Merge branch 'master' into zenflow_zero1_2
sfc-gh-truwase File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from deepspeed.ops.adam import DeepSpeedCPUAdam | ||
import torch | ||
|
||
|
||
class ZenFlowCPUAdam(DeepSpeedCPUAdam): | ||
|
||
def __init__(self, *args, overlap_step=False, **kwargs): | ||
super(ZenFlowCPUAdam, self).__init__(*args, **kwargs) | ||
self.overlap_step = overlap_step | ||
if not self.overlap_step: | ||
print("ZenFlowCPUAdam initialized with normal step.") | ||
self.step = self._sequential_step | ||
else: | ||
print("ZenFlowCPUAdam initialized with overlap step.") | ||
self.step = self._parallel_step | ||
|
||
@torch.no_grad() | ||
def _sequential_step(self, step_id, closure=None): | ||
"""Update the model parameters. | ||
|
||
.. note:: | ||
This method will be called internally by ZeRO-Offload. DeepSpeed | ||
users should still use ``engine.step()`` as shown in the | ||
`Getting Started | ||
<https://www.deepspeed.ai/getting-started/#training>`_ guide. | ||
|
||
Args: | ||
closure (callable, optional): closure to compute the loss. | ||
Defaults to ``None``. | ||
|
||
Returns: | ||
loss: if ``closure`` is provided. Otherwise ``None``. | ||
""" | ||
|
||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
# intended device for step | ||
device = torch.device('cpu') | ||
|
||
for group_id, group in enumerate(self.param_groups): | ||
for param_id, p in enumerate(group['params']): | ||
|
||
if p.grad is None: | ||
continue | ||
|
||
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ | ||
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." | ||
|
||
state = self.state[p] | ||
# State initialization | ||
if len(state) == 0: | ||
#print(f'group {group_id} param {param_id} = {p.numel()}') | ||
state['step'] = 0 | ||
|
||
#use full precision by default unless self.fp32_optimizer_states is off | ||
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype | ||
|
||
# gradient momentums | ||
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) | ||
#memory_format=torch.preserve_format) | ||
# gradient variances | ||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) | ||
#memory_format=torch.preserve_format) | ||
|
||
state['step'] = step_id | ||
beta1, beta2 = group['betas'] | ||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], | ||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data, | ||
state['exp_avg'], state['exp_avg_sq']) | ||
return loss | ||
|
||
@torch.no_grad() | ||
def _parallel_step(self, step_id, now_state, group_info, closure=None): | ||
"""Update the model parameters. | ||
|
||
.. note:: | ||
This method will be called internally by ZeRO-Offload. DeepSpeed | ||
users should still use ``engine.step()`` as shown in the | ||
`Getting Started | ||
<https://www.deepspeed.ai/getting-started/#training>`_ guide. | ||
|
||
Args: | ||
closure (callable, optional): closure to compute the loss. | ||
Defaults to ``None``. | ||
|
||
Returns: | ||
loss: if ``closure`` is provided. Otherwise ``None``. | ||
""" | ||
|
||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
# intended device for step | ||
device = torch.device('cpu') | ||
|
||
stale_param = None | ||
|
||
for group_id, group in enumerate(self.param_groups): | ||
for param_id, p in enumerate(group['params']): | ||
assert p.data.is_shared(), "param.data must be in shared memory" | ||
if not hasattr(p, 'overlap_grad'): | ||
continue | ||
|
||
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ | ||
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." | ||
|
||
state = self.state[p] | ||
# State initialization | ||
if len(state) == 0: | ||
#print(f'group {group_id} param {param_id} = {p.numel()}') | ||
# print("creating", flush=True) | ||
state['step'] = 0 | ||
|
||
#use full precision by default unless self.fp32_optimizer_states is off | ||
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype | ||
exp_avg = torch.zeros_like(p.data, dtype=state_dtype, device=device) | ||
exp_avg_sq = torch.zeros_like(p.data, dtype=state_dtype, device=device) | ||
state['exp_avg'] = [exp_avg, exp_avg.clone()] | ||
state['exp_avg_sq'] = [exp_avg_sq, exp_avg_sq.clone()] | ||
|
||
state['step'] = step_id | ||
beta1, beta2 = group_info['betas'] | ||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group_info['lr'], beta1, beta2, | ||
group_info['eps'], group_info['weight_decay'], | ||
group_info['bias_correction'], p.data, p.overlap_grad[now_state].data, | ||
state['exp_avg'][now_state], state['exp_avg_sq'][now_state]) | ||
p.stale_param.data.copy_(p.data.clone()) | ||
return loss |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.