Skip to content

Commit 7caec08

Browse files
authored
Fix issue with using a callable as a learning rate (#19484)
* the fix * added test
1 parent c73a116 commit 7caec08

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

keras/optimizers/base_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def _get_current_learning_rate(self):
567567
):
568568
return self._learning_rate(self.iterations)
569569
elif callable(self._learning_rate):
570-
return self._learning_rate(self.iterations)
570+
return self._learning_rate()
571571
return self._learning_rate
572572

573573
def _filter_empty_gradients(self, grads, vars):

keras/optimizers/optimizer_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,12 @@ def test_tf_checkpointing(self):
243243
checkpoint.restore(save_path)
244244
pred = model.predict(x)
245245
self.assertAllClose(pred, ref_pred, atol=1e-5)
246+
247+
def test_callable_learning_rate(self):
248+
v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
249+
grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])
250+
optimizer = optimizers.AdamW(learning_rate=lambda: 0.0001)
251+
self.assertAllClose(optimizer.iterations, 0)
252+
optimizer.apply_gradients([(grads, v)])
253+
self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]], atol=1e-4)
254+
self.assertAllClose(optimizer.iterations, 1)

0 commit comments

Comments
 (0)