Skip to content

Commit e444327

Browse files
committed
make mixed_precision types upper case
1 parent 3b40a8c commit e444327

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

composer/trainer/dist_strategy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,20 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
162162
if cpu_offload is not None:
163163
raise ValueError('FSDP CPU Offload not supported yet.')
164164

165-
mixed_precision = fsdp_config.get('mixed_precision', 'default')
165+
mixed_precision = fsdp_config.get('mixed_precision', 'DEFAULT').upper()
166166
if isinstance(mixed_precision, dict):
167167
param_dtype = get_torch_dtype(mixed_precision.get('param_dtype', 'float32'))
168168
reduce_dtype = get_torch_dtype(mixed_precision.get('reduce_dtype', 'float32'))
169169
buffer_dtype = get_torch_dtype(mixed_precision.get('buffer_dtype', 'float32'))
170-
elif mixed_precision == 'full':
170+
elif mixed_precision == 'FULL':
171171
param_dtype = torch.float32
172172
reduce_dtype = torch.float32
173173
buffer_dtype = torch.float32
174-
elif mixed_precision == 'default':
174+
elif mixed_precision == 'DEFAULT':
175175
param_dtype = torch.float32
176176
reduce_dtype = get_torch_dtype(precision)
177177
buffer_dtype = torch.float32
178-
elif mixed_precision == 'pure':
178+
elif mixed_precision == 'PURE':
179179
param_dtype = get_torch_dtype(precision)
180180
reduce_dtype = get_torch_dtype(precision)
181181
buffer_dtype = get_torch_dtype(precision)

0 commit comments

Comments
 (0)