@@ -162,20 +162,20 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
162
162
if cpu_offload is not None :
163
163
raise ValueError ('FSDP CPU Offload not supported yet.' )
164
164
165
- mixed_precision = fsdp_config .get ('mixed_precision' , 'default' )
165
+ mixed_precision = fsdp_config .get ('mixed_precision' , 'DEFAULT' ). upper ( )
166
166
if isinstance (mixed_precision , dict ):
167
167
param_dtype = get_torch_dtype (mixed_precision .get ('param_dtype' , 'float32' ))
168
168
reduce_dtype = get_torch_dtype (mixed_precision .get ('reduce_dtype' , 'float32' ))
169
169
buffer_dtype = get_torch_dtype (mixed_precision .get ('buffer_dtype' , 'float32' ))
170
- elif mixed_precision == 'full ' :
170
+ elif mixed_precision == 'FULL ' :
171
171
param_dtype = torch .float32
172
172
reduce_dtype = torch .float32
173
173
buffer_dtype = torch .float32
174
- elif mixed_precision == 'default ' :
174
+ elif mixed_precision == 'DEFAULT ' :
175
175
param_dtype = torch .float32
176
176
reduce_dtype = get_torch_dtype (precision )
177
177
buffer_dtype = torch .float32
178
- elif mixed_precision == 'pure ' :
178
+ elif mixed_precision == 'PURE ' :
179
179
param_dtype = get_torch_dtype (precision )
180
180
reduce_dtype = get_torch_dtype (precision )
181
181
buffer_dtype = get_torch_dtype (precision )
@@ -194,7 +194,7 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
194
194
'BACKWARD_POST' : BackwardPrefetch .BACKWARD_POST ,
195
195
}
196
196
backward_prefetch = backward_prefetch_map [fsdp_config .get ('backward_prefetch' , 'BACKWARD_POST' ).upper ()]
197
- min_params = int (float (fsdp_config .get ('min_params' , 1e8 )))
197
+ min_params = int (float (fsdp_config .get ('min_params' , 1e9 )))
198
198
activation_checkpointing = fsdp_config .get ('activation_checkpointing' , False )
199
199
activation_cpu_offload = fsdp_config .get ('activation_cpu_offload' , False )
200
200
@@ -238,6 +238,7 @@ def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, unwrapped_params:
238
238
mixed_precision = mixed_precision ,
239
239
backward_prefetch = backward_prefetch ,
240
240
param_init_fn = _param_init_fn ,
241
+ device_id = torch .cuda .current_device (),
241
242
)
242
243
243
244
# Activation Checkpointing
0 commit comments