Skip to content

Commit ba5f829

Browse files
0xjcfacebook-github-bot
authored andcommitted
Parameterized criterions (#808)
Summary: Support criterion with parameters, such as AutoSegmentationCriterion (ASG) used in wav2letter which has a transition matrix parameter. This is needed to integrate wav2letter's ASG into PySpeech. With this diff, parameters in criterions will be: (1) updated by optimizers, with a configurable learning rate (2) saved and loaded from checkpoints, preserving backward compatibility for criterions without parameters (3) synchronized across nodes in distributed training. Pull Request resolved: fairinternal/fairseq-py#808 Reviewed By: jcai1 Differential Revision: D16934097 Pulled By: okhonko fbshipit-source-id: 121ec9382459385c6f9cbef3a8274bec1a434038
1 parent a2f5361 commit ba5f829

15 files changed

+79
-39
lines changed

fairseq/checkpoint_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def save_state(
222222
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
223223
num_updates, optim_history=None, extra_state=None,
224224
):
225+
from fairseq import utils
225226
if optim_history is None:
226227
optim_history = []
227228
if extra_state is None:
@@ -239,6 +240,8 @@ def save_state(
239240
],
240241
'extra_state': extra_state,
241242
}
243+
if utils.has_parameters(criterion):
244+
state_dict['criterion'] = criterion.state_dict()
242245
if not args.no_save_optimizer_state:
243246
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
244247
torch_persistent_save(state_dict, filename)

fairseq/models/distributed_fairseq_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import inspect
77

8-
from torch.nn import parallel
8+
import torch.nn as nn
99

1010
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
1111
from fairseq.models import BaseFairseqModel
@@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model):
2525
model (BaseFairseqModel): model to wrap
2626
"""
2727
# determine which DDP class to extend
28-
assert isinstance(model, BaseFairseqModel)
28+
assert isinstance(model, nn.Module)
2929
if args.ddp_backend == 'c10d':
30-
ddp_class = parallel.DistributedDataParallel
30+
ddp_class = nn.parallel.DistributedDataParallel
3131
init_kwargs = dict(
3232
module=model,
3333
device_ids=[args.device_id],

fairseq/optim/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
]
2020

2121

22-
_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
22+
build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
2323
'--optimizer',
2424
base_class=FairseqOptimizer,
2525
default='nag',
2626
)
2727

2828

29-
def build_optimizer(args, params, *extra_args, **extra_kwargs):
30-
params = list(filter(lambda p: p.requires_grad, params))
31-
return _build_optimizer(args, params, *extra_args, **extra_kwargs)
32-
33-
3429
# automatically import any Python files in the optim/ directory
3530
for file in os.listdir(os.path.dirname(__file__)):
3631
if file.endswith('.py') and not file.startswith('_'):

fairseq/optim/adadelta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@register_optimizer('adadelta')
1212
class Adadelta(FairseqOptimizer):
1313
def __init__(self, args, params):
14-
super().__init__(args, params)
14+
super().__init__(args)
1515
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
1616

1717
@staticmethod

fairseq/optim/adafactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@register_optimizer('adafactor')
1414
class FairseqAdafactor(FairseqOptimizer):
1515
def __init__(self, args, params):
16-
super().__init__(args, params)
16+
super().__init__(args)
1717
self._optimizer = Adafactor(params, **self.optimizer_config)
1818

1919
@staticmethod

fairseq/optim/adagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@register_optimizer('adagrad')
1212
class Adagrad(FairseqOptimizer):
1313
def __init__(self, args, params):
14-
super().__init__(args, params)
14+
super().__init__(args)
1515
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
1616

1717
@staticmethod

fairseq/optim/adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
class FairseqAdam(FairseqOptimizer):
1717

1818
def __init__(self, args, params):
19-
super().__init__(args, params)
19+
super().__init__(args)
2020
if torch.cuda.is_available():
2121
try:
2222
from apex.optimizers import FusedAdam as _FusedAdam # noqa

fairseq/optim/adamax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@register_optimizer('adamax')
1313
class FairseqAdamax(FairseqOptimizer):
1414
def __init__(self, args, params):
15-
super().__init__(args, params)
15+
super().__init__(args)
1616
self._optimizer = Adamax(params, **self.optimizer_config)
1717

1818
@staticmethod

fairseq/optim/bmuf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer):
1919
model-update filtering
2020
"""
2121

22-
def __init__(self, args, params, optimizer):
22+
def __init__(self, args, optimizer):
2323

24-
super().__init__(args, params)
24+
super().__init__(args)
2525
self._optimizer = optimizer
26-
self.params = params
2726
self._num_updates = 0
2827
self.sync_iter = self.args.global_sync_iter
2928
self.block_momentum = self.args.block_momentum

fairseq/optim/fairseq_optimizer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010

1111
class FairseqOptimizer(object):
1212

13-
def __init__(self, args, params):
13+
def __init__(self, args):
1414
super().__init__()
1515
self.args = args
16-
self.params = list(params)
1716

1817
@staticmethod
1918
def add_args(parser):
@@ -39,6 +38,13 @@ def optimizer_config(self):
3938
"""
4039
raise NotImplementedError
4140

41+
@property
42+
def params(self):
43+
"""Return an iterable of the parameters held by the optimizer."""
44+
for param_group in self.optimizer.param_groups:
45+
for p in param_group['params']:
46+
yield p
47+
4248
def __getstate__(self):
4349
return self._optimizer.__getstate__()
4450

@@ -93,9 +99,8 @@ def step(self, closure=None):
9399

94100
def zero_grad(self):
95101
"""Clears the gradients of all optimized parameters."""
96-
for group in self.optimizer.param_groups:
97-
for p in group['params']:
98-
p.grad = None
102+
for p in self.params:
103+
p.grad = None
99104
self.optimizer.zero_grad()
100105

101106
@property

0 commit comments

Comments
 (0)