Skip to content

Commit cd2555a

Browse files
liuchen9494facebook-github-bot
authored andcommitted
build_generator api changes for the scripted SequenceGenerator (#697)
Summary: Pull Request resolved: pytorch/translate#697 Pull Request resolved: #1922 Pull Request resolved: fairinternal/fairseq-py#1117 We are planning to deprecate the original SequenceGenerator and use the ScriptSequenceGenerator in the Fairseq. Due to the change of scripted Sequence Generator constructor, I change `build_generator` interface in Fairseq, pyspeech and pytorch translate. Reviewed By: myleott Differential Revision: D20683836 fbshipit-source-id: d01d891ebd067fe44291d3a0a784935edaf66acd
1 parent f20dc23 commit cd2555a

File tree

10 files changed

+14
-12
lines changed

10 files changed

+14
-12
lines changed

examples/speech_recognition/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def main(args):
208208

209209
# Initialize generator
210210
gen_timer = meters.StopwatchMeter()
211-
generator = task.build_generator(args)
211+
generator = task.build_generator(models, args)
212212

213213
num_sentences = 0
214214

examples/speech_recognition/tasks/speech_recognition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def load_dataset(self, split, combine=False, **kwargs):
108108
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
109109
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
110110

111-
def build_generator(self, args):
111+
def build_generator(self, models, args):
112112
w2l_decoder = getattr(args, "w2l_decoder", None)
113113
if w2l_decoder == "viterbi":
114114
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
@@ -119,7 +119,7 @@ def build_generator(self, args):
119119

120120
return W2lKenLMDecoder(args, self.target_dictionary)
121121
else:
122-
return super().build_generator(args)
122+
return super().build_generator(models, args)
123123

124124
@property
125125
def target_dictionary(self):

fairseq/hub_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def generate(
157157
gen_args.beam = beam
158158
for k, v in kwargs.items():
159159
setattr(gen_args, k, v)
160-
generator = self.task.build_generator(gen_args)
160+
generator = self.task.build_generator(self.models, gen_args)
161161

162162
results = []
163163
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):

fairseq/models/bart/hub_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool
115115
gen_args.beam = beam
116116
for k, v in kwargs.items():
117117
setattr(gen_args, k, v)
118-
generator = self.task.build_generator(gen_args)
118+
generator = self.task.build_generator([self.model], gen_args)
119119
translations = self.task.inference_step(
120120
generator,
121121
[self.model],

fairseq/tasks/fairseq_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def build_criterion(self, args):
225225

226226
return criterions.build_criterion(args, self)
227227

228-
def build_generator(self, args):
228+
def build_generator(self, models, args):
229229
if getattr(args, "score_reference", False):
230230
from fairseq.sequence_scorer import SequenceScorer
231231

fairseq/tasks/translation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths):
261261
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
262262

263263
def build_model(self, args):
264+
model = super().build_model(args)
264265
if getattr(args, 'eval_bleu', False):
265266
assert getattr(args, 'eval_bleu_detok', None) is not None, (
266267
'--eval-bleu-detok is required if using --eval-bleu; '
@@ -274,8 +275,8 @@ def build_model(self, args):
274275
))
275276

276277
gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}')
277-
self.sequence_generator = self.build_generator(Namespace(**gen_args))
278-
return super().build_model(args)
278+
self.sequence_generator = self.build_generator([model], Namespace(**gen_args))
279+
return model
279280

280281
def valid_step(self, sample, model, criterion):
281282
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)

fairseq/tasks/translation_from_pretrained_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
7979
append_source_id=True
8080
)
8181

82-
def build_generator(self, args):
82+
def build_generator(self, models, args):
8383
if getattr(args, 'score_reference', False):
8484
from fairseq.sequence_scorer import SequenceScorer
8585
return SequenceScorer(

fairseq/tasks/translation_lev.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def _full_mask(target_tokens):
126126
else:
127127
raise NotImplementedError
128128

129-
def build_generator(self, args):
129+
def build_generator(self, models, args):
130+
# add models input to match the API for SequenceGenerator
130131
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
131132
return IterativeRefinementGenerator(
132133
self.target_dictionary,

fairseq_cli/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _main(args, output_file):
111111

112112
# Initialize generator
113113
gen_timer = StopwatchMeter()
114-
generator = task.build_generator(args)
114+
generator = task.build_generator(models, args)
115115

116116
# Handle tokenization and BPE
117117
tokenizer = encoders.build_tokenizer(args)

fairseq_cli/interactive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def main(args):
112112
model.cuda()
113113

114114
# Initialize generator
115-
generator = task.build_generator(args)
115+
generator = task.build_generator(models, args)
116116

117117
# Handle tokenization and BPE
118118
tokenizer = encoders.build_tokenizer(args)

0 commit comments

Comments
 (0)