Adds GroupNormalization and SpectralNormalization layers and associated tests (#116)
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests
This commit is contained in:
parent
a42ea97417
commit
5e1558381f
@ -30,9 +30,15 @@ from keras_core.layers.merging.subtract import subtract
|
||||
from keras_core.layers.normalization.batch_normalization import (
|
||||
BatchNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.group_normalization import (
|
||||
GroupNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.layer_normalization import (
|
||||
LayerNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.spectral_normalization import (
|
||||
SpectralNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.unit_normalization import UnitNormalization
|
||||
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
|
||||
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
|
||||
|
236
keras_core/layers/normalization/group_normalization.py
Normal file
236
keras_core/layers/normalization/group_normalization.py
Normal file
@ -0,0 +1,236 @@
|
||||
from keras_core import constraints
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.GroupNormalization")
|
||||
class GroupNormalization(Layer):
|
||||
"""Group normalization layer.
|
||||
|
||||
Group Normalization divides the channels into groups and computes
|
||||
within each group the mean and variance for normalization.
|
||||
Empirically, its accuracy is more stable than batch norm in a wide
|
||||
range of small batch sizes, if learning rate is adjusted linearly
|
||||
with batch sizes.
|
||||
|
||||
Relation to Layer Normalization:
|
||||
If the number of groups is set to 1, then this operation becomes nearly
|
||||
identical to Layer Normalization (see Layer Normalization docs for details).
|
||||
|
||||
Relation to Instance Normalization:
|
||||
If the number of groups is set to the input dimension (number of groups is
|
||||
equal to number of channels), then this operation becomes identical to
|
||||
Instance Normalization.
|
||||
|
||||
Args:
|
||||
groups: Integer, the number of groups for Group Normalization. Can be in
|
||||
the range `[1, N]` where N is the input dimension. The input
|
||||
dimension must be divisible by the number of groups.
|
||||
Defaults to 32.
|
||||
axis: Integer or List/Tuple. The axis or axes to normalize across.
|
||||
Typically, this is the features axis/axes. The left-out axes are
|
||||
typically the batch axis/axes. -1 is the last dimension in the
|
||||
input. Defaults to -1.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
Defaults to 1e-3.
|
||||
center: If `True`, add offset of `beta` to normalized tensor.
|
||||
If `False`, `beta` is ignored. Defaults to `True`.
|
||||
scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `relu`), this can be
|
||||
disabled since the scaling will be done by the next layer.
|
||||
Defaults to `True`.
|
||||
beta_initializer: Initializer for the beta weight. Defaults to zeros.
|
||||
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
|
||||
beta_regularizer: Optional regularizer for the beta weight. None by
|
||||
default.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight. None by
|
||||
default.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
None by default.
|
||||
gamma_constraint: Optional constraint for the gamma weight. None by
|
||||
default. Input shape: Arbitrary. Use the keyword argument
|
||||
`input_shape` (tuple of integers, does not include the samples
|
||||
axis) when using this layer as the first layer in a model.
|
||||
Output shape: Same shape as input.
|
||||
**kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).
|
||||
|
||||
Reference:
|
||||
|
||||
- [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
groups=32,
|
||||
axis=-1,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer="zeros",
|
||||
gamma_initializer="ones",
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.groups = groups
|
||||
self.axis = axis
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
dim = input_shape[self.axis]
|
||||
|
||||
if dim is None:
|
||||
raise ValueError(
|
||||
f"Axis {self.axis} of input tensor should have a defined "
|
||||
"dimension but the layer received an input with shape "
|
||||
f"{input_shape}."
|
||||
)
|
||||
|
||||
if self.groups == -1:
|
||||
self.groups = dim
|
||||
|
||||
if dim < self.groups:
|
||||
raise ValueError(
|
||||
f"Number of groups ({self.groups}) cannot be more than the "
|
||||
f"number of channels ({dim})."
|
||||
)
|
||||
|
||||
if dim % self.groups != 0:
|
||||
raise ValueError(
|
||||
f"Number of groups ({self.groups}) must be a multiple "
|
||||
f"of the number of channels ({dim})."
|
||||
)
|
||||
|
||||
self.input_spec = InputSpec(
|
||||
ndim=len(input_shape), axes={self.axis: dim}
|
||||
)
|
||||
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(
|
||||
shape=(dim,),
|
||||
name="gamma",
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint,
|
||||
)
|
||||
else:
|
||||
self.gamma = None
|
||||
|
||||
if self.center:
|
||||
self.beta = self.add_weight(
|
||||
shape=(dim,),
|
||||
name="beta",
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint,
|
||||
)
|
||||
else:
|
||||
self.beta = None
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs):
|
||||
input_shape = inputs.shape
|
||||
|
||||
reshaped_inputs = self._reshape_into_groups(inputs)
|
||||
|
||||
normalized_inputs = self._apply_normalization(
|
||||
reshaped_inputs, input_shape
|
||||
)
|
||||
|
||||
return ops.reshape(normalized_inputs, input_shape)
|
||||
|
||||
def _reshape_into_groups(self, inputs):
|
||||
input_shape = inputs.shape
|
||||
group_shape = [input_shape[i] for i in range(len(input_shape))]
|
||||
|
||||
group_shape[self.axis] = input_shape[self.axis] // self.groups
|
||||
group_shape.insert(self.axis, self.groups)
|
||||
group_shape = ops.stack(group_shape)
|
||||
reshaped_inputs = ops.reshape(inputs, group_shape)
|
||||
return reshaped_inputs
|
||||
|
||||
def _apply_normalization(self, reshaped_inputs, input_shape):
|
||||
group_reduction_axes = list(range(1, len(reshaped_inputs.shape)))
|
||||
|
||||
axis = -2 if self.axis == -1 else self.axis - 1
|
||||
group_reduction_axes.pop(axis)
|
||||
|
||||
mean = ops.mean(
|
||||
reshaped_inputs, axis=group_reduction_axes, keepdims=True
|
||||
)
|
||||
variance = ops.var(
|
||||
reshaped_inputs, axis=group_reduction_axes, keepdims=True
|
||||
)
|
||||
|
||||
gamma, beta = self._get_reshaped_weights(input_shape)
|
||||
|
||||
# Compute the batch normalization.
|
||||
inv = 1 / ops.sqrt(variance + self.epsilon)
|
||||
if gamma is not None:
|
||||
inv *= gamma
|
||||
|
||||
x = beta - mean * inv if beta is not None else -mean * inv
|
||||
normalized_inputs = reshaped_inputs * ops.cast(
|
||||
inv, reshaped_inputs.dtype
|
||||
) + ops.cast(x, reshaped_inputs.dtype)
|
||||
|
||||
normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype)
|
||||
|
||||
return normalized_inputs
|
||||
|
||||
def _get_reshaped_weights(self, input_shape):
|
||||
broadcast_shape = self._create_broadcast_shape(input_shape)
|
||||
gamma = None
|
||||
beta = None
|
||||
if self.scale:
|
||||
gamma = ops.reshape(self.gamma, broadcast_shape)
|
||||
|
||||
if self.center:
|
||||
beta = ops.reshape(self.beta, broadcast_shape)
|
||||
return gamma, beta
|
||||
|
||||
def _create_broadcast_shape(self, input_shape):
|
||||
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
|
||||
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
|
||||
broadcast_shape.insert(self.axis, self.groups)
|
||||
|
||||
return broadcast_shape
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"groups": self.groups,
|
||||
"axis": self.axis,
|
||||
"epsilon": self.epsilon,
|
||||
"center": self.center,
|
||||
"scale": self.scale,
|
||||
"beta_initializer": initializers.serialize(self.beta_initializer),
|
||||
"gamma_initializer": initializers.serialize(self.gamma_initializer),
|
||||
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
|
||||
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
|
||||
"beta_constraint": constraints.serialize(self.beta_constraint),
|
||||
"gamma_constraint": constraints.serialize(self.gamma_constraint),
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
114
keras_core/layers/normalization/group_normalization_test.py
Normal file
114
keras_core/layers/normalization/group_normalization_test.py
Normal file
@ -0,0 +1,114 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import constraints
|
||||
from keras_core import layers
|
||||
from keras_core import regularizers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class GroupNormalizationTest(testing.TestCase):
|
||||
def test_groupnorm(self):
|
||||
self.run_layer_test(
|
||||
layers.GroupNormalization,
|
||||
init_kwargs={
|
||||
"gamma_regularizer": regularizers.L2(0.01),
|
||||
"beta_regularizer": regularizers.L2(0.01),
|
||||
},
|
||||
input_shape=(3, 4, 32),
|
||||
expected_output_shape=(3, 4, 32),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=2,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
self.run_layer_test(
|
||||
layers.GroupNormalization,
|
||||
init_kwargs={
|
||||
"groups": 4,
|
||||
"gamma_constraint": constraints.UnitNorm(),
|
||||
"beta_constraint": constraints.UnitNorm(),
|
||||
},
|
||||
input_shape=(3, 4, 4),
|
||||
expected_output_shape=(3, 4, 4),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_correctness_instance_norm(self):
|
||||
instance_norm_layer = layers.GroupNormalization(
|
||||
groups=4, axis=-1, scale=False, center=False
|
||||
)
|
||||
|
||||
inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]])
|
||||
|
||||
expected_instance_norm_output = np.array(
|
||||
[[[-1.0, -1.0, 1.0, 1.0], [1.0, 1.0, -1.0, -1.0]]]
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
instance_norm_layer(inputs),
|
||||
expected_instance_norm_output,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
def test_correctness_1d(self):
|
||||
layer_with_1_group = layers.GroupNormalization(
|
||||
groups=1, axis=-1, scale=False, center=False
|
||||
)
|
||||
layer_with_2_groups = layers.GroupNormalization(
|
||||
groups=2, axis=1, scale=False, center=False
|
||||
)
|
||||
|
||||
inputs = np.array([[-1.0, -1.0, 1.0, 1.0, 2.0, 2.0, 0, -2.0]])
|
||||
|
||||
expected_output_1_group = np.array(
|
||||
[[-0.898, -0.898, 0.539, 0.539, 1.257, 1.257, -0.180, -1.616]],
|
||||
)
|
||||
self.assertAllClose(
|
||||
layer_with_1_group(inputs),
|
||||
expected_output_1_group,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
expected_output_2_groups = np.array(
|
||||
[[-1.0, -1.0, 1.0, 1.0, 0.904, 0.904, -0.301, -1.507]]
|
||||
)
|
||||
self.assertAllClose(
|
||||
layer_with_2_groups(inputs),
|
||||
expected_output_2_groups,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
def test_correctness_2d(self):
|
||||
layer_with_1_group = layers.GroupNormalization(
|
||||
groups=1, axis=-1, scale=False, center=False
|
||||
)
|
||||
layer_with_2_groups = layers.GroupNormalization(
|
||||
groups=2, axis=2, scale=False, center=False
|
||||
)
|
||||
|
||||
inputs = np.array([[[-1.0, -1.0, 2.0, 2.0], [1.0, 1.0, 0, -2.0]]])
|
||||
|
||||
expected_output_1_group = np.array(
|
||||
[[[-0.898, -0.898, 1.257, 1.257], [0.539, 0.539, -0.180, -1.616]]]
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
layer_with_1_group(inputs),
|
||||
expected_output_1_group,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
expected_output_2_groups = np.array(
|
||||
[[[-1.0, -1.0, 0.904, 0.904], [1.0, 1.0, -0.301, -1.507]]]
|
||||
)
|
||||
self.assertAllClose(
|
||||
layer_with_2_groups(inputs),
|
||||
expected_output_2_groups,
|
||||
atol=1e-3,
|
||||
)
|
124
keras_core/layers/normalization/spectral_normalization.py
Normal file
124
keras_core/layers/normalization/spectral_normalization.py
Normal file
@ -0,0 +1,124 @@
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers import Wrapper
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.SpectralNormalization")
|
||||
class SpectralNormalization(Wrapper):
|
||||
"""Performs spectral normalization on the weights of a target layer.
|
||||
|
||||
This wrapper controls the Lipschitz constant of the weights of a layer by
|
||||
constraining their spectral norm, which can stabilize the training of GANs.
|
||||
|
||||
Args:
|
||||
layer: A `keras_core.layers.Layer` instance that
|
||||
has either a `kernel` (e.g. `Conv2D`, `Dense`...)
|
||||
or an `embeddings` attribute (`Embedding` layer).
|
||||
power_iterations: int, the number of iterations during normalization.
|
||||
**kwargs: Base wrapper keyword arguments.
|
||||
|
||||
Examples:
|
||||
|
||||
Wrap `keras_core.layers.Conv2D`:
|
||||
>>> x = np.random.rand(1, 10, 10, 1)
|
||||
>>> conv2d = SpectralNormalization(keras_core.layers.Conv2D(2, 2))
|
||||
>>> y = conv2d(x)
|
||||
>>> y.shape
|
||||
(1, 9, 9, 2)
|
||||
|
||||
Wrap `keras_core.layers.Dense`:
|
||||
>>> x = np.random.rand(1, 10, 10, 1)
|
||||
>>> dense = SpectralNormalization(keras_core.layers.Dense(10))
|
||||
>>> y = dense(x)
|
||||
>>> y.shape
|
||||
(1, 10, 10, 10)
|
||||
|
||||
Reference:
|
||||
|
||||
- [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957).
|
||||
"""
|
||||
|
||||
def __init__(self, layer, power_iterations=1, **kwargs):
|
||||
super().__init__(layer, **kwargs)
|
||||
if power_iterations <= 0:
|
||||
raise ValueError(
|
||||
"`power_iterations` should be greater than zero. Received: "
|
||||
f"`power_iterations={power_iterations}`"
|
||||
)
|
||||
self.power_iterations = power_iterations
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
self.input_spec = InputSpec(shape=[None] + list(input_shape[1:]))
|
||||
|
||||
if hasattr(self.layer, "kernel"):
|
||||
self.kernel = self.layer.kernel
|
||||
elif hasattr(self.layer, "embeddings"):
|
||||
self.kernel = self.layer.embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(self.layer).__name__} object has no attribute 'kernel' "
|
||||
"nor 'embeddings'"
|
||||
)
|
||||
|
||||
self.kernel_shape = self.kernel.shape
|
||||
|
||||
self.vector_u = self.add_weight(
|
||||
shape=(1, self.kernel_shape[-1]),
|
||||
initializer=initializers.TruncatedNormal(stddev=0.02),
|
||||
trainable=False,
|
||||
name="vector_u",
|
||||
dtype=self.kernel.dtype,
|
||||
)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
if training:
|
||||
self.normalize_weights()
|
||||
|
||||
output = self.layer(inputs)
|
||||
return output
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return self.layer.compute_output_shape(input_shape)
|
||||
|
||||
def normalize_weights(self):
|
||||
"""Generate spectral normalized weights.
|
||||
|
||||
This method will update the value of `self.kernel` with the
|
||||
spectral normalized value, so that the layer is ready for `call()`.
|
||||
"""
|
||||
|
||||
weights = ops.reshape(self.kernel, [-1, self.kernel_shape[-1]])
|
||||
vector_u = self.vector_u
|
||||
|
||||
# check for zeroes weights
|
||||
if not all([w == 0.0 for w in weights]):
|
||||
for _ in range(self.power_iterations):
|
||||
vector_v = self._l2_normalize(
|
||||
ops.matmul(vector_u, ops.transpose(weights))
|
||||
)
|
||||
vector_u = self._l2_normalize(ops.matmul(vector_v, weights))
|
||||
# vector_u = tf.stop_gradient(vector_u)
|
||||
# vector_v = tf.stop_gradient(vector_v)
|
||||
sigma = ops.matmul(
|
||||
ops.matmul(vector_v, weights), ops.transpose(vector_u)
|
||||
)
|
||||
self.vector_u.assign(ops.cast(vector_u, self.vector_u.dtype))
|
||||
self.kernel.assign(
|
||||
ops.cast(
|
||||
ops.reshape(self.kernel / sigma, self.kernel_shape),
|
||||
self.kernel.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
def _l2_normalize(self, x):
|
||||
square_sum = ops.sum(ops.square(x), keepdims=True)
|
||||
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
|
||||
return ops.multiply(x, x_inv_norm)
|
||||
|
||||
def get_config(self):
|
||||
config = {"power_iterations": self.power_iterations}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import initializers
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class SpectralNormalizationTest(testing.TestCase):
|
||||
def test_basic_spectralnorm(self):
|
||||
self.run_layer_test(
|
||||
layers.SpectralNormalization,
|
||||
init_kwargs={"layer": layers.Dense(2)},
|
||||
input_shape=None,
|
||||
input_data=np.random.uniform(size=(10, 3, 4)),
|
||||
expected_output_shape=(10, 3, 2),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=2,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
def test_apply_layer(self):
|
||||
images = np.ones((1, 2, 2, 1))
|
||||
sn_wrapper = layers.SpectralNormalization(
|
||||
layers.Conv2D(
|
||||
1, (2, 2), kernel_initializer=initializers.Constant(value=1)
|
||||
),
|
||||
)
|
||||
|
||||
result = sn_wrapper(images, training=False)
|
||||
result_train = sn_wrapper(images, training=True)
|
||||
expected_output = np.array([[[[4.0]]]], dtype=np.float32)
|
||||
self.assertAllClose(result, expected_output)
|
||||
# max eigen value of 2x2 matrix of ones is 2
|
||||
self.assertAllClose(result_train, expected_output / 2)
|
Loading…
Reference in New Issue
Block a user