Skip to content

Commit 4fd1f34

Browse files
authored
Abhi/fsdp bugfix 0 11 (#1623)
1 parent 56f0e33 commit 4fd1f34

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

composer/trainer/dist_strategy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,20 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
162162
if cpu_offload is not None:
163163
raise ValueError('FSDP CPU Offload not supported yet.')
164164

165-
mixed_precision = fsdp_config.get('mixed_precision', 'default')
165+
mixed_precision = fsdp_config.get('mixed_precision', 'DEFAULT').upper()
166166
if isinstance(mixed_precision, dict):
167167
param_dtype = get_torch_dtype(mixed_precision.get('param_dtype', 'float32'))
168168
reduce_dtype = get_torch_dtype(mixed_precision.get('reduce_dtype', 'float32'))
169169
buffer_dtype = get_torch_dtype(mixed_precision.get('buffer_dtype', 'float32'))
170-
elif mixed_precision == 'full':
170+
elif mixed_precision == 'FULL':
171171
param_dtype = torch.float32
172172
reduce_dtype = torch.float32
173173
buffer_dtype = torch.float32
174-
elif mixed_precision == 'default':
174+
elif mixed_precision == 'DEFAULT':
175175
param_dtype = torch.float32
176176
reduce_dtype = get_torch_dtype(precision)
177177
buffer_dtype = torch.float32
178-
elif mixed_precision == 'pure':
178+
elif mixed_precision == 'PURE':
179179
param_dtype = get_torch_dtype(precision)
180180
reduce_dtype = get_torch_dtype(precision)
181181
buffer_dtype = get_torch_dtype(precision)
@@ -194,7 +194,7 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
194194
'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
195195
}
196196
backward_prefetch = backward_prefetch_map[fsdp_config.get('backward_prefetch', 'BACKWARD_POST').upper()]
197-
min_params = int(float(fsdp_config.get('min_params', 1e8)))
197+
min_params = int(float(fsdp_config.get('min_params', 1e9)))
198198
activation_checkpointing = fsdp_config.get('activation_checkpointing', False)
199199
activation_cpu_offload = fsdp_config.get('activation_cpu_offload', False)
200200

@@ -238,6 +238,7 @@ def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, unwrapped_params:
238238
mixed_precision=mixed_precision,
239239
backward_prefetch=backward_prefetch,
240240
param_init_fn=_param_init_fn,
241+
device_id=torch.cuda.current_device(),
241242
)
242243

243244
# Activation Checkpointing

0 commit comments

Comments
 (0)