@@ -110,12 +110,8 @@ def __init__(self,
110
110
@classmethod
111
111
def create (cls , zenflow_config ):
112
112
if zenflow_config .overlap_step :
113
- # print("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
114
- print ("Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerParallel" )
115
113
return ZenFlowZeroOptimizerParallel
116
114
else :
117
- # print("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
118
- print ("No!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ZenFlowZeroOptimizerSequential" )
119
115
return ZenFlowZeroOptimizerSequential
120
116
121
117
def _configure_zenflow (self , zenflow_config ):
@@ -182,7 +178,7 @@ def sync_fp32_param_from_gpu(self):
182
178
fp32_partition .copy_ (bit16_partitions [partition_id ].to (dtype = fp32_partition .dtype ,
183
179
device = fp32_partition .device ))
184
180
185
- def update_selected_channels (self , tensor , total_size ):
181
+ def update_selected_channels (self , tensor , total_size , communication_data_type ):
186
182
curr_size = 0
187
183
curr_index_buffer_size = 0
188
184
rank_and_offsets = []
@@ -194,7 +190,8 @@ def update_selected_channels(self, tensor, total_size):
194
190
self .index_buffer = torch .empty (total_size , dtype = torch .int32 , device = 'cuda' )
195
191
196
192
# count = 0
197
- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
193
+ bucket = self .ipg_buckets [communication_data_type ]
194
+ for i , param_idx_in_group , param_id in bucket .params :
198
195
param = self .bit16_groups [i ][param_idx_in_group ]
199
196
200
197
if len (param .shape ) == 1 :
@@ -255,7 +252,7 @@ def update_selected_channels(self, tensor, total_size):
255
252
index_slice = self .index_buffer .narrow (0 , offset , num_select )
256
253
dist .broadcast (index_slice , src = src_rank , group = process_group )
257
254
258
- for i , param_idx_in_group , param_id in self . params_in_ipg_bucket :
255
+ for i , param_idx_in_group , param_id in bucket . params :
259
256
param = self .bit16_groups [i ][param_idx_in_group ]
260
257
261
258
if len (param .shape ) == 1 :
@@ -281,15 +278,15 @@ def update_selected_channels(self, tensor, total_size):
281
278
282
279
self .index_buffer = None
283
280
284
- def process_selected_fp32_groups_grad (self , tensor , total_size ):
281
+ def _process_selected_fp32_groups_grad (self , tensor , total_size , communication_data_type ):
285
282
"""
286
283
Process gradients for selected columns in FP32 groups
287
284
288
285
Args:
289
286
param: The parameter to process
290
287
param_id: ID of the parameter
291
288
"""
292
- print ( "Yes!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!process_selected_fp32_groups_grad" )
289
+
293
290
curr_size = 0
294
291
curr_grad_buffer_size = 0
295
292
curr_sum_buffer_size = 0
@@ -309,7 +306,8 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
309
306
group_to_paramlist = {}
310
307
311
308
# count = 0
312
- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
309
+ bucket = self .ipg_buckets [communication_data_type ]
310
+ for i , param_idx_in_group , param_id in bucket .params :
313
311
param = self .bit16_groups [i ][param_idx_in_group ]
314
312
315
313
if not hasattr (param , 'selected_indices' ):
@@ -389,7 +387,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
389
387
sum_slice = self .sum_buffer .narrow (0 , sum_offset , sum_num )
390
388
dist .broadcast (sum_slice , src = src_rank , group = process_group )
391
389
392
- for i , param_idx_in_group , param_id in self . params_in_ipg_bucket :
390
+ for i , param_idx_in_group , param_id in bucket . params :
393
391
param = self .bit16_groups [i ][param_idx_in_group ]
394
392
395
393
selected_grad = None
@@ -450,7 +448,7 @@ def process_selected_fp32_groups_grad(self, tensor, total_size):
450
448
if self .auto_update :
451
449
self .sum_buffer = None
452
450
453
- def average_tensor (self , tensor ):
451
+ def average_tensor (self , tensor : torch . Tensor , communication_data_type : torch . dtype ):
454
452
if self .overlap_comm :
455
453
stream = self .reduction_stream
456
454
if not get_accelerator ().resolves_data_dependency ():
@@ -478,12 +476,13 @@ def average_tensor(self, tensor):
478
476
479
477
process_group = self .dp_process_group
480
478
# count = 0
481
- for i , param_idx_in_group , param_id in self .params_in_ipg_bucket :
479
+ bucket = self .ipg_buckets [communication_data_type ]
480
+ for i , param_idx_in_group , param_id in bucket .params :
482
481
param = self .bit16_groups [i ][param_idx_in_group ]
483
482
484
483
process_group = self .dp_process_group
485
484
486
- if self . ipg_bucket_has_moe_params :
485
+ if bucket . has_moe_params :
487
486
process_group = self .expert_dp_process_group [param .group_name ] if is_moe_param (
488
487
param ) else self .dp_process_group
489
488
@@ -546,12 +545,14 @@ def average_tensor(self, tensor):
546
545
for bucket_key in buckets :
547
546
if self .use_multi_rank_bucket_allreduce :
548
547
self .allreduce_and_scatter (buckets [bucket_key ],
548
+ communication_data_type ,
549
549
numel_per_bucket = self .reduce_bucket_size ,
550
550
divide = False ,
551
551
process_group = bucket_key )
552
552
else :
553
553
dst , process_group = bucket_key
554
554
self .allreduce_no_retain (buckets [bucket_key ],
555
+ communication_data_type ,
555
556
numel_per_bucket = self .reduce_bucket_size ,
556
557
rank = dst ,
557
558
divide = False ,
@@ -560,15 +561,15 @@ def average_tensor(self, tensor):
560
561
if self .is_zenflow_select_boundary ():
561
562
self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).start ()
562
563
# print("update selected")
563
- self .update_selected_channels (tensor , curr_column_size )
564
+ self .update_selected_channels (tensor , curr_column_size , communication_data_type )
564
565
self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).stop ()
565
566
elif self .zenflow :
566
567
self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).start ()
567
568
self .timers (SELECTIVE_OPTIMIZER_UPDATE_TIMER ).stop ()
568
569
569
570
if self .zenflow and self .micro_step >= self .full_warm_up_rounds :
570
571
self .timers (SELECTIVE_OPTIMIZER_PROCESS_TIMER ).start ()
571
- self .process_selected_fp32_groups_grad (tensor , curr_selected_reduce_size )
572
+ self ._process_selected_fp32_groups_grad (tensor , curr_selected_reduce_size , communication_data_type )
572
573
self .timers (SELECTIVE_OPTIMIZER_PROCESS_TIMER ).stop ()
573
574
574
575
def backward (self , loss , retain_graph = False ):
0 commit comments