keras/keras_core/layers/preprocessing/normalization.py
2023-05-05 14:27:30 -07:00

313 lines
13 KiB
Python

import math
import numpy as np
import tensorflow as tf
from keras_core import backend
from keras_core import operations as ops
from keras_core.layers.layer import Layer
class Normalization(Layer):
"""A preprocessing layer that normalizes continuous features.
This layer will shift and scale inputs into a distribution centered around
0 with standard deviation 1. It accomplishes this by precomputing the mean
and variance of the data, and calling `(input - mean) / sqrt(var)` at
runtime.
The mean and variance values for the layer must be either supplied on
construction or learned via `adapt()`. `adapt()` will compute the mean and
variance of the data and store them as the layer's weights. `adapt()` should
be called before `fit()`, `evaluate()`, or `predict()`.
Args:
axis: Integer, tuple of integers, or None. The axis or axes that should
have a separate mean and variance for each index in the shape.
For example, if shape is `(None, 5)` and `axis=1`, the layer will
track 5 separate mean and variance values for the last axis.
If `axis` is set to `None`, the layer will normalize
all elements in the input by a scalar mean and variance.
When `-1`, the last axis of the input is assumed to be a
feature dimension and is normalized per index.
Note that in the specific case of batched scalar inputs where
the only axis is the batch axis, the default will normalize
each index in the batch separately.
In this case, consider passing `axis=None`. Defaults to `-1`.
mean: The mean value(s) to use during normalization. The passed value(s)
will be broadcast to the shape of the kept axes above;
if the value(s) cannot be broadcast, an error will be raised when
this layer's `build()` method is called.
variance: The variance value(s) to use during normalization. The passed
value(s) will be broadcast to the shape of the kept axes above;
if the value(s) cannot be broadcast, an error will be raised when
this layer's `build()` method is called.
invert: If `True`, this layer will apply the inverse transformation
to its inputs: it would turn a normalized input back into its
original form.
Examples:
Calculate a global mean and variance by analyzing the dataset in `adapt()`.
>>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')
>>> input_data = np.array([1., 2., 3.], dtype='float32')
>>> layer = keras_core.layers.Normalization(axis=None)
>>> layer.adapt(adapt_data)
>>> layer(input_data)
array([-1.4142135, -0.70710677, 0.], dtype=float32)
Calculate a mean and variance for each index on the last axis.
>>> adapt_data = np.array([[0., 7., 4.],
... [2., 9., 6.],
... [0., 7., 4.],
... [2., 9., 6.]], dtype='float32')
>>> input_data = np.array([[0., 7., 4.]], dtype='float32')
>>> layer = keras_core.layers.Normalization(axis=-1)
>>> layer.adapt(adapt_data)
>>> layer(input_data)
array([-1., -1., -1.], dtype=float32)
Pass the mean and variance directly.
>>> input_data = np.array([[1.], [2.], [3.]], dtype='float32')
>>> layer = keras_core.layers.Normalization(mean=3., variance=2.)
>>> layer(input_data)
array([[-1.4142135 ],
[-0.70710677],
[ 0. ]], dtype=float32)
Use the layer to de-normalize inputs (after adapting the layer).
>>> adapt_data = np.array([[0., 7., 4.],
... [2., 9., 6.],
... [0., 7., 4.],
... [2., 9., 6.]], dtype='float32')
>>> input_data = np.array([[1., 2., 3.]], dtype='float32')
>>> layer = keras_core.layers.Normalization(axis=-1, invert=True)
>>> layer.adapt(adapt_data)
>>> layer(input_data)
array([2., 10., 8.], dtype=float32)
"""
def __init__(
self, axis=-1, mean=None, variance=None, invert=False, **kwargs
):
super().__init__(**kwargs)
# Standardize `axis` to a tuple.
if axis is None:
axis = ()
elif isinstance(axis, int):
axis = (axis,)
else:
axis = tuple(axis)
self.axis = axis
# Set `mean` and `variance` if passed.
if (mean is not None) != (variance is not None):
raise ValueError(
"When setting values directly, both `mean` and `variance` "
f"must be set. Received: mean={mean} and variance={variance}"
)
self.input_mean = mean
self.input_variance = variance
self.invert = invert
self.supports_masking = True
def build(self, input_shape):
ndim = len(input_shape)
self._build_input_shape = input_shape
if any(a < -ndim or a >= ndim for a in self.axis):
raise ValueError(
"All `axis` values must be in the range [-ndim, ndim). "
f"Received inputs with ndim={ndim}, while axis={self.axis}"
)
# Axes to be kept, replacing negative values with positive equivalents.
# Sorted to avoid transposing axes.
self._keep_axis = tuple(
sorted([d if d >= 0 else d + ndim for d in self.axis])
)
# All axes to be kept should have known shape.
for d in self._keep_axis:
if input_shape[d] is None:
raise ValueError(
"All `axis` values to be kept must have a known shape. "
f"Received axis={self.axis}, "
f"inputs.shape={input_shape}, "
f"with unknown axis at index {d}"
)
# Axes to be reduced.
self._reduce_axis = tuple(
d for d in range(ndim) if d not in self._keep_axis
)
# 1 if an axis should be reduced, 0 otherwise.
self._reduce_axis_mask = [
0 if d in self._keep_axis else 1 for d in range(ndim)
]
# Broadcast any reduced axes.
self._broadcast_shape = [
input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
]
mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
self._mean_and_var_shape = mean_and_var_shape
if self.input_mean is None:
self.adapt_mean = self.add_weight(
name="mean",
shape=mean_and_var_shape,
dtype=self.compute_dtype,
initializer="zeros",
trainable=False,
)
self.adapt_variance = self.add_weight(
name="variance",
shape=mean_and_var_shape,
dtype=self.compute_dtype,
initializer="ones",
trainable=False,
)
self.built = True
self.finalize_state()
else:
# In the no adapt case, make constant tensors for mean and variance
# with proper broadcast shape for use during call.
mean = ops.convert_to_tensor(self.input_mean)
variance = ops.convert_to_tensor(self.input_variance)
mean = ops.reshape(mean, self._broadcast_shape)
variance = ops.reshape(variance, self._broadcast_shape)
self.mean = ops.cast(mean, dtype=self.compute_dtype)
self.variance = ops.cast(variance, dtype=self.compute_dtype)
self.built = True
def adapt(self, data):
"""Computes the mean and variance of values in a dataset.
Calling `adapt()` on a `Normalization` layer is an alternative to
passing in `mean` and `variance` arguments during layer construction. A
`Normalization` layer should always either be adapted over a dataset or
passed `mean` and `variance`.
During `adapt()`, the layer will compute a `mean` and `variance`
separately for each position in each axis specified by the `axis`
argument. To calculate a single `mean` and `variance` over the input
data, simply pass `axis=None` to the layer.
Arg:
data: The data to train on. It can be passed either as a
`tf.data.Dataset`, as a NumPy array, or as a backend-native
eager tensor.
If a dataset, *it must be batched*. Keras will assume that the
data is batched, and if that assumption doesn't hold, the mean
and variance may be incorrectly computed.
"""
if isinstance(data, np.ndarray) or backend.is_tensor(data):
input_shape = data.shape
elif isinstance(data, tf.data.Dataset):
input_shape = tuple(data.element_spec.shape)
if len(input_shape) == 1:
# Batch dataset if it isn't batched
data = data.batch(128)
input_shape = tuple(data.element_spec.shape)
if not self.built:
self.build(input_shape)
else:
for d in self._keep_axis:
if input_shape[d] != self._build_input_shape[d]:
raise ValueError(
"The layer was built with "
f"input_shape={self._build_input_shape}, "
"but adapt() is being called with data with "
f"an incompatible shape, data.shape={input_shape}"
)
if isinstance(data, np.ndarray):
total_mean = np.mean(data, axis=self._reduce_axis)
total_var = np.var(data, axis=self._reduce_axis)
elif backend.is_tensor(data):
total_mean = ops.mean(data, axis=self._reduce_axis)
total_var = ops.var(data, axis=self._reduce_axis)
elif isinstance(data, tf.data.Dataset):
total_mean = ops.zeros(self._mean_and_var_shape)
total_var = ops.zeros(self._mean_and_var_shape)
total_count = 0
for batch in data:
batch = backend.convert_to_tensor(
batch, dtype=self.compute_dtype
)
batch_mean = ops.mean(batch, axis=self._reduce_axis)
batch_var = ops.var(batch, axis=self._reduce_axis)
if self._reduce_axis:
batch_reduce_shape = (
batch.shape[d] for d in self._reduce_axis
)
batch_count = math.prod(batch_reduce_shape)
else:
batch_count = 1
total_count += batch_count
batch_weight = float(batch_count) / total_count
existing_weight = 1.0 - batch_weight
new_total_mean = (
total_mean * existing_weight + batch_mean * batch_weight
)
# The variance is computed using the lack-of-fit sum of squares
# formula (see
# https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
total_var = (
total_var + (total_mean - new_total_mean) ** 2
) * existing_weight + (
batch_var + (batch_mean - new_total_mean) ** 2
) * batch_weight
total_mean = new_total_mean
self.adapt_mean.assign(total_mean)
self.adapt_variance.assign(total_var)
self.finalize_state()
def finalize_state(self):
if self.input_mean is not None or not self.built:
return
# In the adapt case, we make constant tensors for mean and variance with
# proper broadcast shape and dtype each time `finalize_state` is called.
self.mean = ops.reshape(self.adapt_mean, self._broadcast_shape)
self.mean = ops.cast(self.mean, self.compute_dtype)
self.variance = ops.reshape(self.adapt_variance, self._broadcast_shape)
self.variance = ops.cast(self.variance, self.compute_dtype)
def call(self, inputs):
if self.invert:
return self.mean + (
inputs * ops.maximum(ops.sqrt(self.variance), backend.epsilon())
)
else:
return (inputs - self.mean) / ops.maximum(
ops.sqrt(self.variance), backend.epsilon()
)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super().get_config()
config.update(
{
"axis": self.axis,
"invert": self.invert,
"mean": np.array(self.input_mean).tolist(),
"variance": np.array(self.input_variance).tolist(),
}
)
return config
def load_own_variables(self, store):
# Ensure that we call finalize_state after variable loading.
super().load_own_variables(store)
self.finalize_state()