Add tests for training arg resolution in Layer.

This commit is contained in:
Francois Chollet 2023-04-27 14:10:04 -07:00
parent 2aa4e4887a
commit 2c8e6e19b5

@ -176,3 +176,54 @@ class LayerTest(testing.TestCase):
model(np.ones((1,)))
self.assertLen(model.losses, 1)
self.assertAllClose(model.losses[0], 1.0)
def test_training_arg_value_resolution(self):
# Check that even if `training` is not passed
# to an inner layer, the outer value gets propagated
# in __call__.
class TrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x)
layer = TrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that it still works one level deeper.
class WrappedTrainingLayer(layers.Layer):
def __init__(self):
super().__init__()
self.dp = TrainingLayer()
def call(self, x, training=False):
return self.dp(x)
layer = WrappedTrainingLayer()
x = np.ones((4, 4))
y = layer(x)
self.assertEqual(ops.min(y), 1)
y = layer(x, training=True)
self.assertEqual(ops.min(y), 0)
# Check that if `training` is passed
# to an inner layer in call(), the explicitly
# passed value is what the layer sees.
class TrainingLayerExplicit(layers.Layer):
def __init__(self):
super().__init__()
self.dp = layers.Dropout(0.9)
def call(self, x, training=False):
return self.dp(x, training=True)
layer = TrainingLayerExplicit()
x = np.ones((4, 4))
y = layer(x, training=False)
self.assertEqual(ops.min(y), 0)