Skip to content

Commit c9a7c06

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Remove unused FairseqTask.grad_denom and FairseqCriterion.grad_denom
Summary: Pull Request resolved: fairinternal/fairseq-py#979 Differential Revision: D19347309 Pulled By: myleott fbshipit-source-id: ee6fc6111994d35e95ca0ef4cd8f5932e621d84b
1 parent 9da6207 commit c9a7c06

File tree

3 files changed

+2
-10
lines changed

3 files changed

+2
-10
lines changed

fairseq/criterions/fairseq_criterion.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,3 @@ def forward(self, model, sample, reduce=True):
3737
def aggregate_logging_outputs(logging_outputs):
3838
"""Aggregate logging outputs from data parallel training."""
3939
raise NotImplementedError
40-
41-
@staticmethod
42-
def grad_denom(sample_sizes):
43-
"""Compute the gradient denominator for a set of sample sizes."""
44-
return sum(sample_sizes)

fairseq/tasks/fairseq_task.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ def update_step(self, num_updates):
308308
learning rate update of each step"""
309309
pass
310310

311-
def grad_denom(self, sample_sizes, criterion):
312-
return criterion.__class__.grad_denom(sample_sizes)
313-
314311
def aggregate_logging_outputs(self, logging_outputs, criterion):
315312
return criterion.__class__.aggregate_logging_outputs(logging_outputs)
316313

fairseq/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def maybe_no_sync():
417417
logging_output = self.task.aggregate_logging_outputs(
418418
logging_outputs, self.get_criterion()
419419
)
420-
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
420+
sample_size = sum(sample_sizes)
421421

422422
if not all(k in logging_output for k in ["ntokens", "nsentences"]):
423423
raise Exception(
@@ -559,7 +559,7 @@ def valid_step(self, sample, raise_oom=False):
559559
logging_output = self.task.aggregate_logging_outputs(
560560
logging_output, self.get_criterion()
561561
)
562-
sample_size = self.task.grad_denom(sample_size, self.get_criterion())
562+
sample_size = sum(sample_size)
563563

564564
# update meters for validation
565565
ntokens = logging_output.get("ntokens", 0)

0 commit comments

Comments
 (0)