Skip to content

Commit 27568a7

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Merge TracingCompliantTransformer and regular Transformer, fix NAT tests
Summary: Pull Request resolved: fairinternal/fairseq-py#899 Differential Revision: D18373060 Pulled By: myleott fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
1 parent 2a9b4ec commit 27568a7

14 files changed

+551
-1191
lines changed

fairseq/criterions/nat_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def mean_ds(x: Tensor, dim=None) -> Tensor:
4848
if masks is not None:
4949
outputs, targets = outputs[masks], targets[masks]
5050

51-
if not masks.any():
51+
if masks is not None and not masks.any():
5252
nll_loss = torch.tensor(0)
5353
loss = nll_loss
5454
else:

fairseq/iterative_refinement_generator.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from collections import namedtuple
7+
68
import torch
9+
710
from fairseq import utils
8-
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
9-
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
10-
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
11+
12+
13+
DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
14+
'output_tokens',
15+
'output_scores',
16+
'attn',
17+
'step',
18+
'max_step',
19+
])
1120

1221

1322
class IterativeRefinementGenerator(object):
@@ -88,6 +97,8 @@ def generate_batched_itr(
8897

8998
@torch.no_grad()
9099
def generate(self, models, sample, prefix_tokens=None):
100+
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
101+
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
91102

92103
if len(models) == 1:
93104
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
@@ -110,7 +121,7 @@ def generate(self, models, sample, prefix_tokens=None):
110121

111122
# initialize buffers (very model specific, with length prediction or not)
112123
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
113-
prev_output_tokens = prev_decoder_out[0].clone()
124+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
114125

115126
finalized = [[] for _ in range(bsz)]
116127

@@ -150,8 +161,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
150161
"max_ratio": self.max_ratio,
151162
"decoding_format": self.decoding_format,
152163
}
153-
prev_decoder_out[3] = step
154-
prev_decoder_out[4] = self.max_iter + 1
164+
prev_decoder_out = prev_decoder_out._replace(
165+
step=step,
166+
max_step=self.max_iter + 1,
167+
)
155168

156169
decoder_out = model.forward_decoder(
157170
prev_decoder_out, encoder_out, **decoder_options
@@ -160,24 +173,26 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
160173
if self.adaptive:
161174
# terminate if there is a loop
162175
terminated, out_tokens, out_scores, out_attn = is_a_loop(
163-
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
176+
prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
177+
)
178+
decoder_out = decoder_out._replace(
179+
output_tokens=out_tokens,
180+
output_scores=out_scores,
181+
attn=out_attn,
164182
)
165-
decoder_out[0] = out_tokens
166-
decoder_out[1] = out_scores
167-
decoder_out[2] = out_attn
168183

169184
else:
170-
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
185+
terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
171186

172187
if step == self.max_iter: # reach last iteration, terminate
173188
terminated.fill_(1)
174189

175190
# collect finalized sentences
176191
finalized_idxs = sent_idxs[terminated]
177-
finalized_tokens = decoder_out[0][terminated]
178-
finalized_scores = decoder_out[1][terminated]
192+
finalized_tokens = decoder_out.output_tokens[terminated]
193+
finalized_scores = decoder_out.output_scores[terminated]
179194
finalized_attn = (
180-
None if decoder_out[2] is None else decoder_out[2][terminated]
195+
None if decoder_out.attn is None else decoder_out.attn[terminated]
181196
)
182197

183198
for i in range(finalized_idxs.size(0)):
@@ -194,10 +209,15 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
194209
break
195210

196211
# for next step
197-
prev_decoder_out = _skip(decoder_out, ~terminated)
198-
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
199-
sent_idxs = _skip(sent_idxs, ~terminated)
212+
not_terminated = ~terminated
213+
prev_decoder_out = decoder_out._replace(
214+
output_tokens=decoder_out.output_tokens[not_terminated],
215+
output_scores=decoder_out.output_scores[not_terminated],
216+
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
217+
)
218+
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
219+
sent_idxs = sent_idxs[not_terminated]
200220

201-
prev_output_tokens = prev_decoder_out[0].clone()
221+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
202222

203223
return finalized

fairseq/models/cmlm_transformer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
arXiv preprint arXiv:1904.09324 (2019).
1111
"""
1212

13-
from fairseq.utils import new_arange
1413
from fairseq.models import register_model, register_model_architecture
1514
from fairseq.models.nonautoregressive_transformer import NATransformerModel
15+
from fairseq.utils import new_arange
1616

1717

1818
def _skeptical_unmasking(output_scores, output_masks, p):
@@ -55,11 +55,11 @@ def forward(
5555

5656
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
5757

58-
step = decoder_out["step"]
59-
max_step = decoder_out["max_step"]
58+
step = decoder_out.step
59+
max_step = decoder_out.max_step
6060

61-
output_tokens = decoder_out["output_tokens"]
62-
output_scores = decoder_out["output_scores"]
61+
output_tokens = decoder_out.output_tokens
62+
output_scores = decoder_out.output_scores
6363

6464
# execute the decoder
6565
output_masks = output_tokens.eq(self.unk)
@@ -78,7 +78,11 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
7878
output_tokens.masked_fill_(skeptical_mask, self.unk)
7979
output_scores.masked_fill_(skeptical_mask, 0.0)
8080

81-
return {"output_tokens": output_tokens, "output_scores": output_scores}
81+
return decoder_out._replace(
82+
output_tokens=output_tokens,
83+
output_scores=output_scores,
84+
attn=None,
85+
)
8286

8387

8488
@register_model_architecture("cmlm_transformer", "cmlm_transformer")

fairseq/models/insertion_transformer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import numpy as np
77
import torch
88
import torch.nn.functional as F
9-
from fairseq.utils import new_arange
9+
1010
from fairseq.models import register_model, register_model_architecture
1111
from fairseq.models.levenshtein_transformer import (
1212
LevenshteinTransformerDecoder,
1313
LevenshteinTransformerModel,
1414
)
1515
from fairseq.models.transformer import Linear, TransformerModel
1616
from fairseq.modules.transformer_sentence_encoder import init_bert_params
17+
from fairseq.utils import new_arange
1718

1819

1920
class NegativeDistanceScore(object):
@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
116117

117118
@register_model("insertion_transformer")
118119
class InsertionTransformerModel(LevenshteinTransformerModel):
119-
def __init__(self, encoder, decoder):
120-
super().__init__(encoder, decoder)
120+
def __init__(self, args, encoder, decoder):
121+
super().__init__(args, encoder, decoder)
121122

122123
@staticmethod
123124
def add_args(parser):
@@ -169,8 +170,8 @@ def forward_decoder(
169170
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
170171
):
171172

172-
output_tokens = decoder_out["output_tokens"]
173-
output_scores = decoder_out["output_scores"]
173+
output_tokens = decoder_out.output_tokens
174+
output_scores = decoder_out.output_scores
174175
# TODO: decoding for InsertionTransformer
175176
word_ins_out = self.decoder.forward_word_ins(
176177
output_tokens, encoder_out=encoder_out
@@ -187,7 +188,11 @@ def forward_decoder(
187188
cut_off = output_tokens.ne(self.pad).sum(1).max()
188189
output_tokens = output_tokens[:, :cut_off]
189190
output_scores = output_scores[:, :cut_off]
190-
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
191+
return decoder_out._replace(
192+
output_tokens=output_tokens,
193+
output_scores=output_scores,
194+
attn=None,
195+
)
191196

192197

193198
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
@@ -206,7 +211,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
206211
self.label_tau = getattr(args, "label_tau", None)
207212

208213
def forward_word_ins(self, prev_output_tokens, encoder_out=None):
209-
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
214+
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
210215
features = self.pool_out(
211216
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
212217
)

fairseq/models/iterative_nonautoregressive_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
78
from fairseq.models import register_model, register_model_architecture
89
from fairseq.models.nonautoregressive_transformer import NATransformerModel
910

0 commit comments

Comments
 (0)