-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Description
Feature request
predict_step
in Trainer.py
doesn't currently pass the num_items_in_batch
to the compute_loss
function.
transformers/src/transformers/trainer.py
Line 4900 in 869735d
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) |
This seems to be misaligned because the training_step
function does
transformers/src/transformers/trainer.py
Line 4019 in 869735d
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) |
Motivation
The Trainer's training_step
function uses get_batch_samples
to calculate the num_items_in_batch
which is used to scale the loss.
However with the predict_step
not passing this value, a user with a custom loss function has to account for num_items_in_batch
being null at eval time but not train time, which is a bit confusing. Ensuring that both train and predict steps calculate num_items_in_batch
the same way ensures accurate logging and comparison of loss metrics.
Your contribution
I'm happy to submit a PR for this unless there's a strong reason as to why the predict_step
shouldn't be passing the num_items_in_batch
to compute_loss