Remove debug statement

This commit is contained in:
Francois Chollet 2023-05-24 15:06:15 -07:00
parent c2ad0e7cd2
commit 225252e965
3 changed files with 39 additions and 47 deletions

@ -75,8 +75,8 @@ class BatchNormalization(Layer):
- `training=True`: The layer will normalize its inputs using
the mean and variance of the current batch of inputs.
- `training=False`: The layer will normalize its inputs using
the mean and variance of its moving statistics, learned during
training.
the mean and variance of its moving statistics,
learned during training.
Reference:
@ -102,10 +102,11 @@ class BatchNormalization(Layer):
Note that:
- Setting `trainable` on an model containing other layers will recursively
set the `trainable` value of all inner layers.
- If the value of the `trainable` attribute is changed after calling
`compile()` on a model, the new value doesn't take effect for this model
- Setting `trainable` on an model containing other layers will
recursively set the `trainable` value of all inner layers.
- If the value of the `trainable`
attribute is changed after calling `compile()` on a model,
the new value doesn't take effect for this model
until `compile()` is called again.
"""
@ -188,29 +189,26 @@ class BatchNormalization(Layer):
return input_shape
def call(self, inputs, training=None, mask=None):
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
# TODO: support masking during stats computation.
if training and self.trainable:
mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True)
variance = ops.var(inputs, axis=self._reduction_axes, keepdims=True)
mean = ops.mean(inputs, axis=self._reduction_axes)
variance = ops.var(inputs, axis=self._reduction_axes)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
self.moving_variance.assign(
self.moving_variance * self.momentum
+ variance * (1.0 - self.momentum)
)
self.moving_mean.assign(
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
)
else:
moving_mean = ops.reshape(self.moving_mean, broadcast_shape)
moving_variance = ops.reshape(self.moving_variance, broadcast_shape)
outputs = (inputs - moving_mean) / ops.sqrt(
moving_variance + self.epsilon
outputs = (inputs - self.moving_mean) / ops.sqrt(
self.moving_variance + self.epsilon
)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
outputs = outputs * gamma
outputs = outputs * self.gamma
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
outputs = outputs + beta
outputs = outputs + self.beta
return outputs
def get_config(self):

@ -1,11 +1,10 @@
import numpy as np
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
class BatchNormalizationTest(testing.TestCase):
def test_bn_basics(self):
# vector case
self.run_layer_test(
@ -57,35 +56,27 @@ class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
supports_masking=True,
)
@parameterized.product(
axis=(-1, 1),
input_shape=((5, 2, 3), (5, 3, 3, 2)),
)
def test_correctness(self, axis, input_shape):
def test_correctness(self):
# Training
layer = layers.BatchNormalization(axis=axis, momentum=0)
layer = layers.BatchNormalization(axis=-1, momentum=0.8)
# Random data centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=input_shape)
out = x
for _ in range(3):
out = layer(out, training=True)
x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3))
for _ in range(10):
out = layer(x, training=True)
# Assert the normalization is correct.
broadcast_shape = [1] * len(input_shape)
broadcast_shape[axis] = input_shape[axis]
out -= np.reshape(np.array(layer.beta), broadcast_shape)
out /= np.reshape(np.array(layer.gamma), broadcast_shape)
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
reduction_axes = list(range(len(input_shape)))
del reduction_axes[axis]
reduction_axes = tuple(reduction_axes)
self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3)
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
# Inference
inference_out = layer(x, training=False)
training_out = layer(x, training=True)
self.assertNotAllClose(inference_out, training_out)
out = layer(x, training=False)
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
def test_trainable_behavior(self):
layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7)
@ -114,3 +105,7 @@ class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
def test_masking_correctness(self):
# TODO
pass

@ -138,7 +138,6 @@ class Optimizer:
dtype=None,
name=None,
):
print("add variable", shape)
self._check_super_called()
initializer = initializers.get(initializer)
variable = backend.Variable(