Ensure that JAX model state is current inside callbacks

This commit is contained in:
Francois Chollet 2023-05-24 21:30:49 -07:00
parent 5c40f0def4
commit a9f4ccd0d7
7 changed files with 134 additions and 67 deletions

@ -288,6 +288,7 @@ class JAXTrainer(base_trainer.Trainer):
if getattr(self, "_eval_epoch_iterator", None) is not None: if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs) callbacks.on_train_end(logs=training_logs)
self._jax_state = None
return self.history return self.history
def evaluate( def evaluate(
@ -452,7 +453,7 @@ class JAXTrainer(base_trainer.Trainer):
logs = self._pythonify_logs(self.get_metrics_result()) logs = self._pythonify_logs(self.get_metrics_result())
callbacks.on_test_end(logs) callbacks.on_test_end(logs)
self._jax_state = None
if return_dict: if return_dict:
return logs return logs
return self._flatten_metrics_in_order(logs) return self._flatten_metrics_in_order(logs)
@ -565,7 +566,7 @@ class JAXTrainer(base_trainer.Trainer):
) )
def jax_state_sync(self): def jax_state_sync(self):
if not hasattr(self, "_jax_state"): if not getattr(self, "_jax_state", None):
return return
trainable_variables = self._jax_state.get("trainable_variables", None) trainable_variables = self._jax_state.get("trainable_variables", None)

@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
def dot(x, y): def dot(x, y):
x, y = convert_to_tensor(x), convert_to_tensor(y) x, y = convert_to_tensor(x), convert_to_tensor(y)
if x.ndim == 0 or y.ndim == 0: if len(x.shape) == 0 or len(y.shape) == 0:
return torch.multiply(x, y) return torch.multiply(x, y)
return torch.matmul(x, y) return torch.matmul(x, y)
@ -327,7 +327,7 @@ def expm1(x):
def flip(x, axis=None): def flip(x, axis=None):
x = convert_to_tensor(x) x = convert_to_tensor(x)
if axis is None: if axis is None:
axis = tuple(range(x.ndim)) axis = tuple(range(len(x.shape)))
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
return torch.flip(x, dims=axis) return torch.flip(x, dims=axis)
@ -341,21 +341,23 @@ def floor(x):
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"): if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value) raise NotImplementedError(
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction "`torch.full()` only accepts scalars for `fill_value`. "
reps_by_dim = (*shape[:-1], reps) f"Received: fill_value={fill_value}"
return torch.tile(fill_value, reps_by_dim) )
# TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype) return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None): def full_like(x, fill_value, dtype=None):
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"): if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value) raise NotImplementedError(
reps_by_dim = tuple( "`torch.full()` only accepts scalars for `fill_value`."
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
) )
return torch.tile(fill_value, reps_by_dim) # TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
x = convert_to_tensor(x) x = convert_to_tensor(x)
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype) return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
@ -431,18 +433,6 @@ def linspace(
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype)
if endpoint is False: if endpoint is False:
stop = stop - ((stop - start) / num) stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
else:
linspace = torch.linspace( linspace = torch.linspace(
start=start, start=start,
end=stop, end=stop,
@ -505,19 +495,6 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype)
if endpoint is False: if endpoint is False:
stop = stop - ((stop - start) / num) stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
logspace = base**linspace
else:
logspace = torch.logspace( logspace = torch.logspace(
start=start, start=start,
end=stop, end=stop,
@ -738,7 +715,14 @@ def tile(x, repeats):
def trace(x, offset=None, axis1=None, axis2=None): def trace(x, offset=None, axis1=None, axis2=None):
x = convert_to_tensor(x) x = convert_to_tensor(x)
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1) # TODO: implement support for these arguments
# API divergence between `np.trace()` and `torch.trace()`
if offset or axis1 or axis2:
raise NotImplementedError(
"Arguments not supported by `torch.trace: "
f"offset={offset}, axis1={axis1}, axis2={axis2}"
)
return torch.trace(x)
def tri(N, M=None, k=0, dtype="float32"): def tri(N, M=None, k=0, dtype="float32"):

@ -1,3 +1,4 @@
from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
@ -65,13 +66,25 @@ class Callback:
def __init__(self): def __init__(self):
self.validation_data = None self.validation_data = None
self.model = None self._model = None
def set_params(self, params): def set_params(self, params):
self.params = params self.params = params
def set_model(self, model): def set_model(self, model):
self.model = model self._model = model
@property
def model(self):
if backend.backend() == "jax" and hasattr(
self._model, "jax_state_sync"
):
# With JAX, by default the model state is not
# attached to the model in the middle of an
# epoch. We have to force a sync before
# accessing model state for e.g. checkpointing.
self._model.jax_state_sync()
return self._model
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`.""" """A backwards compatibility alias for `on_train_batch_begin`."""

@ -70,7 +70,7 @@ class CallbackList(Callback):
callback.set_params(params) callback.set_params(params)
def set_model(self, model): def set_model(self, model):
self.model = model super().set_model(model)
if self._history: if self._history:
model.history = self._history model.history = self._history
for callback in self.callbacks: for callback in self.callbacks:

@ -0,0 +1,29 @@
import numpy as np
from keras_core import models
from keras_core import testing
from keras_core.callbacks.callback import Callback
class CallbackTest(testing.TestCase):
def test_model_state_is_current_on_epoch_end(self):
class TestModel(models.Model):
def __init__(self):
super().__init__()
self.iterations = self.add_variable(
shape=(), initializer="zeros", trainable=False
)
def call(self, inputs):
self.iterations.assign(self.iterations + 1)
return inputs
class CBK(Callback):
def on_batch_end(self, batch, logs):
assert np.int32(self.model.iterations) == batch + 1
model = TestModel()
model.compile(optimizer="sgd", loss="mse")
x = np.random.random((8, 1))
y = np.random.random((8, 1))
model.fit(x, y, callbacks=[CBK()], batch_size=2)

@ -54,7 +54,7 @@ class EarlyStoppingTest(testing.TestCase):
for patience in cases: for patience in cases:
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience) stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
stopper.model = models.Sequential() stopper.set_model(models.Sequential())
stopper.model.compile(loss="mse", optimizer="sgd") stopper.model.compile(loss="mse", optimizer="sgd")
stopper.on_train_begin() stopper.on_train_begin()
@ -130,7 +130,7 @@ class EarlyStoppingTest(testing.TestCase):
early_stop = callbacks.EarlyStopping( early_stop = callbacks.EarlyStopping(
monitor="val_loss", patience=2, restore_best_weights=True monitor="val_loss", patience=2, restore_best_weights=True
) )
early_stop.model = DummyModel() early_stop.set_model(DummyModel())
losses = [0.2, 0.15, 0.1, 0.11, 0.12] losses = [0.2, 0.15, 0.1, 0.11, 0.12]
# The best configuration is in the epoch 2 (loss = 0.1000). # The best configuration is in the epoch 2 (loss = 0.1000).
epochs_trained = 0 epochs_trained = 0
@ -153,7 +153,7 @@ class EarlyStoppingTest(testing.TestCase):
baseline=0.5, baseline=0.5,
restore_best_weights=True, restore_best_weights=True,
) )
early_stop.model = DummyModel() early_stop.set_model(DummyModel())
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73] losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
# The best configuration is in the epoch 2 (loss = 0.7000). # The best configuration is in the epoch 2 (loss = 0.7000).
epochs_trained = 0 epochs_trained = 0

@ -1713,16 +1713,27 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.array(knp.full_like(x, 2, dtype="float32")), np.array(knp.full_like(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"), np.full_like(x, 2, dtype="float32"),
) )
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
self.assertAllClose( self.assertAllClose(
np.array(knp.FullLike()(x, 2, dtype="float32")), np.array(knp.FullLike()(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"), np.full_like(x, 2, dtype="float32"),
) )
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_like_without_torch(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose( self.assertAllClose(
np.array(knp.FullLike()(x, np.ones([2, 3]))), np.array(knp.FullLike()(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])), np.full_like(x, np.ones([2, 3])),
@ -1823,6 +1834,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.linspace(0, 10, 5, endpoint=False), np.linspace(0, 10, 5, endpoint=False),
) )
# TODO: torch.linspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.linspace` has no support for array start/stop.",
)
def test_linspace_without_torch(self):
start = np.zeros([2, 3, 4]) start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4]) stop = np.ones([2, 3, 4])
self.assertAllClose( self.assertAllClose(
@ -1927,9 +1945,15 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.logspace(0, 10, 5, endpoint=False), np.logspace(0, 10, 5, endpoint=False),
) )
# TODO: torch.logspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.logspace` has no support for array start/stop.",
)
def test_logspace_without_torch(self):
start = np.zeros([2, 3, 4]) start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4]) stop = np.ones([2, 3, 4])
self.assertAllClose( self.assertAllClose(
np.array(knp.logspace(start, stop, 5, base=10)), np.array(knp.logspace(start, stop, 5, base=10)),
np.logspace(start, stop, 5, base=10), np.logspace(start, stop, 5, base=10),
@ -2992,7 +3016,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
)
def test_trace(self): def test_trace(self):
# TODO: implement `torch.trace` support for arguments `offset`,
# `axis1`, `axis2` and delete NotImplementedError
x = np.arange(24).reshape([1, 2, 3, 4]) x = np.arange(24).reshape([1, 2, 3, 4])
self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
self.assertAllClose( self.assertAllClose(
@ -3062,15 +3092,25 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose( self.assertAllClose(
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
) )
self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),
)
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose( self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
) )
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_without_torch(self):
self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),
)
self.assertAllClose( self.assertAllClose(
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])), np.full([2, 3], np.array([1, 4, 5])),