|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +from fairseq.optim import FairseqOptimizer, register_optimizer |
| 7 | + |
| 8 | + |
| 9 | +@register_optimizer('lamb') |
| 10 | +class FairseqLAMB(FairseqOptimizer): |
| 11 | + """LAMB optimizer.""" |
| 12 | + |
| 13 | + def __init__(self, args, params): |
| 14 | + super().__init__(args) |
| 15 | + try: |
| 16 | + from apex.optimizers import FusedLAMB |
| 17 | + self._optimizer = FusedLAMB(params, **self.optimizer_config) |
| 18 | + except ImportError: |
| 19 | + raise ImportError('Please install apex to use LAMB optimizer') |
| 20 | + |
| 21 | + @staticmethod |
| 22 | + def add_args(parser): |
| 23 | + """Add optimizer-specific arguments to the parser.""" |
| 24 | + # fmt: off |
| 25 | + parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', |
| 26 | + help='betas for LAMB optimizer') |
| 27 | + parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', |
| 28 | + help='epsilon for LAMB optimizer') |
| 29 | + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', |
| 30 | + help='weight decay') |
| 31 | + # fmt: on |
| 32 | + |
| 33 | + @property |
| 34 | + def optimizer_config(self): |
| 35 | + """ |
| 36 | + Return a kwarg dictionary that will be used to override optimizer |
| 37 | + args stored in checkpoints. This allows us to load a checkpoint and |
| 38 | + resume training using a different set of optimizer args, e.g., with a |
| 39 | + different learning rate. |
| 40 | + """ |
| 41 | + return { |
| 42 | + 'lr': self.args.lr[0], |
| 43 | + 'betas': eval(self.args.lamb_betas), |
| 44 | + 'eps': self.args.lamb_eps, |
| 45 | + 'weight_decay': self.args.weight_decay, |
| 46 | + } |
0 commit comments