5
5
import tensorflow as tf
6
6
from tensorflow .python .eager import context as tf_context
7
7
8
- from keras .src import backend as backend_module
9
8
from keras .src import callbacks as callbacks_module
10
9
from keras .src import metrics as metrics_module
11
- from keras .src import ops as ops_module
12
10
from keras .src import optimizers as optimizers_module
13
11
from keras .src import tree
12
+ from keras .src .losses import loss as loss_module
14
13
from keras .src .trainers import trainer as base_trainer
15
14
from keras .src .trainers .data_adapters import array_slicing
16
15
from keras .src .trainers .data_adapters import data_adapter_utils
@@ -66,7 +65,8 @@ def train_step(self, data):
66
65
training = True ,
67
66
)
68
67
self ._loss_tracker .update_state (
69
- loss , sample_weight = tf .shape (tree .flatten (x )[0 ])[0 ]
68
+ loss_module .unscale_loss_for_distribution (loss ),
69
+ sample_weight = tf .shape (tree .flatten (x )[0 ])[0 ],
70
70
)
71
71
if self .optimizer is not None :
72
72
loss = self .optimizer .scale_loss (loss )
@@ -93,7 +93,8 @@ def test_step(self, data):
93
93
x = x , y = y , y_pred = y_pred , sample_weight = sample_weight , training = False
94
94
)
95
95
self ._loss_tracker .update_state (
96
- loss , sample_weight = tf .shape (tree .flatten (x )[0 ])[0 ]
96
+ loss_module .unscale_loss_for_distribution (loss ),
97
+ sample_weight = tf .shape (tree .flatten (x )[0 ])[0 ],
97
98
)
98
99
return self .compute_metrics (x , y , y_pred , sample_weight = sample_weight )
99
100
@@ -710,17 +711,8 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None):
710
711
self ._symbolic_build (data_batch = data_batch )
711
712
712
713
def _aggregate_additional_loss (self , loss ):
713
- if not backend_module .is_float_dtype (loss .dtype ):
714
- loss = ops_module .cast (loss , dtype = backend_module .floatx ())
715
- loss = ops_module .sum (loss )
716
-
717
- # Scales the loss by the number of replicas in the strategy.
718
- num_replicas = tf .distribute .get_strategy ().num_replicas_in_sync
719
- if num_replicas > 1 :
720
- loss = ops_module .multiply (
721
- loss , ops_module .cast (1.0 / num_replicas , loss .dtype )
722
- )
723
- return loss
714
+ loss = super ()._aggregate_additional_loss (loss )
715
+ return loss_module .scale_loss_for_distribution (loss )
724
716
725
717
726
718
class TFEpochIterator (EpochIterator ):
0 commit comments