Skip to content

Commit fee24ff

Browse files
committed
Refactor ZenFlow integration in DeepSpeedEngine
- Move `_configure_zenflow` logic to a standalone `configure_zenflow()` function in `zenflow_utils.py` - Refactor ZenFlow place to decouple it from ZeRO internals Signed-off-by: Tingfeng Lan <[email protected]>
1 parent 417932a commit fee24ff

File tree

10 files changed

+162
-100
lines changed

10 files changed

+162
-100
lines changed

deepspeed/runtime/engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
2828
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
2929
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
30-
from deepspeed.runtime.zero.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer
30+
from deepspeed.runtime.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer
3131
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
3232
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
3333
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
3434
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION
35+
from deepspeed.runtime.zenflow.zenflow_utils import configure_zenflow
3536

3637
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
3738
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
@@ -334,6 +335,8 @@ def __init__(self,
334335
if self.torch_autocast_enabled():
335336
init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules())
336337

338+
configure_zenflow(self)
339+
337340
if has_optimizer:
338341
self._configure_optimizer(optimizer, model_parameters)
339342
self._configure_lr_scheduler()
File renamed without changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from pydantic import Field, model_validator
7+
from typing import Optional, Union
8+
9+
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
10+
11+
12+
class ZenFlowConfig(DeepSpeedConfigModel):
13+
"""Configuration options for ZenFlow optimization module."""
14+
15+
topk_ratio: float = Field(0.1, ge=0.0, le=1.0)
16+
"""Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0)."""
17+
18+
select_strategy: str = "auto"
19+
"""Strategy for selecting important gradient indices.
20+
Options: "auto", "step", or "epoch"."""
21+
22+
select_interval: Union[str, int] = "auto"
23+
"""Interval at which to reselect important gradient indices.
24+
Can be "auto" or a fixed integer step/epoch interval."""
25+
26+
update_interval: Union[str, int] = "auto"
27+
"""Interval for applying accumulated unimportant gradients to model parameters.
28+
Can be "auto" or a fixed integer step interval."""
29+
30+
overlap_step: bool = False
31+
"""Whether to overlap CPU-side optimizer steps with forward/backward computation."""
32+
33+
offload: bool = False
34+
"""Whether to offload selective optimizer states to CPU to save memory."""
35+
36+
auto_ratio: float = Field(0.99, ge=0.0, le=1.0)
37+
"""Threshold used in the "auto" strategy to determine update_interval."""
38+
39+
full_warm_up_rounds: int = 0
40+
"""Number of initial rounds during which all gradients are fully updated (no selection)."""
41+
42+
steps_per_epoch: Optional[int] = Field(
43+
default=None,
44+
description=
45+
"Number of steps per epoch. This field is initialized during execution and should not be set by users.",
46+
exclude=True)
47+
48+
@model_validator(mode="after")
49+
def validate_fields(self):
50+
if self.select_strategy not in ["auto", "step", "epoch"]:
51+
raise ValueError('select_strategy must be one of "auto", "step", or "epoch"')
52+
53+
if isinstance(self.select_interval, str) and self.select_interval != "auto":
54+
raise ValueError('If select_interval is a string, it must be "auto"')
55+
56+
if isinstance(self.update_interval, str) and self.update_interval != "auto":
57+
raise ValueError('If update_interval is a string, it must be "auto"')
58+
59+
if not isinstance(self.full_warm_up_rounds, int):
60+
raise ValueError('full_warm_up_rounds must be an integer')
61+
62+
return self
File renamed without changes.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from deepspeed.runtime.engine import DeepSpeedEngine
11+
12+
13+
def _flatten_dense_tensors(tensors):
14+
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
15+
same dense type.
16+
17+
Since inputs are dense, the resulting tensor will be a concatenated 1D
18+
buffer. Element-wise operation on this buffer will be equivalent to
19+
operating individually.
20+
21+
Args:
22+
tensors (Iterable[Tensor]): dense tensors to flatten.
23+
24+
Returns:
25+
A contiguous 1D buffer containing input tensors.
26+
"""
27+
transposed_tensors = [t.transpose(0, 1).contiguous() if t.dim() == 2 else t for t in tensors]
28+
return torch._C._nn.flatten_dense_tensors(transposed_tensors)
29+
30+
31+
def _unflatten_dense_tensors(flat, tensors):
32+
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
33+
same dense type, and that flat is given by _flatten_dense_tensors.
34+
35+
Args:
36+
flat (Tensor): flattened dense tensors to unflatten.
37+
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
38+
unflatten flat.
39+
40+
Returns:
41+
Unflattened dense tensors with sizes same as tensors and values from
42+
flat.
43+
"""
44+
transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors]
45+
unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors)
46+
return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat]
47+
48+
49+
def configure_zenflow(engine: "DeepSpeedEngine") -> None:
50+
zenflow_config = engine.zenflow_config()
51+
if zenflow_config == None:
52+
engine.zenflow = False
53+
return
54+
55+
engine.zenflow = True
56+
select_strategy = zenflow_config.select_strategy
57+
58+
if select_strategy == 'auto':
59+
select_strategy = "epoch"
60+
if isinstance(zenflow_config.select_interval, int):
61+
raise Warning(
62+
"If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten."
63+
)
64+
engine.select_interval = 1
65+
else:
66+
if isinstance(zenflow_config.select_interval, str):
67+
raise ValueError("If don't use auto select strategy, select_interval must be a number.")
68+
engine.select_interval = zenflow_config.select_interval
69+
70+
if isinstance(zenflow_config.update_interval, str):
71+
engine.auto_update = True
72+
engine.update_interval = 0
73+
else:
74+
engine.auto_update = False
75+
engine.update_interval = int(zenflow_config.update_interval)
76+
77+
if select_strategy == 'epoch':
78+
zenflow_config.steps_per_epoch = len(engine.training_dataloader)
79+
engine.select_interval = engine.select_interval * len(engine.training_dataloader)
80+
81+
if not engine.auto_update and engine.select_interval != 0 and engine.select_interval < engine.update_interval:
82+
raise ValueError("Select interval must be greater or equal to update interval")
83+
84+
engine.overlap_step = zenflow_config.overlap_step
85+
86+
engine.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
87+
88+
engine._config.gradient_accumulation_steps = engine.update_interval

deepspeed/runtime/zero/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pydantic import Field, model_validator
1010
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
1111
from deepspeed.utils import logger
12-
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum, ZenFlowConfig
12+
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
13+
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
1314

1415
# ZeRO optimization. By default, this optimization is not enabled.
1516
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:

deepspeed/runtime/zero/offload_config.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from pathlib import Path
88
from pydantic import Field, model_validator
9-
from typing import Optional, Union
9+
from typing import Optional
1010

1111
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
1212

@@ -100,59 +100,6 @@ def set_pipeline(self):
100100
return self
101101

102102

103-
class ZenFlowConfig(DeepSpeedConfigModel):
104-
"""Configuration options for ZenFlow optimization module."""
105-
106-
topk_ratio: float = Field(0.1, ge=0.0, le=1.0)
107-
"""Ratio of top-k important gradient columns to retain (range: 0.0 to 1.0)."""
108-
109-
select_strategy: str = "auto"
110-
"""Strategy for selecting important gradient indices.
111-
Options: "auto", "step", or "epoch"."""
112-
113-
select_interval: Union[str, int] = "auto"
114-
"""Interval at which to reselect important gradient indices.
115-
Can be "auto" or a fixed integer step/epoch interval."""
116-
117-
update_interval: Union[str, int] = "auto"
118-
"""Interval for applying accumulated unimportant gradients to model parameters.
119-
Can be "auto" or a fixed integer step interval."""
120-
121-
overlap_step: bool = False
122-
"""Whether to overlap CPU-side optimizer steps with forward/backward computation."""
123-
124-
offload: bool = False
125-
"""Whether to offload selective optimizer states to CPU to save memory."""
126-
127-
auto_ratio: float = Field(0.99, ge=0.0, le=1.0)
128-
"""Threshold used in the "auto" strategy to determine update_interval."""
129-
130-
full_warm_up_rounds: int = 0
131-
"""Number of initial rounds during which all gradients are fully updated (no selection)."""
132-
133-
steps_per_epoch: Optional[int] = Field(
134-
default=None,
135-
description=
136-
"Number of steps per epoch. This field is initialized during execution and should not be set by users.",
137-
exclude=True)
138-
139-
@model_validator(mode="after")
140-
def validate_fields(self):
141-
if self.select_strategy not in ["auto", "step", "epoch"]:
142-
raise ValueError('select_strategy must be one of "auto", "step", or "epoch"')
143-
144-
if isinstance(self.select_interval, str) and self.select_interval != "auto":
145-
raise ValueError('If select_interval is a string, it must be "auto"')
146-
147-
if isinstance(self.update_interval, str) and self.update_interval != "auto":
148-
raise ValueError('If update_interval is a string, it must be "auto"')
149-
150-
if not isinstance(self.full_warm_up_rounds, int):
151-
raise ValueError('full_warm_up_rounds must be an integer')
152-
153-
return self
154-
155-
156103
class OffloadStateTypeEnum(str, Enum):
157104
""" Enum for internal buffer types """
158105
optim_states = "optim_states"

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import List, Dict
1212

1313
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
14-
from deepspeed.runtime.zero.zenflow import zenflow_utils
14+
from deepspeed.runtime.zenflow import zenflow_utils
1515

1616
from deepspeed.runtime.base_optimizer import ZeROOptimizer
1717
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
@@ -1933,6 +1933,8 @@ def _optimizer_step(self, group_no):
19331933
if self.torch_autocast_gradscaler:
19341934
self.torch_autocast_gradscaler.step(self.optimizer)
19351935
self.torch_autocast_gradscaler.update()
1936+
elif self.zenflow:
1937+
self.zenflow_cpu_optimizer_step(group_no)
19361938
else:
19371939
self.optimizer.step()
19381940
self.optimizer.param_groups = original_param_groups

deepspeed/runtime/zero/zenflow/zenflow_utils.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

tests/unit/runtime/zenflow/test_zf_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from pydantic import ValidationError
88

99
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig, ZeroStageEnum
10-
from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig, ZenFlowConfig
10+
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
11+
from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig
1112

1213

1314
def test_stage_enum_accepts_int_and_enum():

0 commit comments

Comments
 (0)