Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def create_optimizer(
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
layer_decay: Optional[float] = None,
layer_decay_min_scale: Optional[float] = None,
layer_decay_no_opt_scale: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> torch.optim.Optimizer:
Expand All @@ -248,6 +250,8 @@ def create_optimizer(
foreach: Enable/disable foreach operation
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
layer_decay: Layer-wise learning rate decay
layer_scale_min_scale: Minimum layer scale factor clamp value
layer_scale_no_opt_scale: Layer scale below which optimization is disabled
param_group_fn: Optional custom parameter grouping function
**kwargs: Additional optimizer-specific arguments

Expand All @@ -273,6 +277,8 @@ def create_optimizer(
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
weight_decay_exclude_1d=weight_decay_exclude_1d,
min_scale=layer_decay_min_scale,
no_opt_scale=layer_decay_no_opt_scale,
)
weight_decay = 0.
elif weight_decay and weight_decay_exclude_1d:
Expand Down Expand Up @@ -1140,6 +1146,8 @@ def create_optimizer_v2(
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
layer_decay_min_scale: float = 0.0,
layer_decay_no_opt_scale: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -1215,31 +1223,36 @@ def create_optimizer_v2(
foreach=foreach,
weight_decay_exclude_1d=filter_bias_and_bn,
layer_decay=layer_decay,
layer_decay_min_scale=layer_decay_min_scale,
layer_decay_no_opt_scale=layer_decay_no_opt_scale,
param_group_fn=param_group_fn,
**kwargs
)


def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(
opt=cfg.opt,
lr=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'layer_decay', None) is not None:
kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
"""Convert argparse-style `cfg` object to kwargs for an optimizer factory."""
kwargs = {
'opt': cfg.opt,
'lr': cfg.lr,
'weight_decay': cfg.weight_decay,
'momentum': cfg.momentum,
}
if (eps := getattr(cfg, 'opt_eps', None)) is not None:
kwargs['eps'] = eps
if (betas := getattr(cfg, 'opt_betas', None)) is not None:
kwargs['betas'] = betas
if (layer_decay := getattr(cfg, 'layer_decay', None)) is not None:
kwargs['layer_decay'] = layer_decay
if (ld_min := getattr(cfg, 'layer_decay_min_scale', None)) is not None:
kwargs['layer_decay_min_scale'] = ld_min
if (ld_no_opt := getattr(cfg, 'layer_decay_no_opt_scale', None)) is not None:
kwargs['layer_decay_no_opt_scale'] = ld_no_opt
if (opt_args := getattr(cfg, 'opt_args', None)) is not None:
kwargs.update(opt_args)
if (foreach := getattr(cfg, 'opt_foreach', None)) is not None:
kwargs['foreach'] = foreach

return kwargs


Expand Down
12 changes: 9 additions & 3 deletions timm/optim/_param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def param_groups_layer_decay(
weight_decay_exclude_1d: bool = True,
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
min_scale: float = 0.,
no_opt_scale: Optional[float] = None,
verbose: bool = False,
):
"""
Expand All @@ -91,7 +93,7 @@ def param_groups_layer_decay(
layer_map = auto_group_layers(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
layer_scales = list(max(min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers))

for name, param in model.named_parameters():
if not param.requires_grad:
Expand All @@ -106,10 +108,14 @@ def param_groups_layer_decay(
this_decay = weight_decay

layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)
this_scale = layer_scales[layer_id]
if no_opt_scale and this_scale < no_opt_scale:
# if the calculated layer scale is below this, exclude from optimization
param.requires_grad = False
continue

group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
group.add_argument('--layer-decay-min-scale', type=float, default=0,
help='layer-wise lr decay minimum scale clamp (default: 0)')
group.add_argument('--layer-decay-no-opt-scale', type=float, default=None,
help='layer-wise lr decay no optimization scale (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)

# Learning rate schedule parameters
Expand Down