514 lines
17 KiB
Python
514 lines
17 KiB
Python
import math
|
|
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.backend import random
|
|
from keras_core.initializers.initializer import Initializer
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.RandomNormal")
|
|
class RandomNormal(Initializer):
|
|
"""Random normal initializer.
|
|
|
|
Draws samples from a normal distribution for given parameters.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
mean: A python scalar or a scalar keras tensor. Mean of the random values to
|
|
generate.
|
|
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
|
|
random values to generate.
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
"""
|
|
|
|
def __init__(self, mean=0.0, stddev=1.0, seed=None):
|
|
self.mean = mean
|
|
self.stddev = stddev
|
|
self.seed = seed or random.make_default_seed()
|
|
super().__init__()
|
|
|
|
def __call__(self, shape, dtype=None):
|
|
return random.normal(
|
|
shape=shape,
|
|
mean=self.mean,
|
|
stddev=self.stddev,
|
|
seed=self.seed,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"mean": self.mean, "stddev": self.stddev, "seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.RandomUniform")
|
|
class RandomUniform(Initializer):
|
|
"""Random uniform initializer.
|
|
|
|
Draws samples from a uniform distribution for given parameters.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
minval: A python scalar or a scalar keras tensor. Lower bound of the range of
|
|
random values to generate (inclusive).
|
|
maxval: A python scalar or a scalar keras tensor. Upper bound of the range of
|
|
random values to generate (exclusive).
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
"""
|
|
|
|
def __init__(self, minval=0.0, maxval=1.0, seed=None):
|
|
self.minval = minval
|
|
self.maxval = maxval
|
|
self.seed = seed or random.make_default_seed()
|
|
super().__init__()
|
|
|
|
def __call__(self, shape, dtype=None):
|
|
return random.uniform(
|
|
shape=shape,
|
|
minval=self.minval,
|
|
maxval=self.maxval,
|
|
seed=self.seed,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"minval": self.minval, "maxval": self.maxval, "seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.VarianceScaling")
|
|
class VarianceScaling(Initializer):
|
|
"""Initializer that adapts its scale to the shape of its input tensors.
|
|
|
|
With `distribution="truncated_normal" or "untruncated_normal"`, samples are
|
|
drawn from a truncated/untruncated normal distribution with a mean of zero
|
|
and a standard deviation (after truncation, if used) `stddev = sqrt(scale /
|
|
n)`, where `n` is:
|
|
|
|
- number of input units in the weight tensor, if `mode="fan_in"`
|
|
- number of output units, if `mode="fan_out"`
|
|
- average of the numbers of input and output units, if `mode="fan_avg"`
|
|
|
|
With `distribution="uniform"`, samples are drawn from a uniform distribution
|
|
within `[-limit, limit]`, where `limit = sqrt(3 * scale / n)`.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = VarianceScaling(scale=0.1, mode='fan_in', distribution='uniform')
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = VarianceScaling(scale=0.1, mode='fan_in', distribution='uniform')
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
scale: Scaling factor (positive float).
|
|
mode: One of `"fan_in"`, `"fan_out"`, `"fan_avg"`.
|
|
distribution: Random distribution to use.
|
|
One of `"truncated_normal"`, `"untruncated_normal"`, or `"uniform"`.
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scale=1.0,
|
|
mode="fan_in",
|
|
distribution="truncated_normal",
|
|
seed=None,
|
|
):
|
|
if scale <= 0.0:
|
|
raise ValueError(
|
|
"Argument `scale` must be positive float. "
|
|
f"Received: scale={scale}"
|
|
)
|
|
allowed_modes = {"fan_in", "fan_out", "fan_avg"}
|
|
if mode not in allowed_modes:
|
|
raise ValueError(
|
|
f"Invalid `mode` argument: {mode}. "
|
|
f"Please use one of {allowed_modes}"
|
|
)
|
|
distribution = distribution.lower()
|
|
if distribution == "normal":
|
|
distribution = "truncated_normal"
|
|
allowed_distributions = {
|
|
"uniform",
|
|
"truncated_normal",
|
|
"untruncated_normal",
|
|
}
|
|
if distribution not in allowed_distributions:
|
|
raise ValueError(
|
|
f"Invalid `distribution` argument: {distribution}."
|
|
f"Please use one of {allowed_distributions}"
|
|
)
|
|
self.scale = scale
|
|
self.mode = mode
|
|
self.distribution = distribution
|
|
self.seed = seed or random.make_default_seed()
|
|
|
|
def __call__(self, shape, dtype=None):
|
|
"""Returns a tensor object initialized as specified by the initializer.
|
|
|
|
Args:
|
|
shape: Shape of the tensor.
|
|
dtype: Optional dtype of the tensor. Only floating point types are
|
|
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
|
which default to `float32` unless you configured it otherwise (via
|
|
`tf.keras.backend.set_floatx(float_dtype)`)
|
|
"""
|
|
scale = self.scale
|
|
fan_in, fan_out = compute_fans(shape)
|
|
if self.mode == "fan_in":
|
|
scale /= max(1.0, fan_in)
|
|
elif self.mode == "fan_out":
|
|
scale /= max(1.0, fan_out)
|
|
else:
|
|
scale /= max(1.0, (fan_in + fan_out) / 2.0)
|
|
if self.distribution == "truncated_normal":
|
|
stddev = math.sqrt(scale) / 0.87962566103423978
|
|
return random.truncated_normal(
|
|
shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed
|
|
)
|
|
elif self.distribution == "untruncated_normal":
|
|
stddev = math.sqrt(scale)
|
|
return random.normal(
|
|
shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed
|
|
)
|
|
else:
|
|
limit = math.sqrt(3.0 * scale)
|
|
return random.uniform(
|
|
shape, minval=-limit, maxval=limit, dtype=dtype, seed=self.seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {
|
|
"scale": self.scale,
|
|
"mode": self.mode,
|
|
"distribution": self.distribution,
|
|
"seed": self.seed,
|
|
}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.GlorotUniform")
|
|
class GlorotUniform(VarianceScaling):
|
|
"""The Glorot uniform initializer, also called Xavier uniform initializer.
|
|
|
|
Draws samples from a uniform distribution within `[-limit, limit]`, where
|
|
`limit = sqrt(6 / (fan_in + fan_out))` (`fan_in` is the number of input
|
|
units in the weight tensor and `fan_out` is the number of output units).
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = GlorotUniform()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = GlorotUniform()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.GlorotNormal")
|
|
class GlorotNormal(VarianceScaling):
|
|
"""The Glorot normal initializer, also called Xavier normal initializer.
|
|
|
|
Draws samples from a truncated normal distribution centered on 0 with
|
|
`stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of
|
|
input units in the weight tensor and `fan_out` is the number of output units
|
|
in the weight tensor.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = GlorotNormal()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = GlorotNormal()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=1.0,
|
|
mode="fan_avg",
|
|
distribution="truncated_normal",
|
|
seed=seed,
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.LecunNormal")
|
|
class LecunNormal(VarianceScaling):
|
|
"""Lecun normal initializer.
|
|
|
|
Initializers allow you to pre-specify an initialization strategy, encoded in
|
|
the Initializer object, without knowing the shape and dtype of the variable
|
|
being initialized.
|
|
|
|
Draws samples from a truncated normal distribution centered on 0 with
|
|
`stddev = sqrt(1 / fan_in)` where `fan_in` is the number of input units in
|
|
the weight tensor.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = LecunNormal()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = LecunNormal()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=1.0, mode="fan_in", distribution="truncated_normal", seed=seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.LecunUniform")
|
|
class LecunUniform(VarianceScaling):
|
|
"""Lecun uniform initializer.
|
|
|
|
Draws samples from a uniform distribution within `[-limit, limit]`, where
|
|
`limit = sqrt(3 / fan_in)` (`fan_in` is the number of input units in the
|
|
weight tensor).
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = LecunUniform()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = LecunUniform()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=1.0, mode="fan_in", distribution="uniform", seed=seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.HeNormal")
|
|
class HeNormal(VarianceScaling):
|
|
"""He normal initializer.
|
|
|
|
It draws samples from a truncated normal distribution centered on 0 with
|
|
`stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in
|
|
the weight tensor.
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = HeNormal()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = HeNormal()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [He et al., 2015](https://arxiv.org/abs/1502.01852)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=2.0, mode="fan_in", distribution="truncated_normal", seed=seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
@keras_core_export("keras_core.initializers.HeUniform")
|
|
class HeUniform(VarianceScaling):
|
|
"""He uniform variance scaling initializer.
|
|
|
|
Draws samples from a uniform distribution within `[-limit, limit]`, where
|
|
`limit = sqrt(6 / fan_in)` (`fan_in` is the number of input units in the
|
|
weight tensor).
|
|
|
|
Examples:
|
|
|
|
>>> # Standalone usage:
|
|
>>> initializer = HeUniform()
|
|
>>> values = initializer(shape=(2, 2))
|
|
|
|
>>> # Usage in a Keras layer:
|
|
>>> initializer = HeUniform()
|
|
>>> layer = Dense(3, kernel_initializer=initializer)
|
|
|
|
Args:
|
|
seed: A Python integer or instance of
|
|
`keras_core.backend.SeedGenerator`.
|
|
Used to make the behavior of the initializer
|
|
deterministic. Note that an initializer seeded with an integer
|
|
or None (unseeded) will produce the same random values
|
|
across multiple calls. To get different random values
|
|
across multiple calls, use as seed an instance
|
|
of `keras_core.backend.SeedGenerator`.
|
|
|
|
Reference:
|
|
|
|
- [He et al., 2015](https://arxiv.org/abs/1502.01852)
|
|
"""
|
|
|
|
def __init__(self, seed=None):
|
|
super().__init__(
|
|
scale=2.0, mode="fan_in", distribution="uniform", seed=seed
|
|
)
|
|
|
|
def get_config(self):
|
|
return {"seed": self.seed}
|
|
|
|
|
|
def compute_fans(shape):
|
|
"""Computes the number of input and output units for a weight shape.
|
|
|
|
Args:
|
|
shape: Integer shape tuple.
|
|
|
|
Returns:
|
|
A tuple of integer scalars: `(fan_in, fan_out)`.
|
|
"""
|
|
shape = tuple(shape)
|
|
if len(shape) < 1: # Just to avoid errors for constants.
|
|
fan_in = fan_out = 1
|
|
elif len(shape) == 1:
|
|
fan_in = fan_out = shape[0]
|
|
elif len(shape) == 2:
|
|
fan_in = shape[0]
|
|
fan_out = shape[1]
|
|
else:
|
|
# Assuming convolution kernels (2D, 3D, or more).
|
|
# kernel shape: (..., input_depth, depth)
|
|
receptive_field_size = 1
|
|
for dim in shape[:-2]:
|
|
receptive_field_size *= dim
|
|
fan_in = shape[-2] * receptive_field_size
|
|
fan_out = shape[-1] * receptive_field_size
|
|
return int(fan_in), int(fan_out)
|