Simplify / improve jax trainer a bit
This commit is contained in:
parent
186cbf6b7c
commit
fab6abfff5
@ -222,9 +222,6 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
metrics_variables,
|
||||
) = state
|
||||
|
||||
# Update `iterations` var (necessary to make LR schedules work)
|
||||
# It's always the first variable tracked by the optimizer.
|
||||
self.optimizer.iterations.assign(optimizer_variables[0])
|
||||
# Setting _jax_state enables callbacks to force a state sync
|
||||
# if they need to.
|
||||
self._jax_state = {
|
||||
|
@ -5,7 +5,6 @@ import math
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import callbacks
|
||||
from keras_core import layers
|
||||
from keras_core import optimizers
|
||||
from keras_core import testing
|
||||
@ -15,7 +14,13 @@ from keras_core.optimizers import schedules
|
||||
|
||||
class TestFitLRSchedulesFlow(testing.TestCase):
|
||||
def test_fit_lr_correctness(self):
|
||||
model = Sequential([layers.Dense(3)])
|
||||
model = Sequential(
|
||||
[
|
||||
layers.Dense(
|
||||
2, kernel_initializer="ones", bias_initializer="ones"
|
||||
)
|
||||
]
|
||||
)
|
||||
optimizer = optimizers.Adam(
|
||||
learning_rate=schedules.ExponentialDecay(
|
||||
initial_learning_rate=0.05, decay_steps=1, decay_rate=0.9
|
||||
@ -24,21 +29,15 @@ class TestFitLRSchedulesFlow(testing.TestCase):
|
||||
self.assertEqual(len(optimizer.variables), 1)
|
||||
self.assertEqual(optimizer.variables[0], 0)
|
||||
|
||||
class LRTracker(callbacks.Callback):
|
||||
def __init__(self, optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.logs = []
|
||||
|
||||
def on_batch_end(self, *args, **kwargs):
|
||||
self.logs.append(float(self.optimizer.learning_rate))
|
||||
|
||||
model.compile(optimizer=optimizer, loss="mse")
|
||||
x = np.random.random((16, 2))
|
||||
y = np.random.random((16, 3))
|
||||
tracker = LRTracker(optimizer)
|
||||
model.fit(x, y, epochs=1, batch_size=4, callbacks=[tracker])
|
||||
self.assertEqual(optimizer.variables[0], 4)
|
||||
self.assertAllClose(tracker.logs, [0.045, 0.0405, 0.03645, 0.032805])
|
||||
x = np.arange(32).reshape((16, 2))
|
||||
y = np.arange(32).reshape((16, 2))
|
||||
history = model.fit(x, y, epochs=3, batch_size=4, shuffle=False)
|
||||
self.assertEqual(optimizer.variables[0], 4 * 3)
|
||||
self.assertAllClose(
|
||||
history.history["loss"],
|
||||
[230.79457092285156, 128.30319213867188, 79.33648681640625],
|
||||
)
|
||||
|
||||
|
||||
class ExponentialDecayTest(testing.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user