diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index f4cfeb7060..6a4db6b827 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -532,6 +532,132 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None: registry.register(opt) +def _register_corrected_decay_optimizers(registry: OptimizerRegistry) -> None: + """Register corrected weight decay optimizer variants""" + corrected_optimizers = [ + OptimInfo( + name='adamc', + opt_class=AdamWLegacy, + description='AdamW with corrected weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'corrected_weight_decay': True} + ), + OptimInfo( + name='nadamc', + opt_class=NAdamW, + description='NAdamW with corrected weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'corrected_weight_decay': True} + ), + OptimInfo( + name='sgdc', + opt_class=SGDW, + description='SGD with corrected decoupled weight decay (lr²/max_lr scaling)', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='adoptc', + opt_class=Adopt, + description='Adopt with corrected decoupled weight decay (lr²/max_lr scaling)', + defaults={'decoupled': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='lambcd', + opt_class=Lamb, + description='LAMB with corrected decoupled weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'decoupled_decay': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='kronc', + opt_class=Kron, + description='PSGD Kron with corrected decoupled weight decay (lr²/max_lr scaling)', + has_momentum=True, + defaults={'decoupled_decay': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='lionc', + opt_class=Lion, + description='Lion with corrected weight decay (lr²/max_lr scaling)', + has_eps=False, + has_betas=True, + defaults={'corrected_weight_decay': True} + ), + OptimInfo( + name='lapropc', + opt_class=LaProp, + description='LaProp with corrected weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'corrected_weight_decay': True} + ), + OptimInfo( + name='rmsproptfc', + opt_class=RMSpropTF, + description='RMSprop TF-style with corrected decoupled weight decay (lr²/max_lr scaling)', + has_momentum=True, + defaults={'alpha': 0.9, 'decoupled_decay': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='adafactorbvc', + opt_class=AdafactorBigVision, + description='Adafactor Big Vision with corrected weight decay (lr²/max_lr or lr/max_lr scaling)', + defaults={'corrected_weight_decay': True} + ), + ] + for opt in corrected_optimizers: + registry.register(opt) + + # Cautious + corrected variants + cautious_corrected = [ + OptimInfo( + name='cadamc', + opt_class=AdamWLegacy, + description='Cautious AdamW with corrected weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'caution': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='cadoptc', + opt_class=Adopt, + description='Cautious Adopt with corrected decoupled weight decay (lr²/max_lr scaling)', + defaults={'decoupled': True, 'caution': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='cnadamc', + opt_class=NAdamW, + description='Cautious NAdamW with corrected weight decay (lr²/max_lr scaling)', + has_betas=True, + defaults={'caution': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='csgdc', + opt_class=SGDW, + description='Cautious SGD with corrected decoupled weight decay (lr²/max_lr scaling)', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True, 'caution': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='clionc', + opt_class=Lion, + description='Cautious Lion with corrected weight decay (lr²/max_lr scaling)', + has_eps=False, + has_betas=True, + defaults={'caution': True, 'corrected_weight_decay': True} + ), + OptimInfo( + name='cadafactorbvc', + opt_class=AdafactorBigVision, + description='Cautious Adafactor Big Vision with corrected weight decay', + defaults={'caution': True, 'corrected_weight_decay': True} + ), + ] + for opt in cautious_corrected: + registry.register(opt) + + def _register_cautious_optimizers(registry: OptimizerRegistry) -> None: cautious_optimizers = [ OptimInfo( @@ -896,6 +1022,7 @@ def _register_default_optimizers() -> None: _register_apex_optimizers(default_registry) _register_bnb_optimizers(default_registry) _register_cautious_optimizers(default_registry) + _register_corrected_decay_optimizers(default_registry) # Register aliases default_registry.register_alias('nesterov', 'sgd') diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 298d43bb7e..9afd130f9d 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -4,6 +4,10 @@ Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560 +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + Adaptation and PyTorch modifications by Ross Wightman """ from typing import List, Optional, Tuple, Union @@ -68,6 +72,7 @@ def __init__( clipping_threshold: Optional[float] = None, unscaled_wd: bool = False, caution: bool = False, + corrected_weight_decay: bool = False, *, foreach: Optional[bool] = False, ): @@ -94,6 +99,7 @@ def __init__( clipping_threshold=clipping_threshold, unscaled_wd=unscaled_wd, caution=caution, + corrected_weight_decay=corrected_weight_decay, foreach=foreach, ) super().__init__(params, defaults) @@ -102,6 +108,7 @@ def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) group.setdefault('foreach', None) for p in group['params']: p_state = self.state.get(p, {}) @@ -197,6 +204,7 @@ def step(self, closure=None): clipping_threshold=group['clipping_threshold'], unscaled_wd=group['unscaled_wd'], caution=group['caution'], + max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None, ) return loss @@ -222,6 +230,7 @@ def _single_tensor_adafactor( clipping_threshold: Optional[float], unscaled_wd: bool, caution: bool, + max_lr: Optional[float], ): for i, param in enumerate(params): grad = grads[i] @@ -286,10 +295,18 @@ def _single_tensor_adafactor( if weight_decay != 0: if unscaled_wd: # match big vision impl, 'fully decoupled' decay w/o LR scaling - param.mul_(1. - weight_decay) + if max_lr is None: + param.mul_(1. - weight_decay) + else: + # corrected weight decay: scale by lr / max_lr + param.mul_(1. - (lr / max_lr) * weight_decay) else: # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR - param.mul_(1. - lr * weight_decay) + if max_lr is None: + param.mul_(1. - lr * weight_decay) + else: + # corrected weight decay: scale by lr^2 / max_lr + param.mul_(1. - (lr ** 2 / max_lr) * weight_decay) # Update parameters param.add_(update, alpha=-1.0) @@ -315,6 +332,7 @@ def _multi_tensor_adafactor( clipping_threshold: Optional[float], unscaled_wd: bool, caution: bool, + max_lr: Optional[float], ): # FIXME TODO assert False, 'multi-tensor fn (foreach=True) not implemented yet' diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index 07299ad63e..537420309f 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -1,6 +1,10 @@ """ AdamW Optimizer Impl copied from PyTorch master +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference """ import math @@ -31,6 +35,7 @@ class AdamWLegacy(Optimizer): amsgrad: whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond` caution: apply caution when using AdamW + corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) """ def __init__( @@ -42,6 +47,7 @@ def __init__( weight_decay: float = 1e-2, amsgrad: bool = False, caution: bool = False, + corrected_weight_decay: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -58,6 +64,7 @@ def __init__( weight_decay=weight_decay, amsgrad=amsgrad, caution=caution, + corrected_weight_decay=corrected_weight_decay, ) super(AdamWLegacy, self).__init__(params, defaults) @@ -66,6 +73,7 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault('amsgrad', False) group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) @torch.no_grad() def step(self, closure=None): @@ -86,7 +94,8 @@ def step(self, closure=None): continue # Perform stepweight decay - p.data.mul_(1 - group['lr'] * group['weight_decay']) + wd_scale = group['lr'] if not group['corrected_weight_decay'] else group['lr'] ** 2 / self.defaults['lr'] + p.data.mul_(1 - wd_scale * group['weight_decay']) # Perform optimization step grad = p.grad diff --git a/timm/optim/adopt.py b/timm/optim/adopt.py index 6192990ed6..4acfc2d11d 100644 --- a/timm/optim/adopt.py +++ b/timm/optim/adopt.py @@ -10,6 +10,10 @@ title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate}, year = {2024} } + +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 """ from typing import cast, List, Optional, Tuple, Union @@ -66,6 +70,7 @@ def __init__( clip_exp: Optional[float] = 0.333, weight_decay: float = 0.0, decoupled: bool = False, + corrected_weight_decay: bool = False, *, caution: bool = False, foreach: Optional[bool] = False, @@ -98,6 +103,7 @@ def __init__( weight_decay=weight_decay, clip_exp=clip_exp, decoupled=decoupled, + corrected_weight_decay=corrected_weight_decay, caution=caution, maximize=maximize, foreach=foreach, @@ -115,6 +121,7 @@ def __setstate__(self, state): group.setdefault("differentiable", False) group.setdefault("clip_exp", None) group.setdefault("caution", False) + group.setdefault("corrected_weight_decay", False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): @@ -222,6 +229,7 @@ def step(self, closure=None): lr=group["lr"], weight_decay=group["weight_decay"], clip_exp=group["clip_exp"], + max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None, decoupled=group["decoupled"], eps=group["eps"], caution=group["caution"], @@ -251,6 +259,7 @@ def _single_tensor_adopt( lr: Union[float, Tensor], weight_decay: float, clip_exp: Optional[float], + max_lr: Optional[float], decoupled: bool, eps: float, caution: bool, @@ -299,7 +308,8 @@ def _single_tensor_adopt( continue if weight_decay != 0 and decoupled: - param.add_(param, alpha=-lr * weight_decay) + wd_scale = lr ** 2 / max_lr if max_lr is not None else lr + param.add_(param, alpha=-wd_scale * weight_decay) denom = torch.clamp(exp_avg_sq.sqrt(), eps) normed_grad = grad.div(denom) @@ -336,6 +346,7 @@ def _multi_tensor_adopt( lr: Union[float, Tensor], weight_decay: float, clip_exp: Optional[float], + max_lr: Optional[float], decoupled: bool, eps: float, caution: bool, @@ -410,7 +421,8 @@ def _multi_tensor_adopt( continue if weight_decay != 0 and decoupled: - torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay) + wd_scale = lr ** 2 / max_lr if max_lr is not None else lr + torch._foreach_add_(device_params, device_params, alpha=-wd_scale * weight_decay) exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) torch._foreach_maximum_(exp_avg_sq_sqrt, eps) @@ -460,6 +472,7 @@ def adopt( lr: Union[float, Tensor], weight_decay: float, clip_exp: Optional[float], + max_lr: Optional[float], decoupled: bool, eps: float, caution: bool, @@ -498,6 +511,7 @@ def adopt( lr=lr, weight_decay=weight_decay, clip_exp=clip_exp, + max_lr=max_lr, decoupled=decoupled, eps=eps, caution=caution, diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 23fa8235b6..2d1f03496e 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -11,6 +11,10 @@ The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + This `timm` impl * works with a wider variety of torch versions * fixes some checkpoint save/restore (resume issues) @@ -94,6 +98,7 @@ class Kron(torch.optim.Optimizer): mu_dtype: Dtype of the momentum accumulator. precond_dtype: Dtype of the preconditioner. decoupled_decay: AdamW style decoupled weight decay + corrected_weight_decay: apply corrected weight decay when using decoupled_decay (lr**2 / max_lr) flatten: Flatten dimensions instead of fully relying on expressions for higher rank params flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets. flatten_end_dim: End of flatten range, defaults to -1. @@ -117,6 +122,7 @@ def __init__( mu_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, + corrected_weight_decay: bool = False, flatten: bool = False, flatten_start_dim: int = 2, flatten_end_dim: int = -1, @@ -147,6 +153,7 @@ def __init__( mu_dtype=mu_dtype, precond_dtype=precond_dtype, decoupled_decay=decoupled_decay, + corrected_weight_decay=corrected_weight_decay, flatten=flatten, flatten_start_dim=flatten_start_dim, flatten_end_dim=flatten_end_dim, @@ -171,6 +178,11 @@ def __init__( self._precond_grad = _precond_grad self._balance_Q = _balance_Q + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('corrected_weight_decay', False) + def __getstate__(self): _dict = super().__getstate__() _dict["rng"] = self.rng @@ -353,7 +365,11 @@ def step(self, closure=None): weight_decay = 2 * self.rng.random() * weight_decay if group["decoupled_decay"]: - p.mul_(1. - group["lr"] * weight_decay) + if group['corrected_weight_decay']: + wd_scale = group["lr"] ** 2 / self.defaults['lr'] + else: + wd_scale = group["lr"] + p.mul_(1. - wd_scale * weight_decay) else: pre_grad.add_(p, alpha=weight_decay) diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index fa86757441..13e2900439 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -10,6 +10,10 @@ In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + Original copyrights for above sources are below. Modifications Copyright 2021 Ross Wightman @@ -79,6 +83,8 @@ class Lamb(Optimizer): trust_clip: Enable LAMBC trust ratio clipping. always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter. caution: Apply caution. + decoupled: apply decoupled weight decay + corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) when using decoupled_decay """ def __init__( @@ -95,6 +101,7 @@ def __init__( always_adapt: bool = False, caution: bool = False, decoupled_decay: bool = False, + corrected_weight_decay: bool = False, ): defaults = dict( lr=lr, @@ -108,6 +115,7 @@ def __init__( always_adapt=always_adapt, caution=caution, decoupled_decay=decoupled_decay, + corrected_weight_decay=corrected_weight_decay, ) super().__init__(params, defaults) @@ -116,6 +124,7 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault('caution', False) group.setdefault('decoupled_decay', False) + group.setdefault('corrected_weight_decay', False) def _get_clip_grad_norm(self): max_grad_norm = self.defaults['max_grad_norm'] @@ -203,7 +212,11 @@ def step(self, closure=None): weight_decay = group['weight_decay'] if weight_decay != 0: if group.get('decoupled_decay', False): - p.add_(p, alpha=-group['lr'] * weight_decay) + if group['corrected_weight_decay']: + wd_scale = group['lr'] ** 2 / self.defaults['lr'] + else: + wd_scale = group['lr'] + p.add_(p, alpha=-wd_scale * weight_decay) else: update.add_(p, alpha=weight_decay) diff --git a/timm/optim/laprop.py b/timm/optim/laprop.py index a17c81e6ad..79e7ecb6f7 100644 --- a/timm/optim/laprop.py +++ b/timm/optim/laprop.py @@ -11,6 +11,10 @@ year={2020} } +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + """ from typing import Tuple @@ -33,6 +37,7 @@ def __init__( eps: float = 1e-15, weight_decay: float = 0., caution: bool = False, + corrected_weight_decay: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -48,9 +53,16 @@ def __init__( eps=eps, weight_decay=weight_decay, caution=caution, + corrected_weight_decay=corrected_weight_decay, ) super(LaProp, self).__init__(params, defaults) + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -116,6 +128,10 @@ def step(self, closure=None): p.add_(exp_avg, alpha=-step_size) if group['weight_decay'] != 0: - p.add_(p, alpha=-(group['lr'] * group['weight_decay'])) + if group['corrected_weight_decay']: + wd_scale = group['lr'] ** 2 / self.defaults['lr'] + else: + wd_scale = group['lr'] + p.add_(p, alpha=-wd_scale * group['weight_decay']) return loss \ No newline at end of file diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 1860723203..5368206ee4 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -1,6 +1,10 @@ """ Lion Optimizer Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675 Original Impl: https://github.com/google/automl/tree/master/lion + +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 """ # Copyright 2023 Google Research. All Rights Reserved. # @@ -34,17 +38,19 @@ def __init__( betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, caution: bool = False, + corrected_weight_decay: bool = False, maximize: bool = False, foreach: Optional[bool] = None, ): """Initialize the hyperparameters. Args: - params: iterable of parameters to optimize or dicts defining parameter groups - lr: learning rate - betas: coefficients used for computing running averages of gradient and its square - weight_decay: weight decay coefficient - caution: apply caution + params: iterable of parameters to optimize or dicts defining parameter groups + lr: learning rate + betas: coefficients used for computing running averages of gradient and its square + weight_decay: weight decay coefficient + caution: apply caution + corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) """ if not 0.0 <= lr: @@ -58,6 +64,7 @@ def __init__( betas=betas, weight_decay=weight_decay, caution=caution, + corrected_weight_decay=corrected_weight_decay, foreach=foreach, maximize=maximize, ) @@ -67,6 +74,7 @@ def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) group.setdefault('maximize', False) group.setdefault('foreach', None) @@ -75,10 +83,10 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure: A closure that reevaluates the model and returns the loss. + closure: A closure that reevaluates the model and returns the loss. Returns: - the loss. + the loss. """ loss = None if closure is not None: @@ -118,6 +126,7 @@ def step(self, closure=None): caution=group['caution'], maximize=group['maximize'], foreach=group['foreach'], + max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None, ) return loss @@ -137,6 +146,7 @@ def lion( lr: float, weight_decay: float, caution: bool, + max_lr: Optional[float] = None, ): r"""Functional API that performs Lion algorithm computation. """ @@ -165,6 +175,7 @@ def lion( weight_decay=weight_decay, caution=caution, maximize=maximize, + max_lr=max_lr, ) @@ -179,6 +190,7 @@ def _single_tensor_lion( weight_decay: float, caution: bool, maximize: bool, + max_lr: Optional[float], ): for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] @@ -190,7 +202,8 @@ def _single_tensor_lion( param = torch.view_as_real(param) # Perform stepweight decay - param.mul_(1 - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + param.mul_(1 - wd_scale * weight_decay) # Weight update update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_() @@ -218,6 +231,7 @@ def _multi_tensor_lion( weight_decay: float, caution: bool, maximize: bool, + max_lr: Optional[float], ): if len(params) == 0: return @@ -230,7 +244,8 @@ def _multi_tensor_lion( params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] # Perform stepweight decay - torch._foreach_mul_(params, 1 - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + torch._foreach_mul_(params, 1 - wd_scale * weight_decay) # Weight update updates = torch._foreach_mul(exp_avgs, beta1) diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index d9933026c6..2a9df0390e 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -3,6 +3,10 @@ Based on simplified algorithm in https://github.com/mlcommons/algorithmic-efficiency/tree/main/baselines/nadamw Added multi-tensor (foreach) path. + +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 """ import math from typing import List, Optional, Tuple @@ -32,6 +36,7 @@ class NAdamW(torch.optim.Optimizer): eps: term added to the denominator to improve numerical stability weight_decay: weight decay coefficient caution: enable caution + corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) """ def __init__( @@ -42,6 +47,7 @@ def __init__( eps: float = 1e-8, weight_decay: float = 1e-2, caution: bool = False, + corrected_weight_decay: bool = False, maximize: bool = False, foreach: Optional[bool] = None, capturable: bool = False, @@ -62,6 +68,7 @@ def __init__( eps=eps, weight_decay=weight_decay, caution=caution, + corrected_weight_decay=corrected_weight_decay, foreach=foreach, maximize=maximize, capturable=capturable, @@ -77,6 +84,7 @@ def __setstate__(self, state): s['step'] = torch.tensor(float(s['step'])) for group in self.param_groups: group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) @torch.no_grad() def step(self, closure=None): @@ -137,6 +145,7 @@ def step(self, closure=None): caution=group['caution'], maximize=group['maximize'], capturable=group['capturable'], + max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None, ) return loss @@ -158,6 +167,7 @@ def nadamw( eps: float, caution: bool, maximize: bool, + max_lr: float, ) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. @@ -194,6 +204,7 @@ def nadamw( caution=caution, maximize=maximize, capturable=capturable, + max_lr=max_lr, ) @@ -211,7 +222,8 @@ def _single_tensor_nadamw( eps: float, caution: bool, maximize: bool, - capturable: bool + capturable: bool, + max_lr: Optional[float], ): for i, param in enumerate(params): @@ -224,7 +236,8 @@ def _single_tensor_nadamw( step_t += 1 # Perform stepweight decay. - param.mul_(1. - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + param.mul_(1. - wd_scale * weight_decay) # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) @@ -293,6 +306,7 @@ def _multi_tensor_nadamw( caution: bool, maximize: bool, capturable: bool, + max_lr: Optional[float], ): if len(params) == 0: return @@ -314,7 +328,8 @@ def _multi_tensor_nadamw( torch._foreach_add_(state_steps, 1) # Perform stepweight decay - torch._foreach_mul_(params, 1 - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + torch._foreach_mul_(params, 1 - wd_scale * weight_decay) # Decay the first and second moment running average coefficient torch._foreach_mul_(exp_avgs, beta1) diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py index 07b0279c85..6177090c7a 100644 --- a/timm/optim/rmsprop_tf.py +++ b/timm/optim/rmsprop_tf.py @@ -4,6 +4,10 @@ https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + Modifications Copyright 2021 Ross Wightman """ @@ -39,6 +43,7 @@ class RMSpropTF(Optimizer): centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance weight_decay: weight decay (L2 penalty) (default: 0) decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101 + corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) when decoupled_decay is True lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow caution: apply caution """ @@ -53,6 +58,7 @@ def __init__( momentum: float = 0., centered: bool = False, decoupled_decay: bool = False, + corrected_weight_decay: bool = False, lr_in_momentum: bool = True, caution: bool = False, ): @@ -75,6 +81,7 @@ def __init__( centered=centered, weight_decay=weight_decay, decoupled_decay=decoupled_decay, + corrected_weight_decay=corrected_weight_decay, lr_in_momentum=lr_in_momentum, caution=caution, ) @@ -86,6 +93,7 @@ def __setstate__(self, state): group.setdefault('momentum', 0) group.setdefault('centered', False) group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) @torch.no_grad() def step(self, closure=None): @@ -125,7 +133,11 @@ def step(self, closure=None): if group['weight_decay'] != 0: if group['decoupled_decay']: - p.mul_(1. - group['lr'] * group['weight_decay']) + if group['corrected_weight_decay']: + wd_scale = group['lr'] ** 2 / self.defaults['lr'] + else: + wd_scale = group['lr'] + p.mul_(1. - wd_scale * group['weight_decay']) else: grad = grad.add(p, alpha=group['weight_decay']) diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py index b771c43c67..da67eb3fa8 100644 --- a/timm/optim/sgdw.py +++ b/timm/optim/sgdw.py @@ -1,3 +1,11 @@ +""" SGD with decoupled weight-decay. + +References for added functionality: + Cautious Optimizers: https://arxiv.org/abs/2411.16085 + Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285 + +Hacked together by Ross Wightman +""" from typing import List, Optional import torch @@ -25,6 +33,7 @@ def __init__( nesterov: bool = False, *, caution: bool = False, + corrected_weight_decay: bool = False, maximize: bool = False, foreach: Optional[bool] = None, differentiable: bool = False, @@ -43,6 +52,7 @@ def __init__( weight_decay=weight_decay, nesterov=nesterov, caution=caution, + corrected_weight_decay=corrected_weight_decay, maximize=maximize, foreach=foreach, differentiable=differentiable, @@ -55,6 +65,7 @@ def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('caution', False) + group.setdefault('corrected_weight_decay', False) group.setdefault('nesterov', False) group.setdefault('maximize', False) group.setdefault('foreach', None) @@ -113,6 +124,7 @@ def step(self, closure=None): maximize=group['maximize'], has_sparse_grad=has_sparse_grad, foreach=group['foreach'], + max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None, ) # update momentum_buffers in state @@ -138,7 +150,8 @@ def sgdw( dampening: float, nesterov: bool, caution: bool, - maximize: bool + maximize: bool, + max_lr: Optional[float] = None ): r"""Functional API that performs SGD algorithm computation. @@ -175,6 +188,7 @@ def sgdw( caution=caution, has_sparse_grad=has_sparse_grad, maximize=maximize, + max_lr=max_lr, ) @@ -190,12 +204,14 @@ def _single_tensor_sgdw( nesterov: bool, caution: bool, maximize: bool, - has_sparse_grad: bool + has_sparse_grad: bool, + max_lr: Optional[float] ): for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] - param.mul_(1. - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + param.mul_(1. - wd_scale * weight_decay) if momentum != 0: buf = momentum_buffer_list[i] @@ -234,7 +250,8 @@ def _multi_tensor_sgdw( nesterov: bool, caution: bool, maximize: bool, - has_sparse_grad: bool + has_sparse_grad: bool, + max_lr: Optional[float] ): if len(params) == 0: return @@ -247,7 +264,8 @@ def _multi_tensor_sgdw( if maximize: device_grads = torch._foreach_neg(device_grads) - torch._foreach_mul_(params, 1. - lr * weight_decay) + wd_scale = lr if max_lr is None else lr ** 2 / max_lr + torch._foreach_mul_(params, 1. - wd_scale * weight_decay) if momentum != 0: bufs = []