diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index 15e434690..5f8d362ea 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -176,3 +176,54 @@ class LayerTest(testing.TestCase): model(np.ones((1,))) self.assertLen(model.losses, 1) self.assertAllClose(model.losses[0], 1.0) + + def test_training_arg_value_resolution(self): + # Check that even if `training` is not passed + # to an inner layer, the outer value gets propagated + # in __call__. + class TrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dp = layers.Dropout(0.9) + + def call(self, x, training=False): + return self.dp(x) + + layer = TrainingLayer() + x = np.ones((4, 4)) + y = layer(x) + self.assertEqual(ops.min(y), 1) + y = layer(x, training=True) + self.assertEqual(ops.min(y), 0) + + # Check that it still works one level deeper. + class WrappedTrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dp = TrainingLayer() + + def call(self, x, training=False): + return self.dp(x) + + layer = WrappedTrainingLayer() + x = np.ones((4, 4)) + y = layer(x) + self.assertEqual(ops.min(y), 1) + y = layer(x, training=True) + self.assertEqual(ops.min(y), 0) + + # Check that if `training` is passed + # to an inner layer in call(), the explicitly + # passed value is what the layer sees. + class TrainingLayerExplicit(layers.Layer): + def __init__(self): + super().__init__() + self.dp = layers.Dropout(0.9) + + def call(self, x, training=False): + return self.dp(x, training=True) + + layer = TrainingLayerExplicit() + x = np.ones((4, 4)) + y = layer(x, training=False) + self.assertEqual(ops.min(y), 0)