keras/keras_core/layers/layer_test.py
Francois Chollet a375c1565a Fix bug.
2023-04-28 15:10:53 -07:00

335 lines
11 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import operations as ops
from keras_core import testing
class LayerTest(testing.TestCase):
def test_positional_arg_error(self):
class SomeLayer(layers.Layer):
def call(self, x, bool_arg):
if bool_arg:
return x
return x + 1
x = backend.KerasTensor(shape=(2, 3), name="x")
with self.assertRaisesRegex(
ValueError, "Only input tensors may be passed as"
):
SomeLayer()(x, True)
# This works
SomeLayer()(x, bool_arg=True)
def test_rng_seed_tracking(self):
class RNGLayer(layers.Layer):
def __init__(self):
super().__init__()
self.seed_gen = backend.random.SeedGenerator(seed=1337)
def call(self, x):
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
layer = RNGLayer()
self.assertEqual(layer.variables, [layer.seed_gen.state])
self.assertAllClose(layer.variables[0], [1337, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1337, 1])
# Test tracking in list attributes.
class RNGListLayer(layers.Layer):
def __init__(self):
super().__init__()
self.seed_gens = []
self.seed_gens.append(backend.random.SeedGenerator(seed=1))
self.seed_gens.append(backend.random.SeedGenerator(seed=10))
def call(self, x):
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
return x
layer = RNGListLayer()
self.assertEqual(
layer.variables,
[layer.seed_gens[0].state, layer.seed_gens[1].state],
)
self.assertAllClose(layer.variables[0], [1, 0])
self.assertAllClose(layer.variables[1], [10, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1, 1])
self.assertAllClose(layer.variables[1], [10, 1])
def test_layer_tracking(self):
class NestedLayer(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
self.layer_dict = {
"dense2": layers.Dense(units),
}
self.layer_list = [layers.Dense(units)]
self.units = units
def build(self, input_shape):
self.layer_list.append(layers.Dense(self.units))
def call(self, x):
x = self.dense1(x)
x = self.layer_dict["dense2"](x)
x = self.layer_list[0](x)
x = self.layer_list[1](x)
return x
class DoubleNestedLayer(layers.Layer):
def __init__(self, units):
super().__init__()
self.inner_layer = NestedLayer(units)
def call(self, x):
return self.inner_layer(x)
layer = NestedLayer(3)
layer.build((1, 3))
self.assertLen(layer._layers, 4)
layer(np.zeros((1, 3)))
self.assertLen(layer.weights, 8)
layer = DoubleNestedLayer(3)
self.assertLen(layer._layers, 1)
layer(np.zeros((1, 3)))
self.assertLen(layer.inner_layer.weights, 8)
self.assertLen(layer.weights, 8)
def test_build_on_call(self):
class LayerWithUnbuiltState(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
def call(self, x):
return self.dense1(x)
layer = LayerWithUnbuiltState(2)
layer(backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 2)
class KwargsLayerWithUnbuiltState(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
self.dense2 = layers.Dense(units)
def call(self, x1, x2):
return self.dense1(x1) + self.dense2(x2)
layer = KwargsLayerWithUnbuiltState(2)
layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 4)
layer = KwargsLayerWithUnbuiltState(2)
layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))
self.assertLen(layer.weights, 4)
def test_activity_regularization(self):
class ActivityRegularizer(layers.Layer):
def call(self, x):
return x
layer = ActivityRegularizer(activity_regularizer="l1")
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 0.01)
# losses are reset upon call
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 0.01)
# KerasTensors are no op
layer = ActivityRegularizer(activity_regularizer="l1")
layer(layers.Input(batch_shape=(2, 2)))
self.assertLen(layer.losses, 0)
def test_add_loss(self):
class LossLayer(layers.Layer):
def call(self, x):
self.add_loss(ops.sum(x))
return x
layer = LossLayer()
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 1.0)
# losses are reset upon call
layer = LossLayer()
layer(np.ones((1,)))
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 1.0)
# It works inside a model
model = models.Sequential([layer])
model(np.ones((1,)))
self.assertLen(model.losses, 1)
self.assertAllClose(model.losses[0], 1.0)
def test_training_arg_value_resolution(self):
# Check that even if `training` is not passed
# to an inner layer, the outer value gets propagated
# in __call__.
class TrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x)
layer = TrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that it still works one level deeper.
class WrappedTrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = TrainingLayer()
def call(self, x, training=False):
return self.dp(x)
layer = WrappedTrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that if `training` is passed
# to an inner layer in call(), the explicitly
# passed value is what the layer sees.
class TrainingLayerExplicit(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x, training=True)
layer = TrainingLayerExplicit()
x = np.ones((4, 4))
y = layer(x, training=False)
self.assertEqual(ops.min(y), 0)
# Test that layer interruption does not cause
# the call context to linger
class BadLayer(layers.Layer):
def call(self, x, training=False):
raise RuntimeError("oops!")
x = np.ones((4, 4))
layer = BadLayer()
try:
# training=True will be recorded
# in the call context
layer(x, training=True)
except RuntimeError:
pass
layer = TrainingLayer()
# But this layer call should not see it
y = layer(x)
self.assertEqual(ops.min(y), 1)
def test_mixed_precision(self):
x = np.ones((4, 4))
layer = layers.Dense(2, dtype="float16")
y = layer(x)
self.assertEqual(layer.compute_dtype, "float16")
self.assertEqual(layer.variable_dtype, "float16")
self.assertEqual(y.dtype.name, "float16")
layer = layers.Dense(2, dtype="mixed_float16")
y = layer(x)
self.assertEqual(layer.compute_dtype, "float16")
self.assertEqual(layer.variable_dtype, "float32")
self.assertEqual(y.dtype.name, "float16")
self.assertEqual(layer.kernel.dtype, "float32")
def test_masking(self):
class BasicMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert mask is not None
return x
layer = BasicMaskedLayer()
x = backend.numpy.ones((4, 4))
x._keras_mask = backend.numpy.ones((4,))
layer(x)
class NestedInputMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert isinstance(x, list)
assert len(x) == 2
assert isinstance(mask, list)
assert len(mask) == 2
return x
layer = NestedInputMaskedLayer()
x1 = backend.numpy.ones((4, 4))
x1._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer([x1, x2])
class PositionalInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert x1_mask is not None
assert x2_mask is not None
return x1 + x2
layer = PositionalInputsMaskedLayer()
layer(x1, x2)
layer(x1=x1, x2=x2)
class PositionalNestedInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert isinstance(x1, tuple)
assert x1_mask is not None
assert x2_mask is not None
assert isinstance(x1_mask, tuple)
return x1[0] + x1[1] + x2
layer = PositionalNestedInputsMaskedLayer()
x1_1 = backend.numpy.ones((4, 4))
x1_1._keras_mask = backend.numpy.ones((4,))
x1_2 = backend.numpy.ones((4, 4))
x1_2._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer((x1_1, x1_2), x2)
layer(x1=(x1_1, x1_2), x2=x2)