313 lines
13 KiB
Python
313 lines
13 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from keras_core import backend
|
|
from keras_core import operations as ops
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
class Normalization(Layer):
|
|
"""A preprocessing layer that normalizes continuous features.
|
|
|
|
This layer will shift and scale inputs into a distribution centered around
|
|
0 with standard deviation 1. It accomplishes this by precomputing the mean
|
|
and variance of the data, and calling `(input - mean) / sqrt(var)` at
|
|
runtime.
|
|
|
|
The mean and variance values for the layer must be either supplied on
|
|
construction or learned via `adapt()`. `adapt()` will compute the mean and
|
|
variance of the data and store them as the layer's weights. `adapt()` should
|
|
be called before `fit()`, `evaluate()`, or `predict()`.
|
|
|
|
Args:
|
|
axis: Integer, tuple of integers, or None. The axis or axes that should
|
|
have a separate mean and variance for each index in the shape.
|
|
For example, if shape is `(None, 5)` and `axis=1`, the layer will
|
|
track 5 separate mean and variance values for the last axis.
|
|
If `axis` is set to `None`, the layer will normalize
|
|
all elements in the input by a scalar mean and variance.
|
|
When `-1`, the last axis of the input is assumed to be a
|
|
feature dimension and is normalized per index.
|
|
Note that in the specific case of batched scalar inputs where
|
|
the only axis is the batch axis, the default will normalize
|
|
each index in the batch separately.
|
|
In this case, consider passing `axis=None`. Defaults to `-1`.
|
|
mean: The mean value(s) to use during normalization. The passed value(s)
|
|
will be broadcast to the shape of the kept axes above;
|
|
if the value(s) cannot be broadcast, an error will be raised when
|
|
this layer's `build()` method is called.
|
|
variance: The variance value(s) to use during normalization. The passed
|
|
value(s) will be broadcast to the shape of the kept axes above;
|
|
if the value(s) cannot be broadcast, an error will be raised when
|
|
this layer's `build()` method is called.
|
|
invert: If `True`, this layer will apply the inverse transformation
|
|
to its inputs: it would turn a normalized input back into its
|
|
original form.
|
|
|
|
Examples:
|
|
|
|
Calculate a global mean and variance by analyzing the dataset in `adapt()`.
|
|
|
|
>>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')
|
|
>>> input_data = np.array([1., 2., 3.], dtype='float32')
|
|
>>> layer = keras_core.layers.Normalization(axis=None)
|
|
>>> layer.adapt(adapt_data)
|
|
>>> layer(input_data)
|
|
array([-1.4142135, -0.70710677, 0.], dtype=float32)
|
|
|
|
Calculate a mean and variance for each index on the last axis.
|
|
|
|
>>> adapt_data = np.array([[0., 7., 4.],
|
|
... [2., 9., 6.],
|
|
... [0., 7., 4.],
|
|
... [2., 9., 6.]], dtype='float32')
|
|
>>> input_data = np.array([[0., 7., 4.]], dtype='float32')
|
|
>>> layer = keras_core.layers.Normalization(axis=-1)
|
|
>>> layer.adapt(adapt_data)
|
|
>>> layer(input_data)
|
|
array([-1., -1., -1.], dtype=float32)
|
|
|
|
Pass the mean and variance directly.
|
|
|
|
>>> input_data = np.array([[1.], [2.], [3.]], dtype='float32')
|
|
>>> layer = keras_core.layers.Normalization(mean=3., variance=2.)
|
|
>>> layer(input_data)
|
|
array([[-1.4142135 ],
|
|
[-0.70710677],
|
|
[ 0. ]], dtype=float32)
|
|
|
|
Use the layer to de-normalize inputs (after adapting the layer).
|
|
|
|
>>> adapt_data = np.array([[0., 7., 4.],
|
|
... [2., 9., 6.],
|
|
... [0., 7., 4.],
|
|
... [2., 9., 6.]], dtype='float32')
|
|
>>> input_data = np.array([[1., 2., 3.]], dtype='float32')
|
|
>>> layer = keras_core.layers.Normalization(axis=-1, invert=True)
|
|
>>> layer.adapt(adapt_data)
|
|
>>> layer(input_data)
|
|
array([2., 10., 8.], dtype=float32)
|
|
"""
|
|
|
|
def __init__(
|
|
self, axis=-1, mean=None, variance=None, invert=False, **kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
# Standardize `axis` to a tuple.
|
|
if axis is None:
|
|
axis = ()
|
|
elif isinstance(axis, int):
|
|
axis = (axis,)
|
|
else:
|
|
axis = tuple(axis)
|
|
self.axis = axis
|
|
|
|
# Set `mean` and `variance` if passed.
|
|
if (mean is not None) != (variance is not None):
|
|
raise ValueError(
|
|
"When setting values directly, both `mean` and `variance` "
|
|
f"must be set. Received: mean={mean} and variance={variance}"
|
|
)
|
|
self.input_mean = mean
|
|
self.input_variance = variance
|
|
self.invert = invert
|
|
self.supports_masking = True
|
|
|
|
def build(self, input_shape):
|
|
ndim = len(input_shape)
|
|
self._build_input_shape = input_shape
|
|
|
|
if any(a < -ndim or a >= ndim for a in self.axis):
|
|
raise ValueError(
|
|
"All `axis` values must be in the range [-ndim, ndim). "
|
|
f"Received inputs with ndim={ndim}, while axis={self.axis}"
|
|
)
|
|
|
|
# Axes to be kept, replacing negative values with positive equivalents.
|
|
# Sorted to avoid transposing axes.
|
|
self._keep_axis = tuple(
|
|
sorted([d if d >= 0 else d + ndim for d in self.axis])
|
|
)
|
|
# All axes to be kept should have known shape.
|
|
for d in self._keep_axis:
|
|
if input_shape[d] is None:
|
|
raise ValueError(
|
|
"All `axis` values to be kept must have a known shape. "
|
|
f"Received axis={self.axis}, "
|
|
f"inputs.shape={input_shape}, "
|
|
f"with unknown axis at index {d}"
|
|
)
|
|
# Axes to be reduced.
|
|
self._reduce_axis = tuple(
|
|
d for d in range(ndim) if d not in self._keep_axis
|
|
)
|
|
# 1 if an axis should be reduced, 0 otherwise.
|
|
self._reduce_axis_mask = [
|
|
0 if d in self._keep_axis else 1 for d in range(ndim)
|
|
]
|
|
# Broadcast any reduced axes.
|
|
self._broadcast_shape = [
|
|
input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
|
|
]
|
|
mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
|
|
self._mean_and_var_shape = mean_and_var_shape
|
|
|
|
if self.input_mean is None:
|
|
self.adapt_mean = self.add_weight(
|
|
name="mean",
|
|
shape=mean_and_var_shape,
|
|
dtype=self.compute_dtype,
|
|
initializer="zeros",
|
|
trainable=False,
|
|
)
|
|
self.adapt_variance = self.add_weight(
|
|
name="variance",
|
|
shape=mean_and_var_shape,
|
|
dtype=self.compute_dtype,
|
|
initializer="ones",
|
|
trainable=False,
|
|
)
|
|
self.built = True
|
|
self.finalize_state()
|
|
else:
|
|
# In the no adapt case, make constant tensors for mean and variance
|
|
# with proper broadcast shape for use during call.
|
|
mean = ops.convert_to_tensor(self.input_mean)
|
|
variance = ops.convert_to_tensor(self.input_variance)
|
|
mean = ops.reshape(mean, self._broadcast_shape)
|
|
variance = ops.reshape(variance, self._broadcast_shape)
|
|
self.mean = ops.cast(mean, dtype=self.compute_dtype)
|
|
self.variance = ops.cast(variance, dtype=self.compute_dtype)
|
|
self.built = True
|
|
|
|
def adapt(self, data):
|
|
"""Computes the mean and variance of values in a dataset.
|
|
|
|
Calling `adapt()` on a `Normalization` layer is an alternative to
|
|
passing in `mean` and `variance` arguments during layer construction. A
|
|
`Normalization` layer should always either be adapted over a dataset or
|
|
passed `mean` and `variance`.
|
|
|
|
During `adapt()`, the layer will compute a `mean` and `variance`
|
|
separately for each position in each axis specified by the `axis`
|
|
argument. To calculate a single `mean` and `variance` over the input
|
|
data, simply pass `axis=None` to the layer.
|
|
|
|
Arg:
|
|
data: The data to train on. It can be passed either as a
|
|
`tf.data.Dataset`, as a NumPy array, or as a backend-native
|
|
eager tensor.
|
|
If a dataset, *it must be batched*. Keras will assume that the
|
|
data is batched, and if that assumption doesn't hold, the mean
|
|
and variance may be incorrectly computed.
|
|
"""
|
|
if isinstance(data, np.ndarray) or backend.is_tensor(data):
|
|
input_shape = data.shape
|
|
elif isinstance(data, tf.data.Dataset):
|
|
input_shape = tuple(data.element_spec.shape)
|
|
if len(input_shape) == 1:
|
|
# Batch dataset if it isn't batched
|
|
data = data.batch(128)
|
|
input_shape = tuple(data.element_spec.shape)
|
|
|
|
if not self.built:
|
|
self.build(input_shape)
|
|
else:
|
|
for d in self._keep_axis:
|
|
if input_shape[d] != self._build_input_shape[d]:
|
|
raise ValueError(
|
|
"The layer was built with "
|
|
f"input_shape={self._build_input_shape}, "
|
|
"but adapt() is being called with data with "
|
|
f"an incompatible shape, data.shape={input_shape}"
|
|
)
|
|
|
|
if isinstance(data, np.ndarray):
|
|
total_mean = np.mean(data, axis=self._reduce_axis)
|
|
total_var = np.var(data, axis=self._reduce_axis)
|
|
elif backend.is_tensor(data):
|
|
total_mean = ops.mean(data, axis=self._reduce_axis)
|
|
total_var = ops.var(data, axis=self._reduce_axis)
|
|
elif isinstance(data, tf.data.Dataset):
|
|
total_mean = ops.zeros(self._mean_and_var_shape)
|
|
total_var = ops.zeros(self._mean_and_var_shape)
|
|
total_count = 0
|
|
for batch in data:
|
|
batch = backend.convert_to_tensor(
|
|
batch, dtype=self.compute_dtype
|
|
)
|
|
batch_mean = ops.mean(batch, axis=self._reduce_axis)
|
|
batch_var = ops.var(batch, axis=self._reduce_axis)
|
|
if self._reduce_axis:
|
|
batch_reduce_shape = (
|
|
batch.shape[d] for d in self._reduce_axis
|
|
)
|
|
batch_count = math.prod(batch_reduce_shape)
|
|
else:
|
|
batch_count = 1
|
|
|
|
total_count += batch_count
|
|
batch_weight = float(batch_count) / total_count
|
|
existing_weight = 1.0 - batch_weight
|
|
|
|
new_total_mean = (
|
|
total_mean * existing_weight + batch_mean * batch_weight
|
|
)
|
|
# The variance is computed using the lack-of-fit sum of squares
|
|
# formula (see
|
|
# https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
|
total_var = (
|
|
total_var + (total_mean - new_total_mean) ** 2
|
|
) * existing_weight + (
|
|
batch_var + (batch_mean - new_total_mean) ** 2
|
|
) * batch_weight
|
|
total_mean = new_total_mean
|
|
|
|
self.adapt_mean.assign(total_mean)
|
|
self.adapt_variance.assign(total_var)
|
|
self.finalize_state()
|
|
|
|
def finalize_state(self):
|
|
if self.input_mean is not None or not self.built:
|
|
return
|
|
|
|
# In the adapt case, we make constant tensors for mean and variance with
|
|
# proper broadcast shape and dtype each time `finalize_state` is called.
|
|
self.mean = ops.reshape(self.adapt_mean, self._broadcast_shape)
|
|
self.mean = ops.cast(self.mean, self.compute_dtype)
|
|
self.variance = ops.reshape(self.adapt_variance, self._broadcast_shape)
|
|
self.variance = ops.cast(self.variance, self.compute_dtype)
|
|
|
|
def call(self, inputs):
|
|
if self.invert:
|
|
return self.mean + (
|
|
inputs * ops.maximum(ops.sqrt(self.variance), backend.epsilon())
|
|
)
|
|
else:
|
|
return (inputs - self.mean) / ops.maximum(
|
|
ops.sqrt(self.variance), backend.epsilon()
|
|
)
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
config.update(
|
|
{
|
|
"axis": self.axis,
|
|
"invert": self.invert,
|
|
"mean": np.array(self.input_mean).tolist(),
|
|
"variance": np.array(self.input_variance).tolist(),
|
|
}
|
|
)
|
|
return config
|
|
|
|
def load_own_variables(self, store):
|
|
# Ensure that we call finalize_state after variable loading.
|
|
super().load_own_variables(store)
|
|
self.finalize_state()
|