Skip to content

Commit 32a9ff9

Browse files
AntleraJoshWoo2003
andcommitted
Refactor ZenFlowZeroOptimizer methods to include communication data type
- Updated methods to accept communication_data_type as a parameter for better handling of IPG buckets. - Removed debug print statements to clean up the code. Signed-off-by: Tingfeng Lan <[email protected]> Co-authored-by: Yusen Wu <[email protected]>
1 parent 18bd0ae commit 32a9ff9

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

deepspeed/runtime/zero/zenflow/zenflow_stage_1_and_2.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,8 @@ def __init__(self,
110110
@classmethod
111111
def create(cls, zenflow_config):
112112
if zenflow_config.overlap_step:
113-
# print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
114-
print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerParallel")
115113
return ZenFlowZeroOptimizerParallel
116114
else:
117-
# print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
118-
print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerSequential")
119115
return ZenFlowZeroOptimizerSequential
120116

121117
def _configure_zenflow(self, zenflow_config):
@@ -182,7 +178,7 @@ def sync_fp32_param_from_gpu(self):
182178
fp32_partition.copy_(bit16_partitions[partition_id].to(dtype=fp32_partition.dtype,
183179
device=fp32_partition.device))
184180

185-
def update_selected_channels(self, tensor, total_size):
181+
def update_selected_channels(self, tensor, total_size, communication_data_type):
186182
curr_size = 0
187183
curr_index_buffer_size = 0
188184
rank_and_offsets = []
@@ -194,7 +190,8 @@ def update_selected_channels(self, tensor, total_size):
194190
self.index_buffer = torch.empty(total_size, dtype=torch.int32, device='cuda')
195191

196192
# count = 0
197-
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
193+
bucket = self.ipg_buckets[communication_data_type]
194+
for i, param_idx_in_group, param_id in bucket.params:
198195
param = self.bit16_groups[i][param_idx_in_group]
199196

200197
if len(param.shape) == 1:
@@ -255,7 +252,7 @@ def update_selected_channels(self, tensor, total_size):
255252
index_slice = self.index_buffer.narrow(0, offset, num_select)
256253
dist.broadcast(index_slice, src=src_rank, group=process_group)
257254

258-
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
255+
for i, param_idx_in_group, param_id in bucket.params:
259256
param = self.bit16_groups[i][param_idx_in_group]
260257

261258
if len(param.shape) == 1:
@@ -281,15 +278,15 @@ def update_selected_channels(self, tensor, total_size):
281278

282279
self.index_buffer = None
283280

284-
def process_selected_fp32_groups_grad(self, tensor, total_size):
281+
def _process_selected_fp32_groups_grad(self, tensor, total_size, communication_data_type):
285282
"""
286283
Process gradients for selected columns in FP32 groups
287284
288285
Args:
289286
param: The parameter to process
290287
param_id: ID of the parameter
291288
"""
292-
print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!process_selected_fp32_groups_grad")
289+
293290
curr_size = 0
294291
curr_grad_buffer_size = 0
295292
curr_sum_buffer_size = 0
@@ -309,7 +306,8 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
309306
group_to_paramlist = {}
310307

311308
# count = 0
312-
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
309+
bucket = self.ipg_buckets[communication_data_type]
310+
for i, param_idx_in_group, param_id in bucket.params:
313311
param = self.bit16_groups[i][param_idx_in_group]
314312

315313
if not hasattr(param, 'selected_indices'):
@@ -389,7 +387,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
389387
sum_slice = self.sum_buffer.narrow(0, sum_offset, sum_num)
390388
dist.broadcast(sum_slice, src=src_rank, group=process_group)
391389

392-
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
390+
for i, param_idx_in_group, param_id in bucket.params:
393391
param = self.bit16_groups[i][param_idx_in_group]
394392

395393
selected_grad = None
@@ -450,7 +448,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
450448
if self.auto_update:
451449
self.sum_buffer = None
452450

453-
def average_tensor(self, tensor):
451+
def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dtype):
454452
if self.overlap_comm:
455453
stream = self.reduction_stream
456454
if not get_accelerator().resolves_data_dependency():
@@ -478,12 +476,13 @@ def average_tensor(self, tensor):
478476

479477
process_group = self.dp_process_group
480478
# count = 0
481-
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
479+
bucket = self.ipg_buckets[communication_data_type]
480+
for i, param_idx_in_group, param_id in bucket.params:
482481
param = self.bit16_groups[i][param_idx_in_group]
483482

484483
process_group = self.dp_process_group
485484

486-
if self.ipg_bucket_has_moe_params:
485+
if bucket.has_moe_params:
487486
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
488487
param) else self.dp_process_group
489488

@@ -546,12 +545,14 @@ def average_tensor(self, tensor):
546545
for bucket_key in buckets:
547546
if self.use_multi_rank_bucket_allreduce:
548547
self.allreduce_and_scatter(buckets[bucket_key],
548+
communication_data_type,
549549
numel_per_bucket=self.reduce_bucket_size,
550550
divide=False,
551551
process_group=bucket_key)
552552
else:
553553
dst, process_group = bucket_key
554554
self.allreduce_no_retain(buckets[bucket_key],
555+
communication_data_type,
555556
numel_per_bucket=self.reduce_bucket_size,
556557
rank=dst,
557558
divide=False,
@@ -560,15 +561,15 @@ def average_tensor(self, tensor):
560561
if self.is_zenflow_select_boundary():
561562
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start()
562563
# print("update selected")
563-
self.update_selected_channels(tensor, curr_column_size)
564+
self.update_selected_channels(tensor, curr_column_size, communication_data_type)
564565
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop()
565566
elif self.zenflow:
566567
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start()
567568
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop()
568569

569570
if self.zenflow and self.micro_step >= self.full_warm_up_rounds:
570571
self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).start()
571-
self.process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size)
572+
self._process_selected_fp32_groups_grad(tensor, curr_selected_reduce_size, communication_data_type)
572573
self.timers(SELECTIVE_OPTIMIZER_PROCESS_TIMER).stop()
573574

574575
def backward(self, loss, retain_graph=False):

0 commit comments

Comments
 (0)