From fab6abfff53be5bec3ef19d0c6275e7897291d4e Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 24 May 2023 22:14:20 -0700 Subject: [PATCH] Simplify / improve jax trainer a bit --- keras_core/backend/jax/trainer.py | 3 -- .../schedules/learning_rate_schedule_test.py | 31 +++++++++---------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 80c222e3f..7d3fb973e 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -222,9 +222,6 @@ class JAXTrainer(base_trainer.Trainer): metrics_variables, ) = state - # Update `iterations` var (necessary to make LR schedules work) - # It's always the first variable tracked by the optimizer. - self.optimizer.iterations.assign(optimizer_variables[0]) # Setting _jax_state enables callbacks to force a state sync # if they need to. self._jax_state = { diff --git a/keras_core/optimizers/schedules/learning_rate_schedule_test.py b/keras_core/optimizers/schedules/learning_rate_schedule_test.py index fba465042..d0ecc797d 100644 --- a/keras_core/optimizers/schedules/learning_rate_schedule_test.py +++ b/keras_core/optimizers/schedules/learning_rate_schedule_test.py @@ -5,7 +5,6 @@ import math import numpy as np from keras_core import backend -from keras_core import callbacks from keras_core import layers from keras_core import optimizers from keras_core import testing @@ -15,7 +14,13 @@ from keras_core.optimizers import schedules class TestFitLRSchedulesFlow(testing.TestCase): def test_fit_lr_correctness(self): - model = Sequential([layers.Dense(3)]) + model = Sequential( + [ + layers.Dense( + 2, kernel_initializer="ones", bias_initializer="ones" + ) + ] + ) optimizer = optimizers.Adam( learning_rate=schedules.ExponentialDecay( initial_learning_rate=0.05, decay_steps=1, decay_rate=0.9 @@ -24,21 +29,15 @@ class TestFitLRSchedulesFlow(testing.TestCase): self.assertEqual(len(optimizer.variables), 1) self.assertEqual(optimizer.variables[0], 0) - class LRTracker(callbacks.Callback): - def __init__(self, optimizer): - self.optimizer = optimizer - self.logs = [] - - def on_batch_end(self, *args, **kwargs): - self.logs.append(float(self.optimizer.learning_rate)) - model.compile(optimizer=optimizer, loss="mse") - x = np.random.random((16, 2)) - y = np.random.random((16, 3)) - tracker = LRTracker(optimizer) - model.fit(x, y, epochs=1, batch_size=4, callbacks=[tracker]) - self.assertEqual(optimizer.variables[0], 4) - self.assertAllClose(tracker.logs, [0.045, 0.0405, 0.03645, 0.032805]) + x = np.arange(32).reshape((16, 2)) + y = np.arange(32).reshape((16, 2)) + history = model.fit(x, y, epochs=3, batch_size=4, shuffle=False) + self.assertEqual(optimizer.variables[0], 4 * 3) + self.assertAllClose( + history.history["loss"], + [230.79457092285156, 128.30319213867188, 79.33648681640625], + ) class ExponentialDecayTest(testing.TestCase):