From 225252e965e3ff6b954d38378eda057439d4ef08 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 24 May 2023 15:06:15 -0700 Subject: [PATCH] Remove debug statement --- .../normalization/batch_normalization.py | 40 ++++++++--------- .../normalization/batch_normalization_test.py | 45 +++++++++---------- keras_core/optimizers/base_optimizer.py | 1 - 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index 08eb61c5c..cfc0c0aa4 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -75,8 +75,8 @@ class BatchNormalization(Layer): - `training=True`: The layer will normalize its inputs using the mean and variance of the current batch of inputs. - `training=False`: The layer will normalize its inputs using - the mean and variance of its moving statistics, learned during - training. + the mean and variance of its moving statistics, + learned during training. Reference: @@ -102,11 +102,12 @@ class BatchNormalization(Layer): Note that: - - Setting `trainable` on an model containing other layers will recursively - set the `trainable` value of all inner layers. - - If the value of the `trainable` attribute is changed after calling - `compile()` on a model, the new value doesn't take effect for this model - until `compile()` is called again. + - Setting `trainable` on an model containing other layers will + recursively set the `trainable` value of all inner layers. + - If the value of the `trainable` + attribute is changed after calling `compile()` on a model, + the new value doesn't take effect for this model + until `compile()` is called again. """ def __init__( @@ -188,29 +189,26 @@ class BatchNormalization(Layer): return input_shape def call(self, inputs, training=None, mask=None): - broadcast_shape = [1] * len(inputs.shape) - broadcast_shape[self.axis] = inputs.shape[self.axis] + # TODO: support masking during stats computation. if training and self.trainable: - mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True) - variance = ops.var(inputs, axis=self._reduction_axes, keepdims=True) + mean = ops.mean(inputs, axis=self._reduction_axes) + variance = ops.var(inputs, axis=self._reduction_axes) outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon) - mean = ops.squeeze(mean, self._reduction_axes) - variance = ops.squeeze(variance, self._reduction_axes) + self.moving_variance.assign( + self.moving_variance * self.momentum + + variance * (1.0 - self.momentum) + ) self.moving_mean.assign( self.moving_mean * self.momentum + mean * (1.0 - self.momentum) ) else: - moving_mean = ops.reshape(self.moving_mean, broadcast_shape) - moving_variance = ops.reshape(self.moving_variance, broadcast_shape) - outputs = (inputs - moving_mean) / ops.sqrt( - moving_variance + self.epsilon + outputs = (inputs - self.moving_mean) / ops.sqrt( + self.moving_variance + self.epsilon ) if self.scale: - gamma = ops.reshape(self.gamma, broadcast_shape) - outputs = outputs * gamma + outputs = outputs * self.gamma if self.center: - beta = ops.reshape(self.beta, broadcast_shape) - outputs = outputs + beta + outputs = outputs + self.beta return outputs def get_config(self): diff --git a/keras_core/layers/normalization/batch_normalization_test.py b/keras_core/layers/normalization/batch_normalization_test.py index 49116af45..e08240a48 100644 --- a/keras_core/layers/normalization/batch_normalization_test.py +++ b/keras_core/layers/normalization/batch_normalization_test.py @@ -1,11 +1,10 @@ import numpy as np -from absl.testing import parameterized from keras_core import layers from keras_core import testing -class BatchNormalizationTest(testing.TestCase, parameterized.TestCase): +class BatchNormalizationTest(testing.TestCase): def test_bn_basics(self): # vector case self.run_layer_test( @@ -57,35 +56,27 @@ class BatchNormalizationTest(testing.TestCase, parameterized.TestCase): supports_masking=True, ) - @parameterized.product( - axis=(-1, 1), - input_shape=((5, 2, 3), (5, 3, 3, 2)), - ) - def test_correctness(self, axis, input_shape): + def test_correctness(self): # Training - layer = layers.BatchNormalization(axis=axis, momentum=0) + layer = layers.BatchNormalization(axis=-1, momentum=0.8) # Random data centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=input_shape) - out = x - for _ in range(3): - out = layer(out, training=True) + x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3)) + for _ in range(10): + out = layer(x, training=True) - # Assert the normalization is correct. - broadcast_shape = [1] * len(input_shape) - broadcast_shape[axis] = input_shape[axis] - out -= np.reshape(np.array(layer.beta), broadcast_shape) - out /= np.reshape(np.array(layer.gamma), broadcast_shape) + out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3)) + out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3)) - reduction_axes = list(range(len(input_shape))) - del reduction_axes[axis] - reduction_axes = tuple(reduction_axes) - self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3) - self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3) + self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3) + self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3) # Inference - inference_out = layer(x, training=False) - training_out = layer(x, training=True) - self.assertNotAllClose(inference_out, training_out) + out = layer(x, training=False) + out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3)) + out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3)) + + self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1) + self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1) def test_trainable_behavior(self): layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7) @@ -114,3 +105,7 @@ class BatchNormalizationTest(testing.TestCase, parameterized.TestCase): self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3) self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3) + + def test_masking_correctness(self): + # TODO + pass diff --git a/keras_core/optimizers/base_optimizer.py b/keras_core/optimizers/base_optimizer.py index ccae67cf7..c0fa74dbd 100644 --- a/keras_core/optimizers/base_optimizer.py +++ b/keras_core/optimizers/base_optimizer.py @@ -138,7 +138,6 @@ class Optimizer: dtype=None, name=None, ): - print("add variable", shape) self._check_super_called() initializer = initializers.get(initializer) variable = backend.Variable(