Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
225252e965
commit
faf3711ee6
@ -75,8 +75,8 @@ class BatchNormalization(Layer):
|
||||
- `training=True`: The layer will normalize its inputs using
|
||||
the mean and variance of the current batch of inputs.
|
||||
- `training=False`: The layer will normalize its inputs using
|
||||
the mean and variance of its moving statistics,
|
||||
learned during training.
|
||||
the mean and variance of its moving statistics, learned during
|
||||
training.
|
||||
|
||||
Reference:
|
||||
|
||||
@ -102,11 +102,10 @@ class BatchNormalization(Layer):
|
||||
|
||||
Note that:
|
||||
|
||||
- Setting `trainable` on an model containing other layers will
|
||||
recursively set the `trainable` value of all inner layers.
|
||||
- If the value of the `trainable`
|
||||
attribute is changed after calling `compile()` on a model,
|
||||
the new value doesn't take effect for this model
|
||||
- Setting `trainable` on an model containing other layers will recursively
|
||||
set the `trainable` value of all inner layers.
|
||||
- If the value of the `trainable` attribute is changed after calling
|
||||
`compile()` on a model, the new value doesn't take effect for this model
|
||||
until `compile()` is called again.
|
||||
"""
|
||||
|
||||
@ -189,26 +188,29 @@ class BatchNormalization(Layer):
|
||||
return input_shape
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
# TODO: support masking during stats computation.
|
||||
broadcast_shape = [1] * len(inputs.shape)
|
||||
broadcast_shape[self.axis] = inputs.shape[self.axis]
|
||||
if training and self.trainable:
|
||||
mean = ops.mean(inputs, axis=self._reduction_axes)
|
||||
variance = ops.var(inputs, axis=self._reduction_axes)
|
||||
mean = ops.mean(inputs, axis=self._reduction_axes, keepdims=True)
|
||||
variance = ops.var(inputs, axis=self._reduction_axes, keepdims=True)
|
||||
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
|
||||
self.moving_variance.assign(
|
||||
self.moving_variance * self.momentum
|
||||
+ variance * (1.0 - self.momentum)
|
||||
)
|
||||
mean = ops.squeeze(mean, self._reduction_axes)
|
||||
variance = ops.squeeze(variance, self._reduction_axes)
|
||||
self.moving_mean.assign(
|
||||
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
|
||||
)
|
||||
else:
|
||||
outputs = (inputs - self.moving_mean) / ops.sqrt(
|
||||
self.moving_variance + self.epsilon
|
||||
moving_mean = ops.reshape(self.moving_mean, broadcast_shape)
|
||||
moving_variance = ops.reshape(self.moving_variance, broadcast_shape)
|
||||
outputs = (inputs - moving_mean) / ops.sqrt(
|
||||
moving_variance + self.epsilon
|
||||
)
|
||||
if self.scale:
|
||||
outputs = outputs * self.gamma
|
||||
gamma = ops.reshape(self.gamma, broadcast_shape)
|
||||
outputs = outputs * gamma
|
||||
if self.center:
|
||||
outputs = outputs + self.beta
|
||||
beta = ops.reshape(self.beta, broadcast_shape)
|
||||
outputs = outputs + beta
|
||||
return outputs
|
||||
|
||||
def get_config(self):
|
||||
|
@ -1,10 +1,11 @@
|
||||
import numpy as np
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class BatchNormalizationTest(testing.TestCase):
|
||||
class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_bn_basics(self):
|
||||
# vector case
|
||||
self.run_layer_test(
|
||||
@ -56,27 +57,35 @@ class BatchNormalizationTest(testing.TestCase):
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_correctness(self):
|
||||
@parameterized.product(
|
||||
axis=(-1, 1),
|
||||
input_shape=((5, 2, 3), (5, 3, 3, 2)),
|
||||
)
|
||||
def test_correctness(self, axis, input_shape):
|
||||
# Training
|
||||
layer = layers.BatchNormalization(axis=-1, momentum=0.8)
|
||||
layer = layers.BatchNormalization(axis=axis, momentum=0)
|
||||
# Random data centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3))
|
||||
for _ in range(10):
|
||||
out = layer(x, training=True)
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=input_shape)
|
||||
out = x
|
||||
for _ in range(3):
|
||||
out = layer(out, training=True)
|
||||
|
||||
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
|
||||
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
|
||||
# Assert the normalization is correct.
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
broadcast_shape[axis] = input_shape[axis]
|
||||
out -= np.reshape(np.array(layer.beta), broadcast_shape)
|
||||
out /= np.reshape(np.array(layer.gamma), broadcast_shape)
|
||||
|
||||
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
|
||||
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
|
||||
reduction_axes = list(range(len(input_shape)))
|
||||
del reduction_axes[axis]
|
||||
reduction_axes = tuple(reduction_axes)
|
||||
self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3)
|
||||
self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3)
|
||||
|
||||
# Inference
|
||||
out = layer(x, training=False)
|
||||
out -= np.reshape(np.array(layer.beta), (1, 1, 1, 3))
|
||||
out /= np.reshape(np.array(layer.gamma), (1, 1, 1, 3))
|
||||
|
||||
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
|
||||
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
|
||||
inference_out = layer(x, training=False)
|
||||
training_out = layer(x, training=True)
|
||||
self.assertNotAllClose(inference_out, training_out)
|
||||
|
||||
def test_trainable_behavior(self):
|
||||
layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7)
|
||||
@ -105,7 +114,3 @@ class BatchNormalizationTest(testing.TestCase):
|
||||
|
||||
self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
|
||||
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
|
||||
|
||||
def test_masking_correctness(self):
|
||||
# TODO
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user