Simplify / improve jax trainer a bit

This commit is contained in:
Francois Chollet 2023-05-24 22:14:20 -07:00
parent 186cbf6b7c
commit fab6abfff5
2 changed files with 15 additions and 19 deletions

@ -222,9 +222,6 @@ class JAXTrainer(base_trainer.Trainer):
metrics_variables,
) = state
# Update `iterations` var (necessary to make LR schedules work)
# It's always the first variable tracked by the optimizer.
self.optimizer.iterations.assign(optimizer_variables[0])
# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {

@ -5,7 +5,6 @@ import math
import numpy as np
from keras_core import backend
from keras_core import callbacks
from keras_core import layers
from keras_core import optimizers
from keras_core import testing
@ -15,7 +14,13 @@ from keras_core.optimizers import schedules
class TestFitLRSchedulesFlow(testing.TestCase):
def test_fit_lr_correctness(self):
model = Sequential([layers.Dense(3)])
model = Sequential(
[
layers.Dense(
2, kernel_initializer="ones", bias_initializer="ones"
)
]
)
optimizer = optimizers.Adam(
learning_rate=schedules.ExponentialDecay(
initial_learning_rate=0.05, decay_steps=1, decay_rate=0.9
@ -24,21 +29,15 @@ class TestFitLRSchedulesFlow(testing.TestCase):
self.assertEqual(len(optimizer.variables), 1)
self.assertEqual(optimizer.variables[0], 0)
class LRTracker(callbacks.Callback):
def __init__(self, optimizer):
self.optimizer = optimizer
self.logs = []
def on_batch_end(self, *args, **kwargs):
self.logs.append(float(self.optimizer.learning_rate))
model.compile(optimizer=optimizer, loss="mse")
x = np.random.random((16, 2))
y = np.random.random((16, 3))
tracker = LRTracker(optimizer)
model.fit(x, y, epochs=1, batch_size=4, callbacks=[tracker])
self.assertEqual(optimizer.variables[0], 4)
self.assertAllClose(tracker.logs, [0.045, 0.0405, 0.03645, 0.032805])
x = np.arange(32).reshape((16, 2))
y = np.arange(32).reshape((16, 2))
history = model.fit(x, y, epochs=3, batch_size=4, shuffle=False)
self.assertEqual(optimizer.variables[0], 4 * 3)
self.assertAllClose(
history.history["loss"],
[230.79457092285156, 128.30319213867188, 79.33648681640625],
)
class ExponentialDecayTest(testing.TestCase):