3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ from collections import namedtuple
7
+
6
8
import torch
9
+
7
10
from fairseq import utils
8
- from fairseq .models .levenshtein_transformer import LevenshteinTransformerModel
9
- from fairseq .models .model_utils import script_skip_tensor_list , skip_tensors as _skip
10
- from fairseq .models .nonautoregressive_ensembles import EnsembleLevT
11
+
12
+
13
+ DecoderOut = namedtuple ('IterativeRefinementDecoderOut' , [
14
+ 'output_tokens' ,
15
+ 'output_scores' ,
16
+ 'attn' ,
17
+ 'step' ,
18
+ 'max_step' ,
19
+ ])
11
20
12
21
13
22
class IterativeRefinementGenerator (object ):
@@ -88,6 +97,8 @@ def generate_batched_itr(
88
97
89
98
@torch .no_grad ()
90
99
def generate (self , models , sample , prefix_tokens = None ):
100
+ from fairseq .models .levenshtein_transformer import LevenshteinTransformerModel
101
+ from fairseq .models .nonautoregressive_ensembles import EnsembleLevT
91
102
92
103
if len (models ) == 1 :
93
104
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
@@ -110,7 +121,7 @@ def generate(self, models, sample, prefix_tokens=None):
110
121
111
122
# initialize buffers (very model specific, with length prediction or not)
112
123
prev_decoder_out = model .initialize_output_tokens (encoder_out , src_tokens )
113
- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
124
+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
114
125
115
126
finalized = [[] for _ in range (bsz )]
116
127
@@ -150,8 +161,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
150
161
"max_ratio" : self .max_ratio ,
151
162
"decoding_format" : self .decoding_format ,
152
163
}
153
- prev_decoder_out [3 ] = step
154
- prev_decoder_out [4 ] = self .max_iter + 1
164
+ prev_decoder_out = prev_decoder_out ._replace (
165
+ step = step ,
166
+ max_step = self .max_iter + 1 ,
167
+ )
155
168
156
169
decoder_out = model .forward_decoder (
157
170
prev_decoder_out , encoder_out , ** decoder_options
@@ -160,24 +173,26 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
160
173
if self .adaptive :
161
174
# terminate if there is a loop
162
175
terminated , out_tokens , out_scores , out_attn = is_a_loop (
163
- prev_output_tokens , decoder_out [0 ], decoder_out [1 ], decoder_out [2 ]
176
+ prev_output_tokens , decoder_out .output_tokens , decoder_out .output_scores , decoder_out .attn
177
+ )
178
+ decoder_out = decoder_out ._replace (
179
+ output_tokens = out_tokens ,
180
+ output_scores = out_scores ,
181
+ attn = out_attn ,
164
182
)
165
- decoder_out [0 ] = out_tokens
166
- decoder_out [1 ] = out_scores
167
- decoder_out [2 ] = out_attn
168
183
169
184
else :
170
- terminated = decoder_out [ 0 ]. new_zeros (decoder_out [ 0 ] .size (0 )).bool ()
185
+ terminated = decoder_out . output_tokens . new_zeros (decoder_out . output_tokens .size (0 )).bool ()
171
186
172
187
if step == self .max_iter : # reach last iteration, terminate
173
188
terminated .fill_ (1 )
174
189
175
190
# collect finalized sentences
176
191
finalized_idxs = sent_idxs [terminated ]
177
- finalized_tokens = decoder_out [ 0 ] [terminated ]
178
- finalized_scores = decoder_out [ 1 ] [terminated ]
192
+ finalized_tokens = decoder_out . output_tokens [terminated ]
193
+ finalized_scores = decoder_out . output_scores [terminated ]
179
194
finalized_attn = (
180
- None if decoder_out [ 2 ] is None else decoder_out [ 2 ] [terminated ]
195
+ None if decoder_out . attn is None else decoder_out . attn [terminated ]
181
196
)
182
197
183
198
for i in range (finalized_idxs .size (0 )):
@@ -194,10 +209,15 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
194
209
break
195
210
196
211
# for next step
197
- prev_decoder_out = _skip (decoder_out , ~ terminated )
198
- encoder_out = script_skip_tensor_list (encoder_out , ~ terminated )
199
- sent_idxs = _skip (sent_idxs , ~ terminated )
212
+ not_terminated = ~ terminated
213
+ prev_decoder_out = decoder_out ._replace (
214
+ output_tokens = decoder_out .output_tokens [not_terminated ],
215
+ output_scores = decoder_out .output_scores [not_terminated ],
216
+ attn = decoder_out .attn [not_terminated ] if decoder_out .attn is not None else None ,
217
+ )
218
+ encoder_out = model .encoder .reorder_encoder_out (encoder_out , not_terminated .nonzero ().squeeze ())
219
+ sent_idxs = sent_idxs [not_terminated ]
200
220
201
- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
221
+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
202
222
203
223
return finalized
0 commit comments