@@ -90,7 +90,7 @@ def call(self, x):
90
90
}
91
91
92
92
93
- class ListModel (Trainer , layers .Layer ):
93
+ class ListInputModel (Trainer , layers .Layer ):
94
94
def __init__ (self , units ):
95
95
layers .Layer .__init__ (self )
96
96
Trainer .__init__ (self )
@@ -110,6 +110,25 @@ def call(self, x):
110
110
return self .dense_1 (x [0 ]) + self .dense_2 (x [1 ])
111
111
112
112
113
+ class ListOutputModel (Trainer , layers .Layer ):
114
+ def __init__ (self , units ):
115
+ layers .Layer .__init__ (self )
116
+ Trainer .__init__ (self )
117
+ self .dense_1 = layers .Dense (
118
+ units ,
119
+ use_bias = False ,
120
+ kernel_initializer = initializers .Ones (),
121
+ )
122
+ self .dense_2 = layers .Dense (
123
+ units ,
124
+ use_bias = False ,
125
+ kernel_initializer = initializers .Ones (),
126
+ )
127
+
128
+ def call (self , x ):
129
+ return [self .dense_1 (x ), self .dense_2 (x )]
130
+
131
+
113
132
class TrainingTestingLayer (Trainer , layers .Layer ):
114
133
def __init__ (self , ** kwargs ):
115
134
layers .Layer .__init__ (self , ** kwargs )
@@ -265,8 +284,8 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
265
284
self .assertIn ("mean_squared_error" , history )
266
285
self .assertAllClose (
267
286
history ["mean_squared_error" ],
268
- [14.402393 , 10.991339 , 8.388159 ],
269
- atol = 6.1051628e-1 ,
287
+ [14.5 , 11.5 , 8.5 ],
288
+ atol = 0.6 , # TODO: abnormal results for certain configs.
270
289
)
271
290
272
291
@parameterized .named_parameters (
@@ -1164,7 +1183,7 @@ def metrics_one(y_true, y_pred):
1164
1183
1165
1184
@pytest .mark .requires_trainable_backend
1166
1185
def test_nested_inputs (self ):
1167
- model = ListModel (units = 2 )
1186
+ model = ListInputModel (units = 2 )
1168
1187
out = model ([np .ones ((3 , 2 )), np .ones ((3 , 3 ))])
1169
1188
self .assertEqual (tuple (out .shape ), (3 , 2 ))
1170
1189
model .compile (optimizer = "sgd" , loss = "mse" , metrics = ["mse" ])
@@ -1420,6 +1439,93 @@ def compute_loss(
1420
1439
history = model .fit (x , y )
1421
1440
self .assertGreater (history .history ["custom" ][0 ], 0.0 )
1422
1441
1442
+ @pytest .mark .requires_trainable_backend
1443
+ def test_loss_weights (self ):
1444
+ epochs = 3
1445
+ batch_size = 20
1446
+ dataset_size = batch_size * 2
1447
+
1448
+ # Single output case.
1449
+ model = ExampleModel (units = 3 )
1450
+ model .compile (
1451
+ optimizer = optimizers .SGD (),
1452
+ loss = losses .MeanSquaredError (),
1453
+ metrics = [metrics .MeanSquaredError ()],
1454
+ loss_weights = 0.2 ,
1455
+ )
1456
+ x = np .ones ((dataset_size , 4 ))
1457
+ y = np .zeros ((dataset_size , 3 ))
1458
+ history = model .fit (
1459
+ x ,
1460
+ y ,
1461
+ batch_size = batch_size ,
1462
+ epochs = epochs ,
1463
+ )
1464
+ history = history .history
1465
+ self .assertIn ("loss" , history )
1466
+ self .assertAllClose (
1467
+ history ["loss" ],
1468
+ [3.182979 , 3.115617 , 3.049681 ],
1469
+ atol = 1e-3 ,
1470
+ )
1471
+
1472
+ # Dict output case.
1473
+ model = StructModel (units = 3 )
1474
+ model .compile (
1475
+ optimizer = optimizers .SGD (),
1476
+ loss = {
1477
+ "y_one" : losses .MeanSquaredError (),
1478
+ "y_two" : losses .MeanSquaredError (),
1479
+ },
1480
+ metrics = {
1481
+ "y_one" : metrics .MeanSquaredError (),
1482
+ "y_two" : metrics .MeanSquaredError (),
1483
+ },
1484
+ loss_weights = {"y_one" : 0.1 , "y_two" : 0.2 },
1485
+ )
1486
+ x1 = np .ones ((dataset_size , 4 ))
1487
+ x2 = np .ones ((dataset_size , 4 ))
1488
+ y1 = np .zeros ((dataset_size , 3 ))
1489
+ y2 = np .zeros ((dataset_size , 3 ))
1490
+ history = model .fit (
1491
+ {"x_one" : x1 , "x_two" : x2 },
1492
+ {"y_one" : y1 , "y_two" : y2 },
1493
+ batch_size = batch_size ,
1494
+ epochs = epochs ,
1495
+ )
1496
+ history = history .history
1497
+ self .assertIn ("loss" , history )
1498
+ self .assertAllClose (
1499
+ history ["loss" ],
1500
+ [4.778718 , 4.694403 , 4.611693 ],
1501
+ atol = 1e-3 ,
1502
+ )
1503
+
1504
+ # List output case.
1505
+ model = ListOutputModel (units = 3 )
1506
+ model .compile (
1507
+ optimizer = optimizers .SGD (),
1508
+ loss = [losses .MeanSquaredError (), losses .MeanSquaredError ()],
1509
+ metrics = [metrics .MeanSquaredError (), metrics .MeanSquaredError ()],
1510
+ loss_weights = [0.1 , 0.2 ],
1511
+ )
1512
+ x = np .ones ((dataset_size , 4 ))
1513
+ y1 = np .zeros ((dataset_size , 3 ))
1514
+ y2 = np .zeros ((dataset_size , 3 ))
1515
+ history = model .fit (
1516
+ x ,
1517
+ [y1 , y2 ],
1518
+ batch_size = batch_size ,
1519
+ epochs = epochs ,
1520
+ )
1521
+ history = history .history
1522
+ self .assertIn ("loss" , history )
1523
+ self .assertAllClose (
1524
+ history ["loss" ],
1525
+ [4.778718 , 4.694403 , 4.611693 ],
1526
+ atol = 1e-3 ,
1527
+ )
1528
+
1423
1529
1424
1530
class TrainerDistributeTest (testing .TestCase ):
1425
1531
@pytest .mark .skipif (
0 commit comments