Add Masking layer.

This commit is contained in:
Francois Chollet 2023-05-03 19:56:56 -07:00
parent 10ef809a45
commit f14fb4097f
11 changed files with 235 additions and 3 deletions

@ -49,6 +49,10 @@ def all(x, axis=None, keepdims=False):
return jnp.all(x, axis=axis, keepdims=keepdims)
def any(x, axis=None, keepdims=False):
return jnp.any(x, axis=axis, keepdims=keepdims)
def amax(x, axis=None, keepdims=False):
return jnp.amax(x, axis=axis, keepdims=keepdims)

@ -50,6 +50,10 @@ def all(x, axis=None, keepdims=False):
return tfnp.all(x, axis=axis, keepdims=keepdims)
def any(x, axis=None, keepdims=False):
return tfnp.any(x, axis=axis, keepdims=keepdims)
def amax(x, axis=None, keepdims=False):
return tfnp.amax(x, axis=axis, keepdims=keepdims)

@ -1,8 +1,10 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.core.dense import Dense
from keras_core.layers.core.embedding import Embedding
from keras_core.layers.core.identity import Identity
from keras_core.layers.core.input_layer import Input
from keras_core.layers.core.input_layer import InputLayer
from keras_core.layers.core.masking import Masking
from keras_core.layers.layer import Layer
from keras_core.layers.merging.add import Add
from keras_core.layers.merging.add import add

@ -0,0 +1,18 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Identity")
class Identity(Layer):
"""Identity layer.
This layer should be used as a placeholder when no operation is to be
performed. The layer just returns its `inputs` argument as output.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
def call(self, inputs):
return inputs

@ -0,0 +1,17 @@
from keras_core import layers
from keras_core import testing
class IdentityTest(testing.TestCase):
def test_identity_basics(self):
self.run_layer_test(
layers.Identity,
init_kwargs={},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)

@ -0,0 +1,70 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Masking")
class Masking(Layer):
"""Masks a sequence by using a mask value to skip timesteps.
For each timestep in the input tensor (dimension #1 in the tensor),
if all values in the input tensor at that timestep
are equal to `mask_value`, then the timestep will be masked (skipped)
in all downstream layers (as long as they support masking).
If any downstream layer does not support masking yet receives such
an input mask, an exception will be raised.
Example:
Consider a NumPy data array `x` of shape `(samples, timesteps, features)`,
to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you
lack data for these timesteps. You can:
- Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.`
- Insert a `Masking` layer with `mask_value=0.` before the LSTM layer:
```python
samples, timesteps, features = 32, 10, 8
inputs = np.random.random([samples, timesteps, features]).astype(np.float32)
inputs[:, 3, :] = 0.
inputs[:, 5, :] = 0.
model = keras_core.models.Sequential()
model.add(keras_core.layers.Masking(mask_value=0.)
model.add(keras_core.layers.LSTM(32))
output = model(inputs)
# The time step 3 and 5 will be skipped from LSTM calculation.
```
Note: in the Keras masking convention, a masked timestep is denoted by
a mask value of `False`, while a non-masked (i.e. usable) timestep
is denoted by a mask value of `True`.
"""
def __init__(self, mask_value=0.0, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.mask_value = mask_value
def compute_mask(self, inputs, mask=None):
return ops.any(ops.not_equal(inputs, self.mask_value), axis=-1)
def call(self, inputs):
boolean_mask = ops.any(
ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True
)
# Set masked outputs to 0
outputs = inputs * backend.cast(boolean_mask, dtype=inputs.dtype)
# Compute the mask and outputs simultaneously.
outputs._keras_mask = ops.squeeze(boolean_mask, axis=-1)
return outputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
base_config = super().get_config()
config = {"mask_value": self.mask_value}
return {**base_config, **config}

@ -0,0 +1,53 @@
import numpy as np
from keras_core import layers
from keras_core import models
from keras_core import testing
class MaskingTest(testing.TestCase):
def test_masking_basics(self):
self.run_layer_test(
layers.Masking,
init_kwargs={"mask_value": 0.0},
input_shape=(2, 3, 2),
expected_output_shape=(2, 3, 2),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
def test_masking_correctness(self):
x = np.array(
[
[[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]],
[[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]],
]
)
expected_mask = [[False, True, False], [True, False, True]]
layer = layers.Masking(mask_value=0.0)
self.assertAllClose(layer.compute_mask(x), expected_mask)
class TestLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs, mask=None):
assert mask is not None
np.testing.assert_allclose(mask, expected_mask)
return inputs
model = models.Sequential(
[
layers.Masking(mask_value=0.0),
TestLayer(),
]
)
model(x)

@ -877,6 +877,9 @@ def get_shapes_dict(call_spec):
"""
shapes_dict = {}
for k, v in call_spec.tensor_arguments_dict.items():
if k == "mask" or k.startswith("mask_"):
# Do not include mask tensors in shapes dict
continue
if k in call_spec.nested_tensor_argument_names:
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v

@ -283,7 +283,7 @@ class All(Operation):
axis=self.axis,
keepdims=self.keepdims,
),
dtype=x.dtype,
dtype="bool",
)
@ -293,6 +293,39 @@ def all(x, axis=None, keepdims=False):
return backend.numpy.all(x, axis=axis, keepdims=keepdims)
class Any(Operation):
def __init__(self, axis=None, keepdims=False):
super().__init__()
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
def call(self, x):
return backend.numpy.any(
x,
axis=self.axis,
keepdims=self.keepdims,
)
def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(
x.shape,
axis=self.axis,
keepdims=self.keepdims,
),
dtype="bool",
)
def any(x, axis=None, keepdims=False):
if any_symbolic_tensors((x,)):
return Any(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.numpy.any(x, axis=axis, keepdims=keepdims)
class Amax(Operation):
def __init__(self, axis=None, keepdims=False):
super().__init__()

@ -644,6 +644,14 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
self.assertEqual(knp.all(x, axis=1).shape, (None, 3))
self.assertEqual(knp.all(x, axis=1, keepdims=True).shape, (None, 1, 3))
def test_any(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.any(x).shape, ())
x = KerasTensor([None, 3, 3])
self.assertEqual(knp.any(x, axis=1).shape, (None, 3))
self.assertEqual(knp.any(x, axis=1, keepdims=True).shape, (None, 1, 3))
def test_var(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.var(x).shape, ())
@ -1111,6 +1119,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
x = KerasTensor([2, 3])
self.assertEqual(knp.all(x).shape, ())
def test_any(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.any(x).shape, ())
def test_var(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.var(x).shape, ())
@ -2085,6 +2097,22 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.all(x, axis=1, keepdims=True),
)
def test_any(self):
x = np.array([[True, False, True], [True, True, True]])
self.assertAllClose(np.array(knp.any(x)), np.any(x))
self.assertAllClose(np.array(knp.any(x, axis=1)), np.any(x, axis=1))
self.assertAllClose(
np.array(knp.any(x, axis=1, keepdims=True)),
np.any(x, axis=1, keepdims=True),
)
self.assertAllClose(np.array(knp.Any()(x)), np.any(x))
self.assertAllClose(np.array(knp.Any(axis=1)(x)), np.any(x, axis=1))
self.assertAllClose(
np.array(knp.Any(axis=1, keepdims=True)(x)),
np.any(x, axis=1, keepdims=True),
)
def test_var(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.var(x)), np.var(x))

@ -52,8 +52,8 @@ class Operation:
except Exception as e:
raise RuntimeError(
"Could not automatically infer the output shape / dtype of "
"this operation. "
"Please implement the `compute_output_spec` method "
f"operation '{self.name}'. "
"Please implement the `compute_output_spec()` method "
f"on your object ({self.__class__.__name__}). "
f"Error encountered: {e}"
)