diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index 4bfab22ea..8899b0e3f 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -49,6 +49,10 @@ def all(x, axis=None, keepdims=False): return jnp.all(x, axis=axis, keepdims=keepdims) +def any(x, axis=None, keepdims=False): + return jnp.any(x, axis=axis, keepdims=keepdims) + + def amax(x, axis=None, keepdims=False): return jnp.amax(x, axis=axis, keepdims=keepdims) diff --git a/keras_core/backend/tensorflow/numpy.py b/keras_core/backend/tensorflow/numpy.py index b82343373..1a66d7f46 100644 --- a/keras_core/backend/tensorflow/numpy.py +++ b/keras_core/backend/tensorflow/numpy.py @@ -50,6 +50,10 @@ def all(x, axis=None, keepdims=False): return tfnp.all(x, axis=axis, keepdims=keepdims) +def any(x, axis=None, keepdims=False): + return tfnp.any(x, axis=axis, keepdims=keepdims) + + def amax(x, axis=None, keepdims=False): return tfnp.amax(x, axis=axis, keepdims=keepdims) diff --git a/keras_core/layers/__init__.py b/keras_core/layers/__init__.py index e9b604fc9..484186a7b 100644 --- a/keras_core/layers/__init__.py +++ b/keras_core/layers/__init__.py @@ -1,8 +1,10 @@ from keras_core.layers.activations.activation import Activation from keras_core.layers.core.dense import Dense from keras_core.layers.core.embedding import Embedding +from keras_core.layers.core.identity import Identity from keras_core.layers.core.input_layer import Input from keras_core.layers.core.input_layer import InputLayer +from keras_core.layers.core.masking import Masking from keras_core.layers.layer import Layer from keras_core.layers.merging.add import Add from keras_core.layers.merging.add import add diff --git a/keras_core/layers/core/identity.py b/keras_core/layers/core/identity.py new file mode 100644 index 000000000..02bad96ef --- /dev/null +++ b/keras_core/layers/core/identity.py @@ -0,0 +1,18 @@ +from keras_core.api_export import keras_core_export +from keras_core.layers.layer import Layer + + +@keras_core_export("keras_core.layers.Identity") +class Identity(Layer): + """Identity layer. + + This layer should be used as a placeholder when no operation is to be + performed. The layer just returns its `inputs` argument as output. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs): + return inputs diff --git a/keras_core/layers/core/identity_test.py b/keras_core/layers/core/identity_test.py new file mode 100644 index 000000000..d1d010b93 --- /dev/null +++ b/keras_core/layers/core/identity_test.py @@ -0,0 +1,17 @@ +from keras_core import layers +from keras_core import testing + + +class IdentityTest(testing.TestCase): + def test_identity_basics(self): + self.run_layer_test( + layers.Identity, + init_kwargs={}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) diff --git a/keras_core/layers/core/masking.py b/keras_core/layers/core/masking.py new file mode 100644 index 000000000..3fc7eecbd --- /dev/null +++ b/keras_core/layers/core/masking.py @@ -0,0 +1,70 @@ +from keras_core import backend +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.layers.layer import Layer + + +@keras_core_export("keras_core.layers.Masking") +class Masking(Layer): + """Masks a sequence by using a mask value to skip timesteps. + + For each timestep in the input tensor (dimension #1 in the tensor), + if all values in the input tensor at that timestep + are equal to `mask_value`, then the timestep will be masked (skipped) + in all downstream layers (as long as they support masking). + + If any downstream layer does not support masking yet receives such + an input mask, an exception will be raised. + + Example: + + Consider a NumPy data array `x` of shape `(samples, timesteps, features)`, + to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you + lack data for these timesteps. You can: + + - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.` + - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer: + + ```python + samples, timesteps, features = 32, 10, 8 + inputs = np.random.random([samples, timesteps, features]).astype(np.float32) + inputs[:, 3, :] = 0. + inputs[:, 5, :] = 0. + + model = keras_core.models.Sequential() + model.add(keras_core.layers.Masking(mask_value=0.) + model.add(keras_core.layers.LSTM(32)) + output = model(inputs) + # The time step 3 and 5 will be skipped from LSTM calculation. + ``` + + Note: in the Keras masking convention, a masked timestep is denoted by + a mask value of `False`, while a non-masked (i.e. usable) timestep + is denoted by a mask value of `True`. + """ + + def __init__(self, mask_value=0.0, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + self.mask_value = mask_value + + def compute_mask(self, inputs, mask=None): + return ops.any(ops.not_equal(inputs, self.mask_value), axis=-1) + + def call(self, inputs): + boolean_mask = ops.any( + ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True + ) + # Set masked outputs to 0 + outputs = inputs * backend.cast(boolean_mask, dtype=inputs.dtype) + # Compute the mask and outputs simultaneously. + outputs._keras_mask = ops.squeeze(boolean_mask, axis=-1) + return outputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = {"mask_value": self.mask_value} + return {**base_config, **config} diff --git a/keras_core/layers/core/masking_test.py b/keras_core/layers/core/masking_test.py new file mode 100644 index 000000000..1b0785a2a --- /dev/null +++ b/keras_core/layers/core/masking_test.py @@ -0,0 +1,53 @@ +import numpy as np + +from keras_core import layers +from keras_core import models +from keras_core import testing + + +class MaskingTest(testing.TestCase): + def test_masking_basics(self): + self.run_layer_test( + layers.Masking, + init_kwargs={"mask_value": 0.0}, + input_shape=(2, 3, 2), + expected_output_shape=(2, 3, 2), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_masking_correctness(self): + x = np.array( + [ + [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]], + [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]], + ] + ) + expected_mask = [[False, True, False], [True, False, True]] + + layer = layers.Masking(mask_value=0.0) + self.assertAllClose(layer.compute_mask(x), expected_mask) + + class TestLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, inputs, mask=None): + assert mask is not None + np.testing.assert_allclose(mask, expected_mask) + return inputs + + model = models.Sequential( + [ + layers.Masking(mask_value=0.0), + TestLayer(), + ] + ) + model(x) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 6854ce05e..0f84a9ae0 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -877,6 +877,9 @@ def get_shapes_dict(call_spec): """ shapes_dict = {} for k, v in call_spec.tensor_arguments_dict.items(): + if k == "mask" or k.startswith("mask_"): + # Do not include mask tensors in shapes dict + continue if k in call_spec.nested_tensor_argument_names: shapes_dict[f"{k}_shape"] = nest.map_structure( lambda x: backend.standardize_shape(x.shape), v diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index 2b1ffedd5..3662f1b6e 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -283,7 +283,7 @@ class All(Operation): axis=self.axis, keepdims=self.keepdims, ), - dtype=x.dtype, + dtype="bool", ) @@ -293,6 +293,39 @@ def all(x, axis=None, keepdims=False): return backend.numpy.all(x, axis=axis, keepdims=keepdims) +class Any(Operation): + def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.any( + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape( + x.shape, + axis=self.axis, + keepdims=self.keepdims, + ), + dtype="bool", + ) + + +def any(x, axis=None, keepdims=False): + if any_symbolic_tensors((x,)): + return Any(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.any(x, axis=axis, keepdims=keepdims) + + class Amax(Operation): def __init__(self, axis=None, keepdims=False): super().__init__() diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 5109b8c77..55ed96af9 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -644,6 +644,14 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): self.assertEqual(knp.all(x, axis=1).shape, (None, 3)) self.assertEqual(knp.all(x, axis=1, keepdims=True).shape, (None, 1, 3)) + def test_any(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.any(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.any(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.any(x, axis=1, keepdims=True).shape, (None, 1, 3)) + def test_var(self): x = KerasTensor([None, 3]) self.assertEqual(knp.var(x).shape, ()) @@ -1111,6 +1119,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase): x = KerasTensor([2, 3]) self.assertEqual(knp.all(x).shape, ()) + def test_any(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.any(x).shape, ()) + def test_var(self): x = KerasTensor([2, 3]) self.assertEqual(knp.var(x).shape, ()) @@ -2085,6 +2097,22 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.all(x, axis=1, keepdims=True), ) + def test_any(self): + x = np.array([[True, False, True], [True, True, True]]) + self.assertAllClose(np.array(knp.any(x)), np.any(x)) + self.assertAllClose(np.array(knp.any(x, axis=1)), np.any(x, axis=1)) + self.assertAllClose( + np.array(knp.any(x, axis=1, keepdims=True)), + np.any(x, axis=1, keepdims=True), + ) + + self.assertAllClose(np.array(knp.Any()(x)), np.any(x)) + self.assertAllClose(np.array(knp.Any(axis=1)(x)), np.any(x, axis=1)) + self.assertAllClose( + np.array(knp.Any(axis=1, keepdims=True)(x)), + np.any(x, axis=1, keepdims=True), + ) + def test_var(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(np.array(knp.var(x)), np.var(x)) diff --git a/keras_core/operations/operation.py b/keras_core/operations/operation.py index ed2db7d5a..76c5406de 100644 --- a/keras_core/operations/operation.py +++ b/keras_core/operations/operation.py @@ -52,8 +52,8 @@ class Operation: except Exception as e: raise RuntimeError( "Could not automatically infer the output shape / dtype of " - "this operation. " - "Please implement the `compute_output_spec` method " + f"operation '{self.name}'. " + "Please implement the `compute_output_spec()` method " f"on your object ({self.__class__.__name__}). " f"Error encountered: {e}" )