Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,132 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
registry.register(opt)


def _register_corrected_decay_optimizers(registry: OptimizerRegistry) -> None:
"""Register corrected weight decay optimizer variants"""
corrected_optimizers = [
OptimInfo(
name='adamc',
opt_class=AdamWLegacy,
description='AdamW with corrected weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'corrected_weight_decay': True}
),
OptimInfo(
name='nadamc',
opt_class=NAdamW,
description='NAdamW with corrected weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'corrected_weight_decay': True}
),
OptimInfo(
name='sgdc',
opt_class=SGDW,
description='SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='adoptc',
opt_class=Adopt,
description='Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
defaults={'decoupled': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='lambcd',
opt_class=Lamb,
description='LAMB with corrected decoupled weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='kronc',
opt_class=Kron,
description='PSGD Kron with corrected decoupled weight decay (lr²/max_lr scaling)',
has_momentum=True,
defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='lionc',
opt_class=Lion,
description='Lion with corrected weight decay (lr²/max_lr scaling)',
has_eps=False,
has_betas=True,
defaults={'corrected_weight_decay': True}
),
OptimInfo(
name='lapropc',
opt_class=LaProp,
description='LaProp with corrected weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'corrected_weight_decay': True}
),
OptimInfo(
name='rmsproptfc',
opt_class=RMSpropTF,
description='RMSprop TF-style with corrected decoupled weight decay (lr²/max_lr scaling)',
has_momentum=True,
defaults={'alpha': 0.9, 'decoupled_decay': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='adafactorbvc',
opt_class=AdafactorBigVision,
description='Adafactor Big Vision with corrected weight decay (lr²/max_lr or lr/max_lr scaling)',
defaults={'corrected_weight_decay': True}
),
]
for opt in corrected_optimizers:
registry.register(opt)

# Cautious + corrected variants
cautious_corrected = [
OptimInfo(
name='cadamc',
opt_class=AdamWLegacy,
description='Cautious AdamW with corrected weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'caution': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='cadoptc',
opt_class=Adopt,
description='Cautious Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
defaults={'decoupled': True, 'caution': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='cnadamc',
opt_class=NAdamW,
description='Cautious NAdamW with corrected weight decay (lr²/max_lr scaling)',
has_betas=True,
defaults={'caution': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='csgdc',
opt_class=SGDW,
description='Cautious SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True, 'caution': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='clionc',
opt_class=Lion,
description='Cautious Lion with corrected weight decay (lr²/max_lr scaling)',
has_eps=False,
has_betas=True,
defaults={'caution': True, 'corrected_weight_decay': True}
),
OptimInfo(
name='cadafactorbvc',
opt_class=AdafactorBigVision,
description='Cautious Adafactor Big Vision with corrected weight decay',
defaults={'caution': True, 'corrected_weight_decay': True}
),
]
for opt in cautious_corrected:
registry.register(opt)


def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
cautious_optimizers = [
OptimInfo(
Expand Down Expand Up @@ -896,6 +1022,7 @@ def _register_default_optimizers() -> None:
_register_apex_optimizers(default_registry)
_register_bnb_optimizers(default_registry)
_register_cautious_optimizers(default_registry)
_register_corrected_decay_optimizers(default_registry)

# Register aliases
default_registry.register_alias('nesterov', 'sgd')
Expand Down
22 changes: 20 additions & 2 deletions timm/optim/adafactor_bv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560

References for added functionality:
Cautious Optimizers: https://arxiv.org/abs/2411.16085
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285

Adaptation and PyTorch modifications by Ross Wightman
"""
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -68,6 +72,7 @@ def __init__(
clipping_threshold: Optional[float] = None,
unscaled_wd: bool = False,
caution: bool = False,
corrected_weight_decay: bool = False,
*,
foreach: Optional[bool] = False,
):
Expand All @@ -94,6 +99,7 @@ def __init__(
clipping_threshold=clipping_threshold,
unscaled_wd=unscaled_wd,
caution=caution,
corrected_weight_decay=corrected_weight_decay,
foreach=foreach,
)
super().__init__(params, defaults)
Expand All @@ -102,6 +108,7 @@ def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('corrected_weight_decay', False)
group.setdefault('foreach', None)
for p in group['params']:
p_state = self.state.get(p, {})
Expand Down Expand Up @@ -197,6 +204,7 @@ def step(self, closure=None):
clipping_threshold=group['clipping_threshold'],
unscaled_wd=group['unscaled_wd'],
caution=group['caution'],
max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
)

return loss
Expand All @@ -222,6 +230,7 @@ def _single_tensor_adafactor(
clipping_threshold: Optional[float],
unscaled_wd: bool,
caution: bool,
max_lr: Optional[float],
):
for i, param in enumerate(params):
grad = grads[i]
Expand Down Expand Up @@ -286,10 +295,18 @@ def _single_tensor_adafactor(
if weight_decay != 0:
if unscaled_wd:
# match big vision impl, 'fully decoupled' decay w/o LR scaling
param.mul_(1. - weight_decay)
if max_lr is None:
param.mul_(1. - weight_decay)
else:
# corrected weight decay: scale by lr / max_lr
param.mul_(1. - (lr / max_lr) * weight_decay)
else:
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
param.mul_(1. - lr * weight_decay)
if max_lr is None:
param.mul_(1. - lr * weight_decay)
else:
# corrected weight decay: scale by lr^2 / max_lr
param.mul_(1. - (lr ** 2 / max_lr) * weight_decay)

# Update parameters
param.add_(update, alpha=-1.0)
Expand All @@ -315,6 +332,7 @@ def _multi_tensor_adafactor(
clipping_threshold: Optional[float],
unscaled_wd: bool,
caution: bool,
max_lr: Optional[float],
):
# FIXME TODO
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
11 changes: 10 additions & 1 deletion timm/optim/adamw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
""" AdamW Optimizer
Impl copied from PyTorch master

References for added functionality:
Cautious Optimizers: https://arxiv.org/abs/2411.16085
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285

NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
"""
import math
Expand Down Expand Up @@ -31,6 +35,7 @@ class AdamWLegacy(Optimizer):
amsgrad: whether to use the AMSGrad variant of this algorithm
from the paper `On the Convergence of Adam and Beyond`
caution: apply caution when using AdamW
corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr)
"""

def __init__(
Expand All @@ -42,6 +47,7 @@ def __init__(
weight_decay: float = 1e-2,
amsgrad: bool = False,
caution: bool = False,
corrected_weight_decay: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand All @@ -58,6 +64,7 @@ def __init__(
weight_decay=weight_decay,
amsgrad=amsgrad,
caution=caution,
corrected_weight_decay=corrected_weight_decay,
)
super(AdamWLegacy, self).__init__(params, defaults)

Expand All @@ -66,6 +73,7 @@ def __setstate__(self, state):
for group in self.param_groups:
group.setdefault('amsgrad', False)
group.setdefault('caution', False)
group.setdefault('corrected_weight_decay', False)

@torch.no_grad()
def step(self, closure=None):
Expand All @@ -86,7 +94,8 @@ def step(self, closure=None):
continue

# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
wd_scale = group['lr'] if not group['corrected_weight_decay'] else group['lr'] ** 2 / self.defaults['lr']
p.data.mul_(1 - wd_scale * group['weight_decay'])

# Perform optimization step
grad = p.grad
Expand Down
18 changes: 16 additions & 2 deletions timm/optim/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
year = {2024}
}

References for added functionality:
Cautious Optimizers: https://arxiv.org/abs/2411.16085
Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
"""
from typing import cast, List, Optional, Tuple, Union

Expand Down Expand Up @@ -66,6 +70,7 @@ def __init__(
clip_exp: Optional[float] = 0.333,
weight_decay: float = 0.0,
decoupled: bool = False,
corrected_weight_decay: bool = False,
*,
caution: bool = False,
foreach: Optional[bool] = False,
Expand Down Expand Up @@ -98,6 +103,7 @@ def __init__(
weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled,
corrected_weight_decay=corrected_weight_decay,
caution=caution,
maximize=maximize,
foreach=foreach,
Expand All @@ -115,6 +121,7 @@ def __setstate__(self, state):
group.setdefault("differentiable", False)
group.setdefault("clip_exp", None)
group.setdefault("caution", False)
group.setdefault("corrected_weight_decay", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
Expand Down Expand Up @@ -222,6 +229,7 @@ def step(self, closure=None):
lr=group["lr"],
weight_decay=group["weight_decay"],
clip_exp=group["clip_exp"],
max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
decoupled=group["decoupled"],
eps=group["eps"],
caution=group["caution"],
Expand Down Expand Up @@ -251,6 +259,7 @@ def _single_tensor_adopt(
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
max_lr: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
Expand Down Expand Up @@ -299,7 +308,8 @@ def _single_tensor_adopt(
continue

if weight_decay != 0 and decoupled:
param.add_(param, alpha=-lr * weight_decay)
wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
param.add_(param, alpha=-wd_scale * weight_decay)

denom = torch.clamp(exp_avg_sq.sqrt(), eps)
normed_grad = grad.div(denom)
Expand Down Expand Up @@ -336,6 +346,7 @@ def _multi_tensor_adopt(
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
max_lr: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
Expand Down Expand Up @@ -410,7 +421,8 @@ def _multi_tensor_adopt(
continue

if weight_decay != 0 and decoupled:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
torch._foreach_add_(device_params, device_params, alpha=-wd_scale * weight_decay)

exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
Expand Down Expand Up @@ -460,6 +472,7 @@ def adopt(
lr: Union[float, Tensor],
weight_decay: float,
clip_exp: Optional[float],
max_lr: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
Expand Down Expand Up @@ -498,6 +511,7 @@ def adopt(
lr=lr,
weight_decay=weight_decay,
clip_exp=clip_exp,
max_lr=max_lr,
decoupled=decoupled,
eps=eps,
caution=caution,
Expand Down
Loading