diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index a3e7cb9e9..80c222e3f 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -288,6 +288,7 @@ class JAXTrainer(base_trainer.Trainer): if getattr(self, "_eval_epoch_iterator", None) is not None: del self._eval_epoch_iterator callbacks.on_train_end(logs=training_logs) + self._jax_state = None return self.history def evaluate( @@ -452,7 +453,7 @@ class JAXTrainer(base_trainer.Trainer): logs = self._pythonify_logs(self.get_metrics_result()) callbacks.on_test_end(logs) - + self._jax_state = None if return_dict: return logs return self._flatten_metrics_in_order(logs) @@ -565,7 +566,7 @@ class JAXTrainer(base_trainer.Trainer): ) def jax_state_sync(self): - if not hasattr(self, "_jax_state"): + if not getattr(self, "_jax_state", None): return trainable_variables = self._jax_state.get("trainable_variables", None) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index c69906d87..53883ca01 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1): def dot(x, y): x, y = convert_to_tensor(x), convert_to_tensor(y) - if x.ndim == 0 or y.ndim == 0: + if len(x.shape) == 0 or len(y.shape) == 0: return torch.multiply(x, y) return torch.matmul(x, y) @@ -327,7 +327,7 @@ def expm1(x): def flip(x, axis=None): x = convert_to_tensor(x) if axis is None: - axis = tuple(range(x.ndim)) + axis = tuple(range(len(x.shape))) if isinstance(axis, int): axis = (axis,) return torch.flip(x, dims=axis) @@ -341,21 +341,23 @@ def floor(x): def full(shape, fill_value, dtype=None): dtype = to_torch_dtype(dtype) if hasattr(fill_value, "__len__"): - fill_value = convert_to_tensor(fill_value) - reps = shape[-1] // fill_value.shape[-1] # channels-last reduction - reps_by_dim = (*shape[:-1], reps) - return torch.tile(fill_value, reps_by_dim) + raise NotImplementedError( + "`torch.full()` only accepts scalars for `fill_value`. " + f"Received: fill_value={fill_value}" + ) + # TODO: implement conversion of shape into repetitions for `torch.tile`` + # return torch.tile(fill_value, reps) return torch.full(size=shape, fill_value=fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): dtype = to_torch_dtype(dtype) if hasattr(fill_value, "__len__"): - fill_value = convert_to_tensor(fill_value) - reps_by_dim = tuple( - [x.shape[i] // fill_value.shape[i] for i in range(x.ndim)] + raise NotImplementedError( + "`torch.full()` only accepts scalars for `fill_value`." ) - return torch.tile(fill_value, reps_by_dim) + # TODO: implement conversion of shape into repetitions for `torch.tile`` + # return torch.tile(fill_value, reps) x = convert_to_tensor(x) return torch.full_like(input=x, fill_value=fill_value, dtype=dtype) @@ -428,27 +430,15 @@ def linspace( "torch.linspace does not support an `axis` argument. " f"Received axis={axis}" ) - dtype = to_torch_dtype(dtype) + dtype = to_torch_dtype(dtype) if endpoint is False: stop = stop - ((stop - start) / num) - if hasattr(start, "__len__") and hasattr(stop, "__len__"): - start, stop = convert_to_tensor(start), convert_to_tensor(stop) - stop = cast(stop, dtype) if endpoint is False and dtype else stop - steps = torch.arange(num, dtype=dtype) / (num - 1) - - # reshape `steps` to allow for broadcasting - for i in range(start.ndim): - steps = steps.unsqueeze(-1) - - # increments from `start` to `stop` in each dimension - linspace = start[None] + steps * (stop - start)[None] - else: - linspace = torch.linspace( - start=start, - end=stop, - steps=num, - dtype=dtype, - ) + linspace = torch.linspace( + start=start, + end=stop, + steps=num, + dtype=dtype, + ) if retstep is True: return (linspace, num) return linspace @@ -505,26 +495,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): dtype = to_torch_dtype(dtype) if endpoint is False: stop = stop - ((stop - start) / num) - if hasattr(start, "__len__") and hasattr(stop, "__len__"): - start, stop = convert_to_tensor(start), convert_to_tensor(stop) - stop = cast(stop, dtype) if endpoint is False and dtype else stop - steps = torch.arange(num, dtype=dtype) / (num - 1) - - # reshape `steps` to allow for broadcasting - for i in range(start.ndim): - steps = steps.unsqueeze(-1) - - # increments from `start` to `stop` in each dimension - linspace = start[None] + steps * (stop - start)[None] - logspace = base**linspace - else: - logspace = torch.logspace( - start=start, - end=stop, - steps=num, - base=base, - dtype=dtype, - ) + logspace = torch.logspace( + start=start, + end=stop, + steps=num, + base=base, + dtype=dtype, + ) return logspace @@ -738,7 +715,14 @@ def tile(x, repeats): def trace(x, offset=None, axis1=None, axis2=None): x = convert_to_tensor(x) - return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1) + # TODO: implement support for these arguments + # API divergence between `np.trace()` and `torch.trace()` + if offset or axis1 or axis2: + raise NotImplementedError( + "Arguments not supported by `torch.trace: " + f"offset={offset}, axis1={axis1}, axis2={axis2}" + ) + return torch.trace(x) def tri(N, M=None, k=0, dtype="float32"): diff --git a/keras_core/callbacks/callback.py b/keras_core/callbacks/callback.py index 8c92c4021..f9b66b693 100644 --- a/keras_core/callbacks/callback.py +++ b/keras_core/callbacks/callback.py @@ -1,3 +1,4 @@ +from keras_core import backend from keras_core.api_export import keras_core_export @@ -65,13 +66,25 @@ class Callback: def __init__(self): self.validation_data = None - self.model = None + self._model = None def set_params(self, params): self.params = params def set_model(self, model): - self.model = model + self._model = model + + @property + def model(self): + if backend.backend() == "jax" and hasattr( + self._model, "jax_state_sync" + ): + # With JAX, by default the model state is not + # attached to the model in the middle of an + # epoch. We have to force a sync before + # accessing model state for e.g. checkpointing. + self._model.jax_state_sync() + return self._model def on_batch_begin(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_begin`.""" diff --git a/keras_core/callbacks/callback_list.py b/keras_core/callbacks/callback_list.py index 8059fda67..d394a85b4 100644 --- a/keras_core/callbacks/callback_list.py +++ b/keras_core/callbacks/callback_list.py @@ -70,7 +70,7 @@ class CallbackList(Callback): callback.set_params(params) def set_model(self, model): - self.model = model + super().set_model(model) if self._history: model.history = self._history for callback in self.callbacks: diff --git a/keras_core/callbacks/callback_test.py b/keras_core/callbacks/callback_test.py new file mode 100644 index 000000000..c4d5c7a6b --- /dev/null +++ b/keras_core/callbacks/callback_test.py @@ -0,0 +1,29 @@ +import numpy as np + +from keras_core import models +from keras_core import testing +from keras_core.callbacks.callback import Callback + + +class CallbackTest(testing.TestCase): + def test_model_state_is_current_on_epoch_end(self): + class TestModel(models.Model): + def __init__(self): + super().__init__() + self.iterations = self.add_variable( + shape=(), initializer="zeros", trainable=False + ) + + def call(self, inputs): + self.iterations.assign(self.iterations + 1) + return inputs + + class CBK(Callback): + def on_batch_end(self, batch, logs): + assert np.int32(self.model.iterations) == batch + 1 + + model = TestModel() + model.compile(optimizer="sgd", loss="mse") + x = np.random.random((8, 1)) + y = np.random.random((8, 1)) + model.fit(x, y, callbacks=[CBK()], batch_size=2) diff --git a/keras_core/callbacks/early_stopping_test.py b/keras_core/callbacks/early_stopping_test.py index 0e4a2ba40..d9924da47 100644 --- a/keras_core/callbacks/early_stopping_test.py +++ b/keras_core/callbacks/early_stopping_test.py @@ -54,7 +54,7 @@ class EarlyStoppingTest(testing.TestCase): for patience in cases: stopper = callbacks.EarlyStopping(monitor="loss", patience=patience) - stopper.model = models.Sequential() + stopper.set_model(models.Sequential()) stopper.model.compile(loss="mse", optimizer="sgd") stopper.on_train_begin() @@ -130,7 +130,7 @@ class EarlyStoppingTest(testing.TestCase): early_stop = callbacks.EarlyStopping( monitor="val_loss", patience=2, restore_best_weights=True ) - early_stop.model = DummyModel() + early_stop.set_model(DummyModel()) losses = [0.2, 0.15, 0.1, 0.11, 0.12] # The best configuration is in the epoch 2 (loss = 0.1000). epochs_trained = 0 @@ -153,7 +153,7 @@ class EarlyStoppingTest(testing.TestCase): baseline=0.5, restore_best_weights=True, ) - early_stop.model = DummyModel() + early_stop.set_model(DummyModel()) losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73] # The best configuration is in the epoch 2 (loss = 0.7000). epochs_trained = 0 diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index c509fbe13..323be0b76 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1713,16 +1713,27 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.array(knp.full_like(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) - self.assertAllClose( - np.array(knp.full_like(x, np.ones([2, 3]))), - np.full_like(x, np.ones([2, 3])), - ) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) self.assertAllClose( np.array(knp.FullLike()(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) + + # TODO: implement conversion of shape into repetitions, pass to + # `torch.tile`, since `torch.full()` only accepts scalars + # for `fill_value`." + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.full` only accepts scalars for `fill_value`.", + ) + def test_full_like_without_torch(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose( + np.array(knp.full_like(x, np.ones([2, 3]))), + np.full_like(x, np.ones([2, 3])), + ) + self.assertAllClose( np.array(knp.FullLike()(x, np.ones([2, 3]))), np.full_like(x, np.ones([2, 3])), @@ -1823,6 +1834,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.linspace(0, 10, 5, endpoint=False), ) + # TODO: torch.linspace does not support tensor or array + # for start/stop, create manual implementation + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.linspace` has no support for array start/stop.", + ) + def test_linspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -1927,9 +1945,15 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.logspace(0, 10, 5, endpoint=False), ) + # TODO: torch.logspace does not support tensor or array + # for start/stop, create manual implementation + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.logspace` has no support for array start/stop.", + ) + def test_logspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) - self.assertAllClose( np.array(knp.logspace(start, stop, 5, base=10)), np.logspace(start, stop, 5, base=10), @@ -2992,7 +3016,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.split` does not support args `offset`, `axis1`, `axis2`", + ) def test_trace(self): + # TODO: implement `torch.trace` support for arguments `offset`, + # `axis1`, `axis2` and delete NotImplementedError x = np.arange(24).reshape([1, 2, 3, 4]) self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose( @@ -3062,15 +3092,25 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose( np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) ) - self.assertAllClose( - np.array(knp.full([2, 3], np.array([1, 4, 5]))), - np.full([2, 3], np.array([1, 4, 5])), - ) self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) self.assertAllClose( np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) ) + + # TODO: implement conversion of shape into repetitions, pass to + # `torch.tile`, since `torch.full()` only accepts scalars + # for `fill_value`." + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.full` only accepts scalars for `fill_value`.", + ) + def test_full_without_torch(self): + self.assertAllClose( + np.array(knp.full([2, 3], np.array([1, 4, 5]))), + np.full([2, 3], np.array([1, 4, 5])), + ) + self.assertAllClose( np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])),