Add tests for training arg resolution in Layer.
This commit is contained in:
parent
2aa4e4887a
commit
2c8e6e19b5
@ -176,3 +176,54 @@ class LayerTest(testing.TestCase):
|
||||
model(np.ones((1,)))
|
||||
self.assertLen(model.losses, 1)
|
||||
self.assertAllClose(model.losses[0], 1.0)
|
||||
|
||||
def test_training_arg_value_resolution(self):
|
||||
# Check that even if `training` is not passed
|
||||
# to an inner layer, the outer value gets propagated
|
||||
# in __call__.
|
||||
class TrainingLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dp = layers.Dropout(0.9)
|
||||
|
||||
def call(self, x, training=False):
|
||||
return self.dp(x)
|
||||
|
||||
layer = TrainingLayer()
|
||||
x = np.ones((4, 4))
|
||||
y = layer(x)
|
||||
self.assertEqual(ops.min(y), 1)
|
||||
y = layer(x, training=True)
|
||||
self.assertEqual(ops.min(y), 0)
|
||||
|
||||
# Check that it still works one level deeper.
|
||||
class WrappedTrainingLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dp = TrainingLayer()
|
||||
|
||||
def call(self, x, training=False):
|
||||
return self.dp(x)
|
||||
|
||||
layer = WrappedTrainingLayer()
|
||||
x = np.ones((4, 4))
|
||||
y = layer(x)
|
||||
self.assertEqual(ops.min(y), 1)
|
||||
y = layer(x, training=True)
|
||||
self.assertEqual(ops.min(y), 0)
|
||||
|
||||
# Check that if `training` is passed
|
||||
# to an inner layer in call(), the explicitly
|
||||
# passed value is what the layer sees.
|
||||
class TrainingLayerExplicit(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dp = layers.Dropout(0.9)
|
||||
|
||||
def call(self, x, training=False):
|
||||
return self.dp(x, training=True)
|
||||
|
||||
layer = TrainingLayerExplicit()
|
||||
x = np.ones((4, 4))
|
||||
y = layer(x, training=False)
|
||||
self.assertEqual(ops.min(y), 0)
|
||||
|
Loading…
Reference in New Issue
Block a user