diff --git a/README.md b/README.md index 84e3cdfe21..125428d28a 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,11 @@ ## What's New +## Jan 27, 2025 +* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner) + * Code from https://github.com/evanatyourservice/kron_torch + * See also https://sites.google.com/site/lixilinx/home/psgd + ## Jan 19, 2025 * Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated * Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft @@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho * `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217 * `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677 * `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853 +* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd * `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962 * `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839 * `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888 diff --git a/tests/test_optim.py b/tests/test_optim.py index 0d540f66a3..6de679f65c 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -290,7 +290,7 @@ def _build_params_dict_single(weight, bias, **kwargs): return [dict(params=bias, **kwargs)] -@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) +@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*', 'kron*'))) def test_optim_factory(optimizer): assert issubclass(get_optimizer_class(optimizer, bind_defaults=False), torch.optim.Optimizer) @@ -386,6 +386,14 @@ def test_adam(optimizer): _test_model(optimizer, dict(lr=5e-2)) +@pytest.mark.parametrize('optimizer', ['kron']) +def test_kron(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) + _test_model(optimizer, dict(lr=1e-3)) + + @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) def test_adopt(optimizer): _test_rosenbrock( diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 0ec0eecae0..f4cfeb7060 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -23,6 +23,7 @@ from .adamw import AdamWLegacy from .adan import Adan from .adopt import Adopt +from .kron import Kron from .lamb import Lamb from .laprop import LaProp from .lars import Lars @@ -693,6 +694,19 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, second_order=True, ), + OptimInfo( + name='kron', + opt_class=Kron, + description='PSGD optimizer with Kronecker-factored preconditioner', + has_momentum=True, + ), + OptimInfo( + name='kronw', + opt_class=Kron, + description='PSGD optimizer with Kronecker-factored preconditioner and decoupled weight decay', + has_momentum=True, + defaults={'decoupled_decay': True} + ), OptimInfo( name='laprop', opt_class=LaProp, diff --git a/timm/optim/kron.py b/timm/optim/kron.py new file mode 100644 index 0000000000..e01c9885be --- /dev/null +++ b/timm/optim/kron.py @@ -0,0 +1,527 @@ +""" PyTorch Implementation of the Kron (PSGD) optimizer + +This is a PSGD optimizer using a Kronecker-factored preconditioner. + +This impl was adapted from https://github.com/evanatyourservice/kron_torch +by Evan Walters, licensed CC-BY-4.0. + +Contributions to above also made by +* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation. +* Omead Pooladzandi https://github.com/opooladz + +The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li + +This `timm` impl +* works with a wider variety of torch versions +* fixes some checkpoint save/restore (resume issues) +* adds decoupled weight-decay option +* has some refactoring, cleanup of args, default/group items +* warning about not having opt_einsum (unusable without) + +""" +import logging +import string +import random +import warnings +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +try: + # NOTE opt_einsum needed to avoid blowing up memory with einsum ops + import opt_einsum + import torch.backends.opt_einsum + torch.backends.opt_einsum.enabled = True + torch.backends.opt_einsum.strategy = "auto-hq" + has_opt_einsum = True +except ImportError: + has_opt_einsum = False + +try: + torch._dynamo.config.cache_size_limit = 1_000_000 + has_dynamo = True +except AttributeError: + has_dynamo = False + +from ._types import ParamsT + +_logger = logging.getLogger(__name__) + + +def precond_update_prob_schedule( + n: float, + max_prob: float = 1.0, + min_prob: float = 0.03, + decay: float = 0.001, + flat_start: float = 500, +) -> torch.Tensor: + """Anneal preconditioner update probability during beginning of training. + + PSGD benefits from more preconditioner updates at the beginning of training, + but once the preconditioner is learned the update probability can drop low. + + This schedule is an exponential anneal with a flat start. Default settings keep + update probability at 1.0 for 200 steps then exponentially anneal down to + `min_prob` by 4000 steps. Default settings work very well for most models and + training regimes. + """ + + """Exponential anneal with flat start.""" + n = torch.tensor(n, dtype=torch.float32) + prob = max_prob * torch.exp(-decay * (n - flat_start)) + prob.clamp_(min=min_prob, max=max_prob) + + return prob + + +class Kron(torch.optim.Optimizer): + """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + momentum: Momentum parameter. + weight_decay: Weight decay. + preconditioner_update_probability: Probability of updating the preconditioner. + If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps. + max_size_triangular: Max size for dim's preconditioner to be triangular. + min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners. + memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default + to set all preconditioners to be triangular, 'one_diag' sets the largest + or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal. + momentum_into_precond_update: whether to send momentum into preconditioner + update instead of raw gradients. + mu_dtype: Dtype of the momentum accumulator. + precond_dtype: Dtype of the preconditioner. + decoupled_decay: AdamW style decoupled weight decay + flatten_dim: Flatten dim >= 2 instead of relying on expressions + deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work + """ + + def __init__( + self, + params: ParamsT, + lr: float = 0.001, + momentum: float = 0.9, + weight_decay: float = 0.0, + preconditioner_update_probability: Optional[Union[Callable, float]] = None, + max_size_triangular: int = 2048, + min_ndim_triangular: int = 2, + memory_save_mode: Optional[str] = None, + momentum_into_precond_update: bool = True, + precond_lr: float = 0.1, + precond_init_scale: float = 1.0, + mu_dtype: Optional[torch.dtype] = None, + precond_dtype: Optional[torch.dtype] = None, + decoupled_decay: bool = False, + flatten_dim: bool = False, + deterministic: bool = False, + ): + if not has_opt_einsum: + warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.") + + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= momentum < 1.0: + raise ValueError(f"Invalid beta parameter: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + preconditioner_update_probability=preconditioner_update_probability, + max_size_triangular=max_size_triangular, + min_ndim_triangular=min_ndim_triangular, + memory_save_mode=memory_save_mode, + momentum_into_precond_update=momentum_into_precond_update, + precond_lr=precond_lr, + precond_init_scale=precond_init_scale, + mu_dtype=mu_dtype, + precond_dtype=precond_dtype, + decoupled_decay=decoupled_decay, + flatten_dim=flatten_dim, + ) + super(Kron, self).__init__(params, defaults) + + self._param_exprs = {} # cache for einsum expr + self._tiny = torch.finfo(torch.bfloat16).tiny + self.rng = random.Random(1337) + if deterministic: + # Use a Generator to try to be more deterministic across resume (save/load) + self.torch_rng = torch.Generator().manual_seed(1337) + else: + self.torch_rng = None + + # make compile optional (for bwd compat) + if has_dynamo: + self._calc_A_and_conjB = torch.compile(_calc_A_and_conjB, fullgraph=True, dynamic=False) + self._q_terms = torch.compile(_q_terms, fullgraph=True, dynamic=False) + self._precond_grad = torch.compile(_precond_grad, fullgraph=True, dynamic=False) + self._balance_Q = torch.compile(_balance_Q, fullgraph=True, dynamic=False) + else: + self._calc_A_and_conjB = _calc_A_and_conjB + self._q_terms = _q_terms + self._precond_grad = _precond_grad + self._balance_Q = _balance_Q + + def __getstate__(self): + _dict = super().__getstate__() + _dict["rng"] = self.rng + _dict["torch_rng"] = self.torch_rng + return _dict + + def state_dict(self) -> Dict[str, Any]: + # Get the optimizer's state dict + optimizer_state = super().state_dict() + + # Add the generator state + optimizer_state['rng_state'] = self.rng.getstate() + if self.torch_rng is not None: + optimizer_state['torch_rng_state'] = self.torch_rng.get_state() + + return optimizer_state + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # Extract and remove the RNG state from the state dict + rng_states = {} + if 'rng_state' in state_dict: + rng_states['rng_state'] = state_dict.pop('rng_state') + if 'torch_rng_state' in state_dict: + rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state') + + # Load the optimizer state + super().load_state_dict(state_dict) + state_dict.update(rng_states) # add back + + # Restore the RNG state if it exists + if 'rng_state' in rng_states: + self.rng.setstate(rng_states['rng_state']) + if 'torch_rng_state' in rng_states: + self.torch_rng.set_state(rng_states['torch_rng_state']) + + def __setstate__(self, state): + super().__setstate__(state) + self._param_exprs = {} + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + total_momentum_size = 0 + total_momentum_mb = 0 + total_precond_size = 0 + total_precond_mb = 0 + + for group in self.param_groups: + mu_dtype = group.get("mu_dtype") + precond_dtype = group.get("precond_dtype", torch.float32) + momentum_into_precond_update = group.get("momentum_into_precond_update", True) + update_prob = group.get("preconditioner_update_probability", None) + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + if group['flatten_dim']: + grad = grad.view(grad.size(0), -1) + + if len(state) == 0: + state["step"] = 0 + state["update_counter"] = 0 + state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype) + # init Q and einsum expressions on first step + state["Q"], exprs = _init_Q_exprs( + grad, + group["precond_init_scale"], + group["max_size_triangular"], + group["min_ndim_triangular"], + group["memory_save_mode"], + dtype=precond_dtype, + ) + self._param_exprs[p] = exprs + + # Accumulate sizes for log + momentum_size = state["momentum_buffer"].numel() + momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20 + total_momentum_size += momentum_size + total_momentum_mb += momentum_mb + + precond_size = sum(q.numel() for q in state["Q"]) + precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20 + total_precond_size += precond_size + total_precond_mb += precond_mb + elif p not in self._param_exprs: + # init only the einsum expressions, called after state load, Q are loaded from state_dict + exprs = _init_Q_exprs( + grad, + group["precond_init_scale"], + group["max_size_triangular"], + group["min_ndim_triangular"], + group["memory_save_mode"], + dtype=precond_dtype, + init_q=False, + ) + self._param_exprs[p] = exprs + else: + # retrieve cached expressions + exprs = self._param_exprs[p] + + # update preconditioners all together deterministically + if update_prob is None: + update_prob = precond_update_prob_schedule + if callable(update_prob): + update_prob = update_prob(state["step"]) + state["update_counter"] += 1 + do_update = state["update_counter"] >= 1 / update_prob + if do_update: + state["update_counter"] = 0 + + state["step"] += 1 + + # Update momentum buffer + beta = group["momentum"] + bias_correction = 1 - beta ** state["step"] + momentum_buffer = state["momentum_buffer"] + momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"]) + + # Restore momentum dtype + if mu_dtype is not None: + momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype)) + debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype) + + # Balance preconditioners roughly every 100 updates + balance = self.rng.random() < 0.01 and do_update + if grad.dim() > 1 and balance: + self._balance_Q(state["Q"]) + + # Update preconditioner + if do_update: + exprA, exprGs, _ = exprs + Q = state["Q"] + if self.torch_rng is None: + V = torch.randn_like(debiased_momentum, dtype=precond_dtype) + else: + # Restoring generator state to device is messy. For now, + # we keep RNG on CPU, but this slows the optimizer down quite a bit. + # FIXME Need a better approach + V = torch.randn( + debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu') + V = V.to(debiased_momentum.device) + G = debiased_momentum if momentum_into_precond_update else grad + + A, conjB = self._calc_A_and_conjB(exprA, G, Q, V) + + terms = self._q_terms(exprGs, A, conjB) + + for q, (term1, term2) in zip(Q, terms): + tmp = term1 - term2 + tmp *= group["precond_lr"] + if q.dim() < 2: + tmp *= q + tmp /= (term1 + term2).norm(float("inf")) + self._tiny + else: + tmp = torch.triu(tmp) + tmp /= _norm_lower_bound(term1 + term2) + self._tiny + tmp @= q + q.sub_(tmp) + + # Precondition gradients + pre_grad = self._precond_grad( + state["Q"], + exprs, + debiased_momentum, + ).to(dtype=p.dtype) + + # RMS of pre_grad should be 1.0, so let's cap at 1.1 + pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0)) + if group['flatten_dim']: + pre_grad = pre_grad.view(p.shape) + + # Apply weight decay + if group["weight_decay"] != 0: + if group["decoupled_decay"]: + p.mul_(1. - group["lr"] * group["weight_decay"]) + else: + pre_grad.add_(p, alpha=group["weight_decay"]) + + # Update parameters + p.add_(pre_grad, alpha=-group["lr"]) + + if total_momentum_size > 0: + _logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB") + _logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB") + + return loss + + +def _init_Q_exprs( + t, + scale, + max_size, + min_ndim_triangular, + memory_save_mode, + dtype=None, + init_q=True, +): + """For a scalar or tensor t, we initialize its preconditioner Q and + reusable einsum expressions for updating Q and preconditioning gradient. + """ + letters = string.ascii_lowercase + string.ascii_uppercase + + dtype = dtype if dtype is not None else t.dtype + shape = t.shape + Q = [] + if len(shape) == 0: # scalar + if init_q: + Q.append(scale * torch.ones_like(t, dtype=dtype)) + exprA = ",->" + exprGs = [",->"] + exprP = ",,->" + else: # tensor + if len(shape) > 13: + raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!") + + scale = scale ** (1 / len(shape)) + + if memory_save_mode is None: + dim_diag = [False for _ in shape] + elif memory_save_mode == "one_diag": + rev_sorted_dims = np.argsort(shape)[::-1] + dim_diag = [False for _ in shape] + dim_diag[rev_sorted_dims[0]] = True + elif memory_save_mode == "smart_one_diag": + # addition proposed by Lucas Nestler + rev_sorted_dims = np.argsort(shape)[::-1] + sorted_shape = sorted(shape) + dim_diag = [False for _ in shape] + if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]: + dim_diag[rev_sorted_dims[0]] = True + elif memory_save_mode == "all_diag": + dim_diag = [True for _ in shape] + else: + raise ValueError( + f"Invalid memory_save_mode: {memory_save_mode}, must be one of [None, 'one_diag', 'all_diag']") + + piece1A, piece2A, piece3A = ([], "", "") + exprGs = [] + piece1P, piece2P, piece3P, piece4P = ([], [], "", "") + for i, (size, dim_d) in enumerate(zip(shape, dim_diag)): + if ( + size == 1 + or size > max_size + or len(shape) < min_ndim_triangular + or dim_d + ): + # use diagonal matrix as preconditioner for this dim + if init_q: + Q.append(scale * torch.ones(size, dtype=dtype, device=t.device)) + + piece1A.append(letters[i]) + piece2A = piece2A + letters[i] + piece3A = piece3A + letters[i] + + piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))]) + subscripts = piece1 + "," + piece1 + "->" + letters[i + 13] + exprGs.append(subscripts) + + piece1P.append(letters[i + 13]) + piece2P.append(letters[i + 13]) + piece3P = piece3P + letters[i + 13] + piece4P = piece4P + letters[i + 13] + else: + # use triangular matrix as preconditioner for this dim + if init_q: + Q.append(scale * torch.eye(size, dtype=dtype, device=t.device)) + + piece1A.append(letters[i] + letters[i + 13]) + piece2A = piece2A + letters[i + 13] + piece3A = piece3A + letters[i] + + piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))]) + piece2 = "".join([letters[i + 26] if j == i else letters[j] for j in range(len(shape))]) + subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] + exprGs.append(subscripts) + + a, b, c = (letters[i], letters[i + 13], letters[i + 26]) + piece1P.append(a + b) + piece2P.append(a + c) + piece3P = piece3P + c + piece4P = piece4P + b + + exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A + exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P + + exprGs = tuple(exprGs) + if init_q: + return [Q, (exprA, exprGs, exprP)] + else: + return exprA, exprGs, exprP + + +def _lb(A, max_abs): + A = A / max_abs + aa = torch.real(A * A.conj()) + value0, i = torch.max(torch.sum(aa, dim=0), 0) + value1, j = torch.max(torch.sum(aa, dim=1), 0) + if value0 > value1: + x = A[:, i].conj() @ A + return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H) + else: + x = A @ A[j].conj() + return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x))) + + +def _norm_lower_bound(A): + """Cheap lower bound for the spectral norm of A.""" + max_abs = A.norm(float("inf")) + return torch.where(max_abs > 0, _lb(A, max_abs), max_abs) + + +def _solve_triangular_right(X, A): + """X @ inv(A)""" + orig_dtype = X.dtype + X = X.to(dtype=torch.float32) + A = A.to(dtype=torch.float32) + out = torch.linalg.solve_triangular(A, X.reshape(-1, X.size(-1)), upper=True, left=False).reshape_as(X) + return out.to(dtype=orig_dtype) + + +def _balance_Q(Q_in): + norms = torch.stack([q.norm(float("inf")) for q in Q_in]) + geometric_mean = norms.prod() ** (1 / len(Q_in)) + norms = geometric_mean / norms + for i, q in enumerate(Q_in): + q.mul_(norms[i]) + + +def _precond_grad(Q, exprs, G): + """Precondition gradient G with preconditioner Q.""" + return torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G) + + +def _calc_A_and_conjB(exprA, G, Q, V): + A = torch.einsum(exprA, *Q, G) + order = G.dim() + p = tuple(range(order)) + conjB = torch.permute(V.conj(), p[1:] + p[:1]) + for i, q in enumerate(Q): + conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q) + if i < order - 1: + conjB = torch.transpose(conjB, i, order - 1) + return A, conjB + + +def _q_terms(exprGs, A, conjB): + terms = [] + for exprG in exprGs: + term1 = torch.einsum(exprG, A, A.conj()) + term2 = torch.einsum(exprG, conjB.conj(), conjB) + terms.append((term1, term2)) + return terms