From 5e1558381f464a68d7c6a87cb144dfba98b650ee Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Tue, 9 May 2023 17:00:12 +0000 Subject: [PATCH] Adds GroupNormalization and SpectralNormalization layers and associated tests (#116) * Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests --- keras_core/layers/__init__.py | 6 + .../normalization/group_normalization.py | 236 ++++++++++++++++++ .../normalization/group_normalization_test.py | 114 +++++++++ .../normalization/spectral_normalization.py | 124 +++++++++ .../spectral_normalization_test.py | 36 +++ 5 files changed, 516 insertions(+) create mode 100644 keras_core/layers/normalization/group_normalization.py create mode 100644 keras_core/layers/normalization/group_normalization_test.py create mode 100644 keras_core/layers/normalization/spectral_normalization.py create mode 100644 keras_core/layers/normalization/spectral_normalization_test.py diff --git a/keras_core/layers/__init__.py b/keras_core/layers/__init__.py index e2ad84ef9..8ea887892 100644 --- a/keras_core/layers/__init__.py +++ b/keras_core/layers/__init__.py @@ -30,9 +30,15 @@ from keras_core.layers.merging.subtract import subtract from keras_core.layers.normalization.batch_normalization import ( BatchNormalization, ) +from keras_core.layers.normalization.group_normalization import ( + GroupNormalization, +) from keras_core.layers.normalization.layer_normalization import ( LayerNormalization, ) +from keras_core.layers.normalization.spectral_normalization import ( + SpectralNormalization, +) from keras_core.layers.normalization.unit_normalization import UnitNormalization from keras_core.layers.pooling.average_pooling1d import AveragePooling1D from keras_core.layers.pooling.average_pooling2d import AveragePooling2D diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py new file mode 100644 index 000000000..c3f6fc4c7 --- /dev/null +++ b/keras_core/layers/normalization/group_normalization.py @@ -0,0 +1,236 @@ +from keras_core import constraints +from keras_core import initializers +from keras_core import operations as ops +from keras_core import regularizers +from keras_core.api_export import keras_core_export +from keras_core.layers.input_spec import InputSpec +from keras_core.layers.layer import Layer + + +@keras_core_export("keras_core.layers.GroupNormalization") +class GroupNormalization(Layer): + """Group normalization layer. + + Group Normalization divides the channels into groups and computes + within each group the mean and variance for normalization. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Relation to Layer Normalization: + If the number of groups is set to 1, then this operation becomes nearly + identical to Layer Normalization (see Layer Normalization docs for details). + + Relation to Instance Normalization: + If the number of groups is set to the input dimension (number of groups is + equal to number of channels), then this operation becomes identical to + Instance Normalization. + + Args: + groups: Integer, the number of groups for Group Normalization. Can be in + the range `[1, N]` where N is the input dimension. The input + dimension must be divisible by the number of groups. + Defaults to 32. + axis: Integer or List/Tuple. The axis or axes to normalize across. + Typically, this is the features axis/axes. The left-out axes are + typically the batch axis/axes. -1 is the last dimension in the + input. Defaults to -1. + epsilon: Small float added to variance to avoid dividing by zero. + Defaults to 1e-3. + center: If `True`, add offset of `beta` to normalized tensor. + If `False`, `beta` is ignored. Defaults to `True`. + scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used. + When the next layer is linear (also e.g. `relu`), this can be + disabled since the scaling will be done by the next layer. + Defaults to `True`. + beta_initializer: Initializer for the beta weight. Defaults to zeros. + gamma_initializer: Initializer for the gamma weight. Defaults to ones. + beta_regularizer: Optional regularizer for the beta weight. None by + default. + gamma_regularizer: Optional regularizer for the gamma weight. None by + default. + beta_constraint: Optional constraint for the beta weight. + None by default. + gamma_constraint: Optional constraint for the gamma weight. None by + default. Input shape: Arbitrary. Use the keyword argument + `input_shape` (tuple of integers, does not include the samples + axis) when using this layer as the first layer in a model. + Output shape: Same shape as input. + **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`). + + Reference: + + - [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494) + """ + + def __init__( + self, + groups=32, + axis=-1, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self.groups = groups + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + + def build(self, input_shape): + dim = input_shape[self.axis] + + if dim is None: + raise ValueError( + f"Axis {self.axis} of input tensor should have a defined " + "dimension but the layer received an input with shape " + f"{input_shape}." + ) + + if self.groups == -1: + self.groups = dim + + if dim < self.groups: + raise ValueError( + f"Number of groups ({self.groups}) cannot be more than the " + f"number of channels ({dim})." + ) + + if dim % self.groups != 0: + raise ValueError( + f"Number of groups ({self.groups}) must be a multiple " + f"of the number of channels ({dim})." + ) + + self.input_spec = InputSpec( + ndim=len(input_shape), axes={self.axis: dim} + ) + + if self.scale: + self.gamma = self.add_weight( + shape=(dim,), + name="gamma", + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + ) + else: + self.gamma = None + + if self.center: + self.beta = self.add_weight( + shape=(dim,), + name="beta", + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + ) + else: + self.beta = None + + super().build(input_shape) + + def call(self, inputs): + input_shape = inputs.shape + + reshaped_inputs = self._reshape_into_groups(inputs) + + normalized_inputs = self._apply_normalization( + reshaped_inputs, input_shape + ) + + return ops.reshape(normalized_inputs, input_shape) + + def _reshape_into_groups(self, inputs): + input_shape = inputs.shape + group_shape = [input_shape[i] for i in range(len(input_shape))] + + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = ops.stack(group_shape) + reshaped_inputs = ops.reshape(inputs, group_shape) + return reshaped_inputs + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_reduction_axes = list(range(1, len(reshaped_inputs.shape))) + + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + + mean = ops.mean( + reshaped_inputs, axis=group_reduction_axes, keepdims=True + ) + variance = ops.var( + reshaped_inputs, axis=group_reduction_axes, keepdims=True + ) + + gamma, beta = self._get_reshaped_weights(input_shape) + + # Compute the batch normalization. + inv = 1 / ops.sqrt(variance + self.epsilon) + if gamma is not None: + inv *= gamma + + x = beta - mean * inv if beta is not None else -mean * inv + normalized_inputs = reshaped_inputs * ops.cast( + inv, reshaped_inputs.dtype + ) + ops.cast(x, reshaped_inputs.dtype) + + normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype) + + return normalized_inputs + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = None + beta = None + if self.scale: + gamma = ops.reshape(self.gamma, broadcast_shape) + + if self.center: + beta = ops.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _create_broadcast_shape(self, input_shape): + + broadcast_shape = [1] * len(input_shape) + + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + + return broadcast_shape + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "groups": self.groups, + "axis": self.axis, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "beta_initializer": initializers.serialize(self.beta_initializer), + "gamma_initializer": initializers.serialize(self.gamma_initializer), + "beta_regularizer": regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), + "beta_constraint": constraints.serialize(self.beta_constraint), + "gamma_constraint": constraints.serialize(self.gamma_constraint), + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras_core/layers/normalization/group_normalization_test.py b/keras_core/layers/normalization/group_normalization_test.py new file mode 100644 index 000000000..ce52f411e --- /dev/null +++ b/keras_core/layers/normalization/group_normalization_test.py @@ -0,0 +1,114 @@ +import numpy as np + +from keras_core import constraints +from keras_core import layers +from keras_core import regularizers +from keras_core import testing + + +class GroupNormalizationTest(testing.TestCase): + def test_groupnorm(self): + self.run_layer_test( + layers.GroupNormalization, + init_kwargs={ + "gamma_regularizer": regularizers.L2(0.01), + "beta_regularizer": regularizers.L2(0.01), + }, + input_shape=(3, 4, 32), + expected_output_shape=(3, 4, 32), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=2, + supports_masking=True, + ) + + self.run_layer_test( + layers.GroupNormalization, + init_kwargs={ + "groups": 4, + "gamma_constraint": constraints.UnitNorm(), + "beta_constraint": constraints.UnitNorm(), + }, + input_shape=(3, 4, 4), + expected_output_shape=(3, 4, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_correctness_instance_norm(self): + instance_norm_layer = layers.GroupNormalization( + groups=4, axis=-1, scale=False, center=False + ) + + inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]]) + + expected_instance_norm_output = np.array( + [[[-1.0, -1.0, 1.0, 1.0], [1.0, 1.0, -1.0, -1.0]]] + ) + + self.assertAllClose( + instance_norm_layer(inputs), + expected_instance_norm_output, + atol=1e-3, + ) + + def test_correctness_1d(self): + layer_with_1_group = layers.GroupNormalization( + groups=1, axis=-1, scale=False, center=False + ) + layer_with_2_groups = layers.GroupNormalization( + groups=2, axis=1, scale=False, center=False + ) + + inputs = np.array([[-1.0, -1.0, 1.0, 1.0, 2.0, 2.0, 0, -2.0]]) + + expected_output_1_group = np.array( + [[-0.898, -0.898, 0.539, 0.539, 1.257, 1.257, -0.180, -1.616]], + ) + self.assertAllClose( + layer_with_1_group(inputs), + expected_output_1_group, + atol=1e-3, + ) + + expected_output_2_groups = np.array( + [[-1.0, -1.0, 1.0, 1.0, 0.904, 0.904, -0.301, -1.507]] + ) + self.assertAllClose( + layer_with_2_groups(inputs), + expected_output_2_groups, + atol=1e-3, + ) + + def test_correctness_2d(self): + layer_with_1_group = layers.GroupNormalization( + groups=1, axis=-1, scale=False, center=False + ) + layer_with_2_groups = layers.GroupNormalization( + groups=2, axis=2, scale=False, center=False + ) + + inputs = np.array([[[-1.0, -1.0, 2.0, 2.0], [1.0, 1.0, 0, -2.0]]]) + + expected_output_1_group = np.array( + [[[-0.898, -0.898, 1.257, 1.257], [0.539, 0.539, -0.180, -1.616]]] + ) + + self.assertAllClose( + layer_with_1_group(inputs), + expected_output_1_group, + atol=1e-3, + ) + + expected_output_2_groups = np.array( + [[[-1.0, -1.0, 0.904, 0.904], [1.0, 1.0, -0.301, -1.507]]] + ) + self.assertAllClose( + layer_with_2_groups(inputs), + expected_output_2_groups, + atol=1e-3, + ) diff --git a/keras_core/layers/normalization/spectral_normalization.py b/keras_core/layers/normalization/spectral_normalization.py new file mode 100644 index 000000000..cfe6eaf04 --- /dev/null +++ b/keras_core/layers/normalization/spectral_normalization.py @@ -0,0 +1,124 @@ +from keras_core import initializers +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.layers import Wrapper +from keras_core.layers.input_spec import InputSpec + + +@keras_core_export("keras_core.layers.SpectralNormalization") +class SpectralNormalization(Wrapper): + """Performs spectral normalization on the weights of a target layer. + + This wrapper controls the Lipschitz constant of the weights of a layer by + constraining their spectral norm, which can stabilize the training of GANs. + + Args: + layer: A `keras_core.layers.Layer` instance that + has either a `kernel` (e.g. `Conv2D`, `Dense`...) + or an `embeddings` attribute (`Embedding` layer). + power_iterations: int, the number of iterations during normalization. + **kwargs: Base wrapper keyword arguments. + + Examples: + + Wrap `keras_core.layers.Conv2D`: + >>> x = np.random.rand(1, 10, 10, 1) + >>> conv2d = SpectralNormalization(keras_core.layers.Conv2D(2, 2)) + >>> y = conv2d(x) + >>> y.shape + (1, 9, 9, 2) + + Wrap `keras_core.layers.Dense`: + >>> x = np.random.rand(1, 10, 10, 1) + >>> dense = SpectralNormalization(keras_core.layers.Dense(10)) + >>> y = dense(x) + >>> y.shape + (1, 10, 10, 10) + + Reference: + + - [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957). + """ + + def __init__(self, layer, power_iterations=1, **kwargs): + super().__init__(layer, **kwargs) + if power_iterations <= 0: + raise ValueError( + "`power_iterations` should be greater than zero. Received: " + f"`power_iterations={power_iterations}`" + ) + self.power_iterations = power_iterations + + def build(self, input_shape): + super().build(input_shape) + self.input_spec = InputSpec(shape=[None] + list(input_shape[1:])) + + if hasattr(self.layer, "kernel"): + self.kernel = self.layer.kernel + elif hasattr(self.layer, "embeddings"): + self.kernel = self.layer.embeddings + else: + raise ValueError( + f"{type(self.layer).__name__} object has no attribute 'kernel' " + "nor 'embeddings'" + ) + + self.kernel_shape = self.kernel.shape + + self.vector_u = self.add_weight( + shape=(1, self.kernel_shape[-1]), + initializer=initializers.TruncatedNormal(stddev=0.02), + trainable=False, + name="vector_u", + dtype=self.kernel.dtype, + ) + + def call(self, inputs, training=False): + if training: + self.normalize_weights() + + output = self.layer(inputs) + return output + + def compute_output_shape(self, input_shape): + return self.layer.compute_output_shape(input_shape) + + def normalize_weights(self): + """Generate spectral normalized weights. + + This method will update the value of `self.kernel` with the + spectral normalized value, so that the layer is ready for `call()`. + """ + + weights = ops.reshape(self.kernel, [-1, self.kernel_shape[-1]]) + vector_u = self.vector_u + + # check for zeroes weights + if not all([w == 0.0 for w in weights]): + for _ in range(self.power_iterations): + vector_v = self._l2_normalize( + ops.matmul(vector_u, ops.transpose(weights)) + ) + vector_u = self._l2_normalize(ops.matmul(vector_v, weights)) + # vector_u = tf.stop_gradient(vector_u) + # vector_v = tf.stop_gradient(vector_v) + sigma = ops.matmul( + ops.matmul(vector_v, weights), ops.transpose(vector_u) + ) + self.vector_u.assign(ops.cast(vector_u, self.vector_u.dtype)) + self.kernel.assign( + ops.cast( + ops.reshape(self.kernel / sigma, self.kernel_shape), + self.kernel.dtype, + ) + ) + + def _l2_normalize(self, x): + square_sum = ops.sum(ops.square(x), keepdims=True) + x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12)) + return ops.multiply(x, x_inv_norm) + + def get_config(self): + config = {"power_iterations": self.power_iterations} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras_core/layers/normalization/spectral_normalization_test.py b/keras_core/layers/normalization/spectral_normalization_test.py new file mode 100644 index 000000000..4740a192e --- /dev/null +++ b/keras_core/layers/normalization/spectral_normalization_test.py @@ -0,0 +1,36 @@ +import numpy as np + +from keras_core import initializers +from keras_core import layers +from keras_core import testing + + +class SpectralNormalizationTest(testing.TestCase): + def test_basic_spectralnorm(self): + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Dense(2)}, + input_shape=None, + input_data=np.random.uniform(size=(10, 3, 4)), + expected_output_shape=(10, 3, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_apply_layer(self): + images = np.ones((1, 2, 2, 1)) + sn_wrapper = layers.SpectralNormalization( + layers.Conv2D( + 1, (2, 2), kernel_initializer=initializers.Constant(value=1) + ), + ) + + result = sn_wrapper(images, training=False) + result_train = sn_wrapper(images, training=True) + expected_output = np.array([[[[4.0]]]], dtype=np.float32) + self.assertAllClose(result, expected_output) + # max eigen value of 2x2 matrix of ones is 2 + self.assertAllClose(result_train, expected_output / 2)