Ensure that JAX model state is current inside callbacks
This commit is contained in:
parent
5c40f0def4
commit
a9f4ccd0d7
@ -288,6 +288,7 @@ class JAXTrainer(base_trainer.Trainer):
|
|||||||
if getattr(self, "_eval_epoch_iterator", None) is not None:
|
if getattr(self, "_eval_epoch_iterator", None) is not None:
|
||||||
del self._eval_epoch_iterator
|
del self._eval_epoch_iterator
|
||||||
callbacks.on_train_end(logs=training_logs)
|
callbacks.on_train_end(logs=training_logs)
|
||||||
|
self._jax_state = None
|
||||||
return self.history
|
return self.history
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
@ -452,7 +453,7 @@ class JAXTrainer(base_trainer.Trainer):
|
|||||||
|
|
||||||
logs = self._pythonify_logs(self.get_metrics_result())
|
logs = self._pythonify_logs(self.get_metrics_result())
|
||||||
callbacks.on_test_end(logs)
|
callbacks.on_test_end(logs)
|
||||||
|
self._jax_state = None
|
||||||
if return_dict:
|
if return_dict:
|
||||||
return logs
|
return logs
|
||||||
return self._flatten_metrics_in_order(logs)
|
return self._flatten_metrics_in_order(logs)
|
||||||
@ -565,7 +566,7 @@ class JAXTrainer(base_trainer.Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def jax_state_sync(self):
|
def jax_state_sync(self):
|
||||||
if not hasattr(self, "_jax_state"):
|
if not getattr(self, "_jax_state", None):
|
||||||
return
|
return
|
||||||
|
|
||||||
trainable_variables = self._jax_state.get("trainable_variables", None)
|
trainable_variables = self._jax_state.get("trainable_variables", None)
|
||||||
|
@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
|
|||||||
|
|
||||||
def dot(x, y):
|
def dot(x, y):
|
||||||
x, y = convert_to_tensor(x), convert_to_tensor(y)
|
x, y = convert_to_tensor(x), convert_to_tensor(y)
|
||||||
if x.ndim == 0 or y.ndim == 0:
|
if len(x.shape) == 0 or len(y.shape) == 0:
|
||||||
return torch.multiply(x, y)
|
return torch.multiply(x, y)
|
||||||
return torch.matmul(x, y)
|
return torch.matmul(x, y)
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ def expm1(x):
|
|||||||
def flip(x, axis=None):
|
def flip(x, axis=None):
|
||||||
x = convert_to_tensor(x)
|
x = convert_to_tensor(x)
|
||||||
if axis is None:
|
if axis is None:
|
||||||
axis = tuple(range(x.ndim))
|
axis = tuple(range(len(x.shape)))
|
||||||
if isinstance(axis, int):
|
if isinstance(axis, int):
|
||||||
axis = (axis,)
|
axis = (axis,)
|
||||||
return torch.flip(x, dims=axis)
|
return torch.flip(x, dims=axis)
|
||||||
@ -341,21 +341,23 @@ def floor(x):
|
|||||||
def full(shape, fill_value, dtype=None):
|
def full(shape, fill_value, dtype=None):
|
||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype)
|
||||||
if hasattr(fill_value, "__len__"):
|
if hasattr(fill_value, "__len__"):
|
||||||
fill_value = convert_to_tensor(fill_value)
|
raise NotImplementedError(
|
||||||
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
|
"`torch.full()` only accepts scalars for `fill_value`. "
|
||||||
reps_by_dim = (*shape[:-1], reps)
|
f"Received: fill_value={fill_value}"
|
||||||
return torch.tile(fill_value, reps_by_dim)
|
)
|
||||||
|
# TODO: implement conversion of shape into repetitions for `torch.tile``
|
||||||
|
# return torch.tile(fill_value, reps)
|
||||||
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
|
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def full_like(x, fill_value, dtype=None):
|
def full_like(x, fill_value, dtype=None):
|
||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype)
|
||||||
if hasattr(fill_value, "__len__"):
|
if hasattr(fill_value, "__len__"):
|
||||||
fill_value = convert_to_tensor(fill_value)
|
raise NotImplementedError(
|
||||||
reps_by_dim = tuple(
|
"`torch.full()` only accepts scalars for `fill_value`."
|
||||||
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
|
|
||||||
)
|
)
|
||||||
return torch.tile(fill_value, reps_by_dim)
|
# TODO: implement conversion of shape into repetitions for `torch.tile``
|
||||||
|
# return torch.tile(fill_value, reps)
|
||||||
x = convert_to_tensor(x)
|
x = convert_to_tensor(x)
|
||||||
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
|
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
|
||||||
|
|
||||||
@ -428,27 +430,15 @@ def linspace(
|
|||||||
"torch.linspace does not support an `axis` argument. "
|
"torch.linspace does not support an `axis` argument. "
|
||||||
f"Received axis={axis}"
|
f"Received axis={axis}"
|
||||||
)
|
)
|
||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype)
|
||||||
if endpoint is False:
|
if endpoint is False:
|
||||||
stop = stop - ((stop - start) / num)
|
stop = stop - ((stop - start) / num)
|
||||||
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
|
linspace = torch.linspace(
|
||||||
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
|
start=start,
|
||||||
stop = cast(stop, dtype) if endpoint is False and dtype else stop
|
end=stop,
|
||||||
steps = torch.arange(num, dtype=dtype) / (num - 1)
|
steps=num,
|
||||||
|
dtype=dtype,
|
||||||
# reshape `steps` to allow for broadcasting
|
)
|
||||||
for i in range(start.ndim):
|
|
||||||
steps = steps.unsqueeze(-1)
|
|
||||||
|
|
||||||
# increments from `start` to `stop` in each dimension
|
|
||||||
linspace = start[None] + steps * (stop - start)[None]
|
|
||||||
else:
|
|
||||||
linspace = torch.linspace(
|
|
||||||
start=start,
|
|
||||||
end=stop,
|
|
||||||
steps=num,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
if retstep is True:
|
if retstep is True:
|
||||||
return (linspace, num)
|
return (linspace, num)
|
||||||
return linspace
|
return linspace
|
||||||
@ -505,26 +495,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
|
|||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype)
|
||||||
if endpoint is False:
|
if endpoint is False:
|
||||||
stop = stop - ((stop - start) / num)
|
stop = stop - ((stop - start) / num)
|
||||||
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
|
logspace = torch.logspace(
|
||||||
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
|
start=start,
|
||||||
stop = cast(stop, dtype) if endpoint is False and dtype else stop
|
end=stop,
|
||||||
steps = torch.arange(num, dtype=dtype) / (num - 1)
|
steps=num,
|
||||||
|
base=base,
|
||||||
# reshape `steps` to allow for broadcasting
|
dtype=dtype,
|
||||||
for i in range(start.ndim):
|
)
|
||||||
steps = steps.unsqueeze(-1)
|
|
||||||
|
|
||||||
# increments from `start` to `stop` in each dimension
|
|
||||||
linspace = start[None] + steps * (stop - start)[None]
|
|
||||||
logspace = base**linspace
|
|
||||||
else:
|
|
||||||
logspace = torch.logspace(
|
|
||||||
start=start,
|
|
||||||
end=stop,
|
|
||||||
steps=num,
|
|
||||||
base=base,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
return logspace
|
return logspace
|
||||||
|
|
||||||
|
|
||||||
@ -738,7 +715,14 @@ def tile(x, repeats):
|
|||||||
|
|
||||||
def trace(x, offset=None, axis1=None, axis2=None):
|
def trace(x, offset=None, axis1=None, axis2=None):
|
||||||
x = convert_to_tensor(x)
|
x = convert_to_tensor(x)
|
||||||
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1)
|
# TODO: implement support for these arguments
|
||||||
|
# API divergence between `np.trace()` and `torch.trace()`
|
||||||
|
if offset or axis1 or axis2:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Arguments not supported by `torch.trace: "
|
||||||
|
f"offset={offset}, axis1={axis1}, axis2={axis2}"
|
||||||
|
)
|
||||||
|
return torch.trace(x)
|
||||||
|
|
||||||
|
|
||||||
def tri(N, M=None, k=0, dtype="float32"):
|
def tri(N, M=None, k=0, dtype="float32"):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from keras_core import backend
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
|
|
||||||
|
|
||||||
@ -65,13 +66,25 @@ class Callback:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.validation_data = None
|
self.validation_data = None
|
||||||
self.model = None
|
self._model = None
|
||||||
|
|
||||||
def set_params(self, params):
|
def set_params(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
def set_model(self, model):
|
def set_model(self, model):
|
||||||
self.model = model
|
self._model = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
if backend.backend() == "jax" and hasattr(
|
||||||
|
self._model, "jax_state_sync"
|
||||||
|
):
|
||||||
|
# With JAX, by default the model state is not
|
||||||
|
# attached to the model in the middle of an
|
||||||
|
# epoch. We have to force a sync before
|
||||||
|
# accessing model state for e.g. checkpointing.
|
||||||
|
self._model.jax_state_sync()
|
||||||
|
return self._model
|
||||||
|
|
||||||
def on_batch_begin(self, batch, logs=None):
|
def on_batch_begin(self, batch, logs=None):
|
||||||
"""A backwards compatibility alias for `on_train_batch_begin`."""
|
"""A backwards compatibility alias for `on_train_batch_begin`."""
|
||||||
|
@ -70,7 +70,7 @@ class CallbackList(Callback):
|
|||||||
callback.set_params(params)
|
callback.set_params(params)
|
||||||
|
|
||||||
def set_model(self, model):
|
def set_model(self, model):
|
||||||
self.model = model
|
super().set_model(model)
|
||||||
if self._history:
|
if self._history:
|
||||||
model.history = self._history
|
model.history = self._history
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
29
keras_core/callbacks/callback_test.py
Normal file
29
keras_core/callbacks/callback_test.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from keras_core import models
|
||||||
|
from keras_core import testing
|
||||||
|
from keras_core.callbacks.callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackTest(testing.TestCase):
|
||||||
|
def test_model_state_is_current_on_epoch_end(self):
|
||||||
|
class TestModel(models.Model):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.iterations = self.add_variable(
|
||||||
|
shape=(), initializer="zeros", trainable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
self.iterations.assign(self.iterations + 1)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
class CBK(Callback):
|
||||||
|
def on_batch_end(self, batch, logs):
|
||||||
|
assert np.int32(self.model.iterations) == batch + 1
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
model.compile(optimizer="sgd", loss="mse")
|
||||||
|
x = np.random.random((8, 1))
|
||||||
|
y = np.random.random((8, 1))
|
||||||
|
model.fit(x, y, callbacks=[CBK()], batch_size=2)
|
@ -54,7 +54,7 @@ class EarlyStoppingTest(testing.TestCase):
|
|||||||
|
|
||||||
for patience in cases:
|
for patience in cases:
|
||||||
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
|
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
|
||||||
stopper.model = models.Sequential()
|
stopper.set_model(models.Sequential())
|
||||||
stopper.model.compile(loss="mse", optimizer="sgd")
|
stopper.model.compile(loss="mse", optimizer="sgd")
|
||||||
stopper.on_train_begin()
|
stopper.on_train_begin()
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ class EarlyStoppingTest(testing.TestCase):
|
|||||||
early_stop = callbacks.EarlyStopping(
|
early_stop = callbacks.EarlyStopping(
|
||||||
monitor="val_loss", patience=2, restore_best_weights=True
|
monitor="val_loss", patience=2, restore_best_weights=True
|
||||||
)
|
)
|
||||||
early_stop.model = DummyModel()
|
early_stop.set_model(DummyModel())
|
||||||
losses = [0.2, 0.15, 0.1, 0.11, 0.12]
|
losses = [0.2, 0.15, 0.1, 0.11, 0.12]
|
||||||
# The best configuration is in the epoch 2 (loss = 0.1000).
|
# The best configuration is in the epoch 2 (loss = 0.1000).
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
@ -153,7 +153,7 @@ class EarlyStoppingTest(testing.TestCase):
|
|||||||
baseline=0.5,
|
baseline=0.5,
|
||||||
restore_best_weights=True,
|
restore_best_weights=True,
|
||||||
)
|
)
|
||||||
early_stop.model = DummyModel()
|
early_stop.set_model(DummyModel())
|
||||||
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
|
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
|
||||||
# The best configuration is in the epoch 2 (loss = 0.7000).
|
# The best configuration is in the epoch 2 (loss = 0.7000).
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
|
@ -1713,16 +1713,27 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
|||||||
np.array(knp.full_like(x, 2, dtype="float32")),
|
np.array(knp.full_like(x, 2, dtype="float32")),
|
||||||
np.full_like(x, 2, dtype="float32"),
|
np.full_like(x, 2, dtype="float32"),
|
||||||
)
|
)
|
||||||
self.assertAllClose(
|
|
||||||
np.array(knp.full_like(x, np.ones([2, 3]))),
|
|
||||||
np.full_like(x, np.ones([2, 3])),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
|
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.FullLike()(x, 2, dtype="float32")),
|
np.array(knp.FullLike()(x, 2, dtype="float32")),
|
||||||
np.full_like(x, 2, dtype="float32"),
|
np.full_like(x, 2, dtype="float32"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: implement conversion of shape into repetitions, pass to
|
||||||
|
# `torch.tile`, since `torch.full()` only accepts scalars
|
||||||
|
# for `fill_value`."
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() == "torch",
|
||||||
|
reason="`torch.full` only accepts scalars for `fill_value`.",
|
||||||
|
)
|
||||||
|
def test_full_like_without_torch(self):
|
||||||
|
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||||
|
self.assertAllClose(
|
||||||
|
np.array(knp.full_like(x, np.ones([2, 3]))),
|
||||||
|
np.full_like(x, np.ones([2, 3])),
|
||||||
|
)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.FullLike()(x, np.ones([2, 3]))),
|
np.array(knp.FullLike()(x, np.ones([2, 3]))),
|
||||||
np.full_like(x, np.ones([2, 3])),
|
np.full_like(x, np.ones([2, 3])),
|
||||||
@ -1823,6 +1834,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
|||||||
np.linspace(0, 10, 5, endpoint=False),
|
np.linspace(0, 10, 5, endpoint=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: torch.linspace does not support tensor or array
|
||||||
|
# for start/stop, create manual implementation
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() == "torch",
|
||||||
|
reason="`torch.linspace` has no support for array start/stop.",
|
||||||
|
)
|
||||||
|
def test_linspace_without_torch(self):
|
||||||
start = np.zeros([2, 3, 4])
|
start = np.zeros([2, 3, 4])
|
||||||
stop = np.ones([2, 3, 4])
|
stop = np.ones([2, 3, 4])
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -1927,9 +1945,15 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
|||||||
np.logspace(0, 10, 5, endpoint=False),
|
np.logspace(0, 10, 5, endpoint=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: torch.logspace does not support tensor or array
|
||||||
|
# for start/stop, create manual implementation
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() == "torch",
|
||||||
|
reason="`torch.logspace` has no support for array start/stop.",
|
||||||
|
)
|
||||||
|
def test_logspace_without_torch(self):
|
||||||
start = np.zeros([2, 3, 4])
|
start = np.zeros([2, 3, 4])
|
||||||
stop = np.ones([2, 3, 4])
|
stop = np.ones([2, 3, 4])
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.logspace(start, stop, 5, base=10)),
|
np.array(knp.logspace(start, stop, 5, base=10)),
|
||||||
np.logspace(start, stop, 5, base=10),
|
np.logspace(start, stop, 5, base=10),
|
||||||
@ -2992,7 +3016,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
|||||||
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
|
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
|
||||||
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
|
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() == "torch",
|
||||||
|
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
|
||||||
|
)
|
||||||
def test_trace(self):
|
def test_trace(self):
|
||||||
|
# TODO: implement `torch.trace` support for arguments `offset`,
|
||||||
|
# `axis1`, `axis2` and delete NotImplementedError
|
||||||
x = np.arange(24).reshape([1, 2, 3, 4])
|
x = np.arange(24).reshape([1, 2, 3, 4])
|
||||||
self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
|
self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -3062,15 +3092,25 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
|
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
|
||||||
)
|
)
|
||||||
self.assertAllClose(
|
|
||||||
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
|
|
||||||
np.full([2, 3], np.array([1, 4, 5])),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
|
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
|
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: implement conversion of shape into repetitions, pass to
|
||||||
|
# `torch.tile`, since `torch.full()` only accepts scalars
|
||||||
|
# for `fill_value`."
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() == "torch",
|
||||||
|
reason="`torch.full` only accepts scalars for `fill_value`.",
|
||||||
|
)
|
||||||
|
def test_full_without_torch(self):
|
||||||
|
self.assertAllClose(
|
||||||
|
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
|
||||||
|
np.full([2, 3], np.array([1, 4, 5])),
|
||||||
|
)
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
|
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
|
||||||
np.full([2, 3], np.array([1, 4, 5])),
|
np.full([2, 3], np.array([1, 4, 5])),
|
||||||
|
Loading…
Reference in New Issue
Block a user