Skip to content

Commit f75411a

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Add FusedLAMB optimizer
Summary: Pull Request resolved: fairinternal/fairseq-py#971 Differential Revision: D19265752 Pulled By: myleott fbshipit-source-id: 062748d3b44ef3627d35d65dccef8659709c2b5e
1 parent 1e324a5 commit f75411a

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

fairseq/optim/fused_lamb.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from fairseq.optim import FairseqOptimizer, register_optimizer
7+
8+
9+
@register_optimizer('lamb')
10+
class FairseqLAMB(FairseqOptimizer):
11+
"""LAMB optimizer."""
12+
13+
def __init__(self, args, params):
14+
super().__init__(args)
15+
try:
16+
from apex.optimizers import FusedLAMB
17+
self._optimizer = FusedLAMB(params, **self.optimizer_config)
18+
except ImportError:
19+
raise ImportError('Please install apex to use LAMB optimizer')
20+
21+
@staticmethod
22+
def add_args(parser):
23+
"""Add optimizer-specific arguments to the parser."""
24+
# fmt: off
25+
parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
26+
help='betas for LAMB optimizer')
27+
parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
28+
help='epsilon for LAMB optimizer')
29+
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
30+
help='weight decay')
31+
# fmt: on
32+
33+
@property
34+
def optimizer_config(self):
35+
"""
36+
Return a kwarg dictionary that will be used to override optimizer
37+
args stored in checkpoints. This allows us to load a checkpoint and
38+
resume training using a different set of optimizer args, e.g., with a
39+
different learning rate.
40+
"""
41+
return {
42+
'lr': self.args.lr[0],
43+
'betas': eval(self.args.lamb_betas),
44+
'eps': self.args.lamb_eps,
45+
'weight_decay': self.args.weight_decay,
46+
}

0 commit comments

Comments
 (0)