Ensure that JAX model state is current inside callbacks

This commit is contained in:
Francois Chollet 2023-05-24 21:30:49 -07:00
parent 5c40f0def4
commit a9f4ccd0d7
7 changed files with 134 additions and 67 deletions

@ -288,6 +288,7 @@ class JAXTrainer(base_trainer.Trainer):
if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
return self.history
def evaluate(
@ -452,7 +453,7 @@ class JAXTrainer(base_trainer.Trainer):
logs = self._pythonify_logs(self.get_metrics_result())
callbacks.on_test_end(logs)
self._jax_state = None
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
@ -565,7 +566,7 @@ class JAXTrainer(base_trainer.Trainer):
)
def jax_state_sync(self):
if not hasattr(self, "_jax_state"):
if not getattr(self, "_jax_state", None):
return
trainable_variables = self._jax_state.get("trainable_variables", None)

@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
def dot(x, y):
x, y = convert_to_tensor(x), convert_to_tensor(y)
if x.ndim == 0 or y.ndim == 0:
if len(x.shape) == 0 or len(y.shape) == 0:
return torch.multiply(x, y)
return torch.matmul(x, y)
@ -327,7 +327,7 @@ def expm1(x):
def flip(x, axis=None):
x = convert_to_tensor(x)
if axis is None:
axis = tuple(range(x.ndim))
axis = tuple(range(len(x.shape)))
if isinstance(axis, int):
axis = (axis,)
return torch.flip(x, dims=axis)
@ -341,21 +341,23 @@ def floor(x):
def full(shape, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value)
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
reps_by_dim = (*shape[:-1], reps)
return torch.tile(fill_value, reps_by_dim)
raise NotImplementedError(
"`torch.full()` only accepts scalars for `fill_value`. "
f"Received: fill_value={fill_value}"
)
# TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value)
reps_by_dim = tuple(
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
raise NotImplementedError(
"`torch.full()` only accepts scalars for `fill_value`."
)
return torch.tile(fill_value, reps_by_dim)
# TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
x = convert_to_tensor(x)
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
@ -428,27 +430,15 @@ def linspace(
"torch.linspace does not support an `axis` argument. "
f"Received axis={axis}"
)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype)
if endpoint is False:
stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
else:
linspace = torch.linspace(
start=start,
end=stop,
steps=num,
dtype=dtype,
)
linspace = torch.linspace(
start=start,
end=stop,
steps=num,
dtype=dtype,
)
if retstep is True:
return (linspace, num)
return linspace
@ -505,26 +495,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = to_torch_dtype(dtype)
if endpoint is False:
stop = stop - ((stop - start) / num)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
logspace = base**linspace
else:
logspace = torch.logspace(
start=start,
end=stop,
steps=num,
base=base,
dtype=dtype,
)
logspace = torch.logspace(
start=start,
end=stop,
steps=num,
base=base,
dtype=dtype,
)
return logspace
@ -738,7 +715,14 @@ def tile(x, repeats):
def trace(x, offset=None, axis1=None, axis2=None):
x = convert_to_tensor(x)
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1)
# TODO: implement support for these arguments
# API divergence between `np.trace()` and `torch.trace()`
if offset or axis1 or axis2:
raise NotImplementedError(
"Arguments not supported by `torch.trace: "
f"offset={offset}, axis1={axis1}, axis2={axis2}"
)
return torch.trace(x)
def tri(N, M=None, k=0, dtype="float32"):

@ -1,3 +1,4 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
@ -65,13 +66,25 @@ class Callback:
def __init__(self):
self.validation_data = None
self.model = None
self._model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
self._model = model
@property
def model(self):
if backend.backend() == "jax" and hasattr(
self._model, "jax_state_sync"
):
# With JAX, by default the model state is not
# attached to the model in the middle of an
# epoch. We have to force a sync before
# accessing model state for e.g. checkpointing.
self._model.jax_state_sync()
return self._model
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""

@ -70,7 +70,7 @@ class CallbackList(Callback):
callback.set_params(params)
def set_model(self, model):
self.model = model
super().set_model(model)
if self._history:
model.history = self._history
for callback in self.callbacks:

@ -0,0 +1,29 @@
import numpy as np
from keras_core import models
from keras_core import testing
from keras_core.callbacks.callback import Callback
class CallbackTest(testing.TestCase):
def test_model_state_is_current_on_epoch_end(self):
class TestModel(models.Model):
def __init__(self):
super().__init__()
self.iterations = self.add_variable(
shape=(), initializer="zeros", trainable=False
)
def call(self, inputs):
self.iterations.assign(self.iterations + 1)
return inputs
class CBK(Callback):
def on_batch_end(self, batch, logs):
assert np.int32(self.model.iterations) == batch + 1
model = TestModel()
model.compile(optimizer="sgd", loss="mse")
x = np.random.random((8, 1))
y = np.random.random((8, 1))
model.fit(x, y, callbacks=[CBK()], batch_size=2)

@ -54,7 +54,7 @@ class EarlyStoppingTest(testing.TestCase):
for patience in cases:
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
stopper.model = models.Sequential()
stopper.set_model(models.Sequential())
stopper.model.compile(loss="mse", optimizer="sgd")
stopper.on_train_begin()
@ -130,7 +130,7 @@ class EarlyStoppingTest(testing.TestCase):
early_stop = callbacks.EarlyStopping(
monitor="val_loss", patience=2, restore_best_weights=True
)
early_stop.model = DummyModel()
early_stop.set_model(DummyModel())
losses = [0.2, 0.15, 0.1, 0.11, 0.12]
# The best configuration is in the epoch 2 (loss = 0.1000).
epochs_trained = 0
@ -153,7 +153,7 @@ class EarlyStoppingTest(testing.TestCase):
baseline=0.5,
restore_best_weights=True,
)
early_stop.model = DummyModel()
early_stop.set_model(DummyModel())
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
# The best configuration is in the epoch 2 (loss = 0.7000).
epochs_trained = 0

@ -1713,16 +1713,27 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.array(knp.full_like(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
self.assertAllClose(
np.array(knp.FullLike()(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_like_without_torch(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose(
np.array(knp.FullLike()(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
@ -1823,6 +1834,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.linspace(0, 10, 5, endpoint=False),
)
# TODO: torch.linspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.linspace` has no support for array start/stop.",
)
def test_linspace_without_torch(self):
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
@ -1927,9 +1945,15 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.logspace(0, 10, 5, endpoint=False),
)
# TODO: torch.logspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.logspace` has no support for array start/stop.",
)
def test_logspace_without_torch(self):
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
np.array(knp.logspace(start, stop, 5, base=10)),
np.logspace(start, stop, 5, base=10),
@ -2992,7 +3016,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
)
def test_trace(self):
# TODO: implement `torch.trace` support for arguments `offset`,
# `axis1`, `axis2` and delete NotImplementedError
x = np.arange(24).reshape([1, 2, 3, 4])
self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
self.assertAllClose(
@ -3062,15 +3092,25 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
)
self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),
)
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_without_torch(self):
self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),
)
self.assertAllClose(
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),