@@ -162,20 +162,20 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
162
162
if cpu_offload is not None :
163
163
raise ValueError ('FSDP CPU Offload not supported yet.' )
164
164
165
- mixed_precision = fsdp_config .get ('mixed_precision' , 'default' )
165
+ mixed_precision = fsdp_config .get ('mixed_precision' , 'DEFAULT' ). upper ( )
166
166
if isinstance (mixed_precision , dict ):
167
167
param_dtype = get_torch_dtype (mixed_precision .get ('param_dtype' , 'float32' ))
168
168
reduce_dtype = get_torch_dtype (mixed_precision .get ('reduce_dtype' , 'float32' ))
169
169
buffer_dtype = get_torch_dtype (mixed_precision .get ('buffer_dtype' , 'float32' ))
170
- elif mixed_precision == 'full ' :
170
+ elif mixed_precision == 'FULL ' :
171
171
param_dtype = torch .float32
172
172
reduce_dtype = torch .float32
173
173
buffer_dtype = torch .float32
174
- elif mixed_precision == 'default ' :
174
+ elif mixed_precision == 'DEFAULT ' :
175
175
param_dtype = torch .float32
176
176
reduce_dtype = get_torch_dtype (precision )
177
177
buffer_dtype = torch .float32
178
- elif mixed_precision == 'pure ' :
178
+ elif mixed_precision == 'PURE ' :
179
179
param_dtype = get_torch_dtype (precision )
180
180
reduce_dtype = get_torch_dtype (precision )
181
181
buffer_dtype = get_torch_dtype (precision )
0 commit comments