From 83709aea18fad8201c9a290125ceac47f7a755e2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 5 Jul 2025 07:11:43 -0700 Subject: [PATCH] Add a min layer-decay scale clamp, and no optimization threshold to exclude groups from optimization --- timm/optim/_optim_factory.py | 51 ++++++++++++++++++++++-------------- timm/optim/_param_groups.py | 12 ++++++--- train.py | 4 +++ 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 6a4db6b827..559f91e6fa 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -234,6 +234,8 @@ def create_optimizer( foreach: Optional[bool] = None, weight_decay_exclude_1d: bool = True, layer_decay: Optional[float] = None, + layer_decay_min_scale: Optional[float] = None, + layer_decay_no_opt_scale: Optional[float] = None, param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None, **kwargs: Any, ) -> torch.optim.Optimizer: @@ -248,6 +250,8 @@ def create_optimizer( foreach: Enable/disable foreach operation weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine) layer_decay: Layer-wise learning rate decay + layer_scale_min_scale: Minimum layer scale factor clamp value + layer_scale_no_opt_scale: Layer scale below which optimization is disabled param_group_fn: Optional custom parameter grouping function **kwargs: Additional optimizer-specific arguments @@ -273,6 +277,8 @@ def create_optimizer( layer_decay=layer_decay, no_weight_decay_list=no_weight_decay, weight_decay_exclude_1d=weight_decay_exclude_1d, + min_scale=layer_decay_min_scale, + no_opt_scale=layer_decay_no_opt_scale, ) weight_decay = 0. elif weight_decay and weight_decay_exclude_1d: @@ -1140,6 +1146,8 @@ def create_optimizer_v2( foreach: Optional[bool] = None, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, + layer_decay_min_scale: float = 0.0, + layer_decay_no_opt_scale: Optional[float] = None, param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None, **kwargs: Any, ) -> torch.optim.Optimizer: @@ -1215,31 +1223,36 @@ def create_optimizer_v2( foreach=foreach, weight_decay_exclude_1d=filter_bias_and_bn, layer_decay=layer_decay, + layer_decay_min_scale=layer_decay_min_scale, + layer_decay_no_opt_scale=layer_decay_no_opt_scale, param_group_fn=param_group_fn, **kwargs ) def optimizer_kwargs(cfg): - """ cfg/argparse to kwargs helper - Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. - """ - kwargs = dict( - opt=cfg.opt, - lr=cfg.lr, - weight_decay=cfg.weight_decay, - momentum=cfg.momentum, - ) - if getattr(cfg, 'opt_eps', None) is not None: - kwargs['eps'] = cfg.opt_eps - if getattr(cfg, 'opt_betas', None) is not None: - kwargs['betas'] = cfg.opt_betas - if getattr(cfg, 'layer_decay', None) is not None: - kwargs['layer_decay'] = cfg.layer_decay - if getattr(cfg, 'opt_args', None) is not None: - kwargs.update(cfg.opt_args) - if getattr(cfg, 'opt_foreach', None) is not None: - kwargs['foreach'] = cfg.opt_foreach + """Convert argparse-style `cfg` object to kwargs for an optimizer factory.""" + kwargs = { + 'opt': cfg.opt, + 'lr': cfg.lr, + 'weight_decay': cfg.weight_decay, + 'momentum': cfg.momentum, + } + if (eps := getattr(cfg, 'opt_eps', None)) is not None: + kwargs['eps'] = eps + if (betas := getattr(cfg, 'opt_betas', None)) is not None: + kwargs['betas'] = betas + if (layer_decay := getattr(cfg, 'layer_decay', None)) is not None: + kwargs['layer_decay'] = layer_decay + if (ld_min := getattr(cfg, 'layer_decay_min_scale', None)) is not None: + kwargs['layer_decay_min_scale'] = ld_min + if (ld_no_opt := getattr(cfg, 'layer_decay_no_opt_scale', None)) is not None: + kwargs['layer_decay_no_opt_scale'] = ld_no_opt + if (opt_args := getattr(cfg, 'opt_args', None)) is not None: + kwargs.update(opt_args) + if (foreach := getattr(cfg, 'opt_foreach', None)) is not None: + kwargs['foreach'] = foreach + return kwargs diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index a756c5e0c0..5048736ece 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -73,6 +73,8 @@ def param_groups_layer_decay( weight_decay_exclude_1d: bool = True, layer_decay: float = .75, end_layer_decay: Optional[float] = None, + min_scale: float = 0., + no_opt_scale: Optional[float] = None, verbose: bool = False, ): """ @@ -91,7 +93,7 @@ def param_groups_layer_decay( layer_map = auto_group_layers(model) num_layers = max(layer_map.values()) + 1 layer_max = num_layers - 1 - layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + layer_scales = list(max(min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)) for name, param in model.named_parameters(): if not param.requires_grad: @@ -106,10 +108,14 @@ def param_groups_layer_decay( this_decay = weight_decay layer_id = layer_map.get(name, layer_max) - group_name = "layer_%d_%s" % (layer_id, g_decay) + this_scale = layer_scales[layer_id] + if no_opt_scale and this_scale < no_opt_scale: + # if the calculated layer scale is below this, exclude from optimization + param.requires_grad = False + continue + group_name = "layer_%d_%s" % (layer_id, g_decay) if group_name not in param_groups: - this_scale = layer_scales[layer_id] param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, diff --git a/train.py b/train.py index ca48538771..6641052274 100755 --- a/train.py +++ b/train.py @@ -206,6 +206,10 @@ help='Gradient clipping mode. One of ("norm", "value", "agc")') group.add_argument('--layer-decay', type=float, default=None, help='layer-wise learning rate decay (default: None)') +group.add_argument('--layer-decay-min-scale', type=float, default=0, + help='layer-wise lr decay minimum scale clamp (default: 0)') +group.add_argument('--layer-decay-no-opt-scale', type=float, default=None, + help='layer-wise lr decay no optimization scale (default: None)') group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs) # Learning rate schedule parameters