Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-05-24 15:06:29 -07:00
parent 225252e965
commit faf3711ee6
2 changed files with 46 additions and 39 deletions

@ -75,8 +75,8 @@ class BatchNormalization(Layer):
- `training=True`: The layer will normalize its inputs using
the mean and variance of the current batch of inputs.
- `training=False`: The layer will normalize its inputs using
the mean and variance of its moving statistics,
learned during training.
the mean and variance of its moving statistics, learned during
training.
Reference:
@ -102,11 +102,10 @@ class BatchNormalization(Layer):
Note that:
- Setting `trainable` on an model containing other layers will
recursively set the `trainable` value of all inner layers.
- If the value of the `trainable`
attribute is changed after calling `compile()` on a model,
the new value doesn't take effect for this model
- Setting `trainable` on an model containing other layers will recursively
set the `trainable` value of all inner layers.
- If the value of the `trainable` attribute is changed after calling
`compile()` on a model, the new value doesn't take effect for this model
until `compile()` is called again.
"""
@ -189,26 +188,29 @@ class BatchNormalization(Layer):
return input_shape
def call(self, inputs, training=None, mask=None):
# TODO: support masking during stats computation.
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
if training and self.trainable:
mean = ops.mean(inputs, axis=self._reduction_axes)
variance = ops.var(inputs, axis=self._reduction_axes)
mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True)
variance = ops.var(inputs, axis=self._reduction_axes, keepdims=True)
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
self.moving_variance.assign(
self.moving_variance * self.momentum
+ variance * (1.0 - self.momentum)
)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
self.moving_mean.assign(
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
)
else:
outputs = (inputs - self.moving_mean) / ops.sqrt(
self.moving_variance + self.epsilon
moving_mean = ops.reshape(self.moving_mean, broadcast_shape)
moving_variance = ops.reshape(self.moving_variance, broadcast_shape)
outputs = (inputs - moving_mean) / ops.sqrt(
moving_variance + self.epsilon
)
if self.scale:
outputs = outputs * self.gamma
gamma = ops.reshape(self.gamma, broadcast_shape)
outputs = outputs * gamma
if self.center:
outputs = outputs + self.beta
beta = ops.reshape(self.beta, broadcast_shape)
outputs = outputs + beta
return outputs
def get_config(self):

@ -1,10 +1,11 @@
import numpy as np
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class BatchNormalizationTest(testing.TestCase):
class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
def test_bn_basics(self):
# vector case
self.run_layer_test(
@ -56,27 +57,35 @@ class BatchNormalizationTest(testing.TestCase):
supports_masking=True,
)
def test_correctness(self):
@parameterized.product(
axis=(-1, 1),
input_shape=((5, 2, 3), (5, 3, 3, 2)),
)
def test_correctness(self, axis, input_shape):
# Training
layer = layers.BatchNormalization(axis=-1, momentum=0.8)
layer = layers.BatchNormalization(axis=axis, momentum=0)
# Random data centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3))
for _ in range(10):
out = layer(x, training=True)
x = np.random.normal(loc=5.0, scale=10.0, size=input_shape)
out = x
for _ in range(3):
out = layer(out, training=True)
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
# Assert the normalization is correct.
broadcast_shape = [1] * len(input_shape)
broadcast_shape[axis] = input_shape[axis]
out -= np.reshape(np.array(layer.beta), broadcast_shape)
out /= np.reshape(np.array(layer.gamma), broadcast_shape)
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[axis]
reduction_axes = tuple(reduction_axes)
self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3)
# Inference
out = layer(x, training=False)
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
inference_out = layer(x, training=False)
training_out = layer(x, training=True)
self.assertNotAllClose(inference_out, training_out)
def test_trainable_behavior(self):
layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7)
@ -105,7 +114,3 @@ class BatchNormalizationTest(testing.TestCase):
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
def test_masking_correctness(self):
# TODO
pass