Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,20 @@ def sync_hook(*args):
param_pointer_to_param_name = {id(p): n for n, p in model.named_parameters()}
param_name_to_group_num = {}
group_num_to_param_group_info = {}
for group_num, group in enumerate(optim.param_groups):
for param in group['params']:
param_name_to_group_num[param_pointer_to_param_name[id(param)]] = group_num
for group_num in range(len(optim.param_groups)):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# group = optim.param_groups[group_num]
for param_num in range(len(optim.param_groups[group_num]['params'])):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# param = optim.param_groups[group_num]['params'][param_num]
param_name_to_group_num[param_pointer_to_param_name[id(
optim.param_groups[group_num]['params'][param_num])]] = group_num

# this includes optimizer-specific values like lr, eps
# this will be used as the kwargs for the optim param groups later
optimizer_specific_group_info = {k: v for k, v in group.items() if k != 'params'}
optimizer_specific_group_info = {
k: v for k, v in optim.param_groups[group_num].items() if k != 'params'
}
group_num_to_param_group_info[group_num] = optimizer_specific_group_info
else:
optimizer_specific_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'}
Expand Down