keras/keras_core/layers/layer_test.py
Matt Watson cc053ac309 Add layers.MultiHeadAttention (#169)
* Add layers.MultiHeadAttention

* Update build/compute_output_signature argument checks

We do not definitively know which arguments are tensor arguments
at any given invocation (e.g. arguments with a None value may be tensor
arguments).

So rather than check that the build signature matches perfectly with
tensor call arguments, we will check the build signature arguments match
with some call argument.
2023-05-16 18:22:42 -07:00

629 lines
22 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import operations as ops
from keras_core import testing
class LayerTest(testing.TestCase):
def test_compute_output_spec(self):
# Test that implementing compute_output_shape
# is enough to make compute_output_spec work.
# Case: single output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return input_shape
layer = TestLayer()
self.assertEqual(
layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3)
)
# Case: tuple output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return (input_shape, input_shape)
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
# Case: list output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return [input_shape, input_shape]
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, list))
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
# Case: dict output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return {"1": input_shape, "2": input_shape}
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertEqual(out["2"].shape, (2, 3))
# Case: nested tuple output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return (
input_shape,
(input_shape, input_shape),
(input_shape, input_shape),
)
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertEqual(len(out), 3)
self.assertEqual(out[0].shape, (2, 3))
self.assertTrue(isinstance(out[1], tuple))
self.assertEqual(len(out[1]), 2)
self.assertEqual(out[1][0].shape, (2, 3))
self.assertEqual(out[1][1].shape, (2, 3))
self.assertTrue(isinstance(out[2], tuple))
self.assertEqual(len(out[2]), 2)
self.assertEqual(out[2][0].shape, (2, 3))
self.assertEqual(out[2][1].shape, (2, 3))
# Case: nested dict output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return {
"1": input_shape,
"2": {"11": input_shape, "22": input_shape},
}
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertTrue(isinstance(out["2"], dict))
self.assertEqual(len(out["2"]), 2)
self.assertEqual(out["2"]["11"].shape, (2, 3))
self.assertEqual(out["2"]["22"].shape, (2, 3))
def test_positional_arg_error(self):
class SomeLayer(layers.Layer):
def call(self, x, bool_arg):
if bool_arg:
return x
return x + 1
x = backend.KerasTensor(shape=(2, 3), name="x")
with self.assertRaisesRegex(
ValueError, "Only input tensors may be passed as"
):
SomeLayer()(x, True)
# This works
SomeLayer()(x, bool_arg=True)
def test_rng_seed_tracking(self):
class RNGLayer(layers.Layer):
def __init__(self):
super().__init__()
self.seed_gen = backend.random.SeedGenerator(seed=1337)
def call(self, x):
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
layer = RNGLayer()
self.assertEqual(layer.variables, [layer.seed_gen.state])
self.assertAllClose(layer.variables[0], [1337, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1337, 1])
# Test tracking in list attributes.
class RNGListLayer(layers.Layer):
def __init__(self):
super().__init__()
self.seed_gens = []
self.seed_gens.append(backend.random.SeedGenerator(seed=1))
self.seed_gens.append(backend.random.SeedGenerator(seed=10))
def call(self, x):
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
return x
layer = RNGListLayer()
self.assertEqual(
layer.variables,
[layer.seed_gens[0].state, layer.seed_gens[1].state],
)
self.assertAllClose(layer.variables[0], [1, 0])
self.assertAllClose(layer.variables[1], [10, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1, 1])
self.assertAllClose(layer.variables[1], [10, 1])
def test_layer_tracking(self):
class NestedLayer(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
self.layer_dict = {
"dense2": layers.Dense(units),
}
self.layer_list = [layers.Dense(units)]
self.units = units
def build(self, input_shape):
self.layer_list.append(layers.Dense(self.units))
def call(self, x):
x = self.dense1(x)
x = self.layer_dict["dense2"](x)
x = self.layer_list[0](x)
x = self.layer_list[1](x)
return x
class DoubleNestedLayer(layers.Layer):
def __init__(self, units):
super().__init__()
self.inner_layer = NestedLayer(units)
def call(self, x):
return self.inner_layer(x)
layer = NestedLayer(3)
layer.build((1, 3))
self.assertLen(layer._layers, 4)
layer(np.zeros((1, 3)))
self.assertLen(layer.weights, 8)
layer = DoubleNestedLayer(3)
self.assertLen(layer._layers, 1)
layer(np.zeros((1, 3)))
self.assertLen(layer.inner_layer.weights, 8)
self.assertLen(layer.weights, 8)
def test_build_on_call(self):
class LayerWithUnbuiltState(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
def call(self, x):
return self.dense1(x)
layer = LayerWithUnbuiltState(2)
layer(backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 2)
class KwargsLayerWithUnbuiltState(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
self.dense2 = layers.Dense(units)
def call(self, x1, x2):
return self.dense1(x1) + self.dense2(x2)
layer = KwargsLayerWithUnbuiltState(2)
layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 4)
layer = KwargsLayerWithUnbuiltState(2)
layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 4)
def test_activity_regularization(self):
class ActivityRegularizer(layers.Layer):
def call(self, x):
return x
layer = ActivityRegularizer(activity_regularizer="l1")
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 0.01)
# losses are reset upon call
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 0.01)
# KerasTensors are no op
layer = ActivityRegularizer(activity_regularizer="l1")
layer(layers.Input(batch_shape=(2, 2)))
self.assertLen(layer.losses, 0)
def test_add_loss(self):
class LossLayer(layers.Layer):
def call(self, x):
self.add_loss(ops.sum(x))
return x
layer = LossLayer()
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 1.0)
# losses are reset upon call
layer = LossLayer()
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 1.0)
# It works inside a model
model = models.Sequential([layer])
model(np.ones((1,)))
self.assertLen(model.losses, 1)
self.assertAllClose(model.losses[0], 1.0)
def test_training_arg_value_resolution(self):
# Check that even if `training` is not passed
# to an inner layer, the outer value gets propagated
# in __call__.
class TrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x)
layer = TrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that it still works one level deeper.
class WrappedTrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = TrainingLayer()
def call(self, x, training=False):
return self.dp(x)
layer = WrappedTrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that if `training` is passed
# to an inner layer in call(), the explicitly
# passed value is what the layer sees.
class TrainingLayerExplicit(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x, training=True)
layer = TrainingLayerExplicit()
x = np.ones((4, 4))
y = layer(x, training=False)
self.assertEqual(ops.min(y), 0)
# Test that layer interruption does not cause
# the call context to linger
class BadLayer(layers.Layer):
def call(self, x, training=False):
raise RuntimeError("oops!")
x = np.ones((4, 4))
layer = BadLayer()
try:
# training=True will be recorded
# in the call context
layer(x, training=True)
except RuntimeError:
pass
layer = TrainingLayer()
# But this layer call should not see it
y = layer(x)
self.assertEqual(ops.min(y), 1)
def test_mixed_precision(self):
x = np.ones((4, 4))
layer = layers.Dense(2, dtype="float16")
y = layer(x)
self.assertEqual(layer.compute_dtype, "float16")
self.assertEqual(layer.variable_dtype, "float16")
self.assertEqual(y.dtype.name, "float16")
layer = layers.Dense(2, dtype="mixed_float16")
y = layer(x)
self.assertEqual(layer.compute_dtype, "float16")
self.assertEqual(layer.variable_dtype, "float32")
self.assertEqual(y.dtype.name, "float16")
self.assertEqual(layer.kernel.dtype, "float32")
def test_masking(self):
class BasicMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert mask is not None
return x
layer = BasicMaskedLayer()
x = backend.numpy.ones((4, 4))
x._keras_mask = backend.numpy.ones((4,))
layer(x)
layer(backend.numpy.ones((4, 4)), mask=backend.numpy.ones((4,)))
class NestedInputMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert isinstance(x, list)
assert len(x) == 2
assert isinstance(mask, list)
assert len(mask) == 2
return x
layer = NestedInputMaskedLayer()
x1 = backend.numpy.ones((4, 4))
x1._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer([x1, x2])
layer(
[backend.numpy.ones((4, 4)), backend.numpy.ones((4, 4))],
mask=[backend.numpy.ones((4,)), backend.numpy.ones((4,))],
)
class PositionalInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert x1_mask is not None
assert x2_mask is not None
return x1 + x2
layer = PositionalInputsMaskedLayer()
layer(x1, x2)
layer(x1=x1, x2=x2)
class PositionalNestedInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert isinstance(x1, tuple)
assert x1_mask is not None
assert x2_mask is not None
assert isinstance(x1_mask, tuple)
return x1[0] + x1[1] + x2
layer = PositionalNestedInputsMaskedLayer()
x1_1 = backend.numpy.ones((4, 4))
x1_1._keras_mask = backend.numpy.ones((4,))
x1_2 = backend.numpy.ones((4, 4))
x1_2._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer((x1_1, x1_2), x2)
layer(x1=(x1_1, x1_2), x2=x2)
def test_stateless_call(self):
class TestLayer(layers.Layer):
def __init__(self):
super().__init__()
self._seed_generator = backend.random.SeedGenerator(1337)
self.ntw = self.add_weight(
shape=(),
initializer="zeros",
trainable=False,
)
self.tw = self.add_weight(
shape=(),
initializer="zeros",
trainable=True,
)
self.built = True
def call(self, x):
x = backend.convert_to_tensor(x, dtype="float32")
self.add_loss(ops.sum(x))
self.ntw.assign(ops.sum(x))
x = x + backend.random.normal(
shape=(), seed=self._seed_generator
)
return x + self.tw + self.ntw
data = np.random.random((3, 4))
layer = TestLayer()
out = layer(data)
layer1 = TestLayer()
out1 = layer1(data)
# Check that the layer is in fact deterministic
self.assertAllClose(out, out1)
# Test stateless_call correctness
layer2 = TestLayer()
trainable_variables = layer2.trainable_variables
non_trainable_variables = layer2.non_trainable_variables
out2, non_trainable_variables = layer2.stateless_call(
trainable_variables, non_trainable_variables, data
)
self.assertAllClose(out1, out2)
self.assertEqual(
len(layer1.non_trainable_variables), len(non_trainable_variables)
)
for ref_v, v in zip(
layer1.non_trainable_variables, non_trainable_variables
):
self.assertAllClose(ref_v, v)
# Test with loss collection
layer3 = TestLayer()
trainable_variables = layer3.trainable_variables
non_trainable_variables = layer3.non_trainable_variables
out3, non_trainable_variables, losses = layer3.stateless_call(
trainable_variables,
non_trainable_variables,
data,
return_losses=True,
)
self.assertAllClose(out1, out3)
for ref_v, v in zip(
layer1.non_trainable_variables, non_trainable_variables
):
self.assertAllClose(ref_v, v)
for ref_loss, loss in zip(layer1.losses, losses):
self.assertAllClose(ref_loss, loss)
def test_trainable_setting(self):
class NonTrainableWeightsLayer(layers.Layer):
def build(self, _):
self.w1 = self.add_weight(
shape=(),
initializer="ones",
trainable=True,
)
self.w2 = self.add_weight(
shape=(),
initializer="ones",
trainable=False,
)
self.seed = backend.random.SeedGenerator(123)
def call(self, inputs):
return inputs
class NestedNonTrainableWeightsLayer(layers.Layer):
def build(self, _):
self.w1 = self.add_weight(
shape=(),
initializer="ones",
trainable=True,
)
self.w2 = self.add_weight(
shape=(),
initializer="ones",
trainable=False,
)
self.nested = NonTrainableWeightsLayer()
self.nested.build(None)
def call(self, inputs):
return inputs
layer = NestedNonTrainableWeightsLayer()
layer.build(None)
self.assertEqual(len(layer.trainable_weights), 2)
self.assertEqual(len(layer.trainable_variables), 2)
self.assertEqual(len(layer.non_trainable_weights), 2)
self.assertEqual(len(layer.non_trainable_variables), 3)
layer.trainable = False
self.assertEqual(len(layer.trainable_weights), 0)
self.assertEqual(len(layer.trainable_variables), 0)
self.assertEqual(len(layer.non_trainable_weights), 4)
self.assertEqual(len(layer.non_trainable_variables), 5)
self.assertFalse(layer.w1.trainable)
self.assertFalse(layer.nested.w1.trainable)
layer.trainable = True
self.assertEqual(len(layer.trainable_weights), 2)
self.assertEqual(len(layer.trainable_variables), 2)
self.assertEqual(len(layer.non_trainable_weights), 2)
self.assertEqual(len(layer.non_trainable_variables), 3)
self.assertTrue(layer.w1.trainable)
self.assertTrue(layer.nested.w1.trainable)
layer = NestedNonTrainableWeightsLayer(trainable=False)
layer.build(None)
self.assertEqual(len(layer.trainable_weights), 0)
self.assertEqual(len(layer.trainable_variables), 0)
self.assertEqual(len(layer.non_trainable_weights), 4)
self.assertEqual(len(layer.non_trainable_variables), 5)
layer.trainable = True
self.assertEqual(len(layer.trainable_weights), 2)
self.assertEqual(len(layer.trainable_variables), 2)
self.assertEqual(len(layer.non_trainable_weights), 2)
self.assertEqual(len(layer.non_trainable_variables), 3)
def test_build_signature_errors(self):
class NoShapeSuffix(layers.Layer):
def build(self, foo_shape, bar):
self._built = True
def call(self, foo, bar):
return foo + bar
class NonMatchingArgument(layers.Layer):
def build(self, foo_shape, baz_shape):
self._built = True
def call(self, foo, bar):
return foo + bar
class MatchingArguments(layers.Layer):
def build(self, foo_shape, bar_shape):
self._built = True
def call(self, foo, bar):
return foo + bar
foo = backend.numpy.ones((4, 4))
bar = backend.numpy.ones((4, 4))
with self.assertRaisesRegex(
ValueError,
r"argument `bar`, which does not end in `_shape`",
):
NoShapeSuffix()(foo, bar)
with self.assertRaisesRegex(
ValueError,
r"`baz_shape`, but `call\(\)` does not have argument `baz`",
):
NonMatchingArgument()(foo, bar)
MatchingArguments()(foo, bar)