Add TruncatedNormal initializer

This commit is contained in:
Francois Chollet 2023-04-22 20:20:56 -07:00
parent 1181f444f2
commit de595e39a3
4 changed files with 103 additions and 6 deletions

@ -18,7 +18,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`).
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = draw_seed(seed)
@ -39,7 +46,14 @@ def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = draw_seed(seed)
@ -60,7 +74,14 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = draw_seed(seed)

@ -23,7 +23,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`).
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
@ -45,7 +52,14 @@ def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
@ -70,7 +84,14 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)

@ -13,6 +13,7 @@ from keras_core.initializers.random_initializers import LecunNormal
from keras_core.initializers.random_initializers import LecunUniform
from keras_core.initializers.random_initializers import RandomNormal
from keras_core.initializers.random_initializers import RandomUniform
from keras_core.initializers.random_initializers import TruncatedNormal
from keras_core.initializers.random_initializers import VarianceScaling
from keras_core.saving import serialization_lib
from keras_core.utils.naming import to_snake_case
@ -29,6 +30,7 @@ ALL_OBJECTS = {
LecunNormal,
LecunUniform,
RandomNormal,
TruncatedNormal,
RandomUniform,
VarianceScaling,
}

@ -55,6 +55,59 @@ class RandomNormal(Initializer):
return {"mean": self.mean, "stddev": self.stddev, "seed": self.seed}
@keras_core_export("keras_core.initializers.TruncatedNormal")
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
The values generated are similar to values from a
`RandomNormal` initializer, except that values more
than two standard deviations from the mean are
discarded and re-drawn.
Examples:
>>> # Standalone usage:
>>> initializer = TruncatedNormal(mean=0., stddev=1.)
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = TruncatedNormal(mean=0., stddev=1.)
>>> layer = Dense(3, kernel_initializer=initializer)
Args:
mean: A python scalar or a scalar keras tensor. Mean of the random values to
generate.
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
random values to generate.
seed: A Python integer or instance of
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
def __init__(self, mean=0.0, stddev=1.0, seed=None):
self.mean = mean
self.stddev = stddev
self.seed = seed or random.make_default_seed()
super().__init__()
def __call__(self, shape, dtype=None):
return random.truncated_normal(
shape=shape,
mean=self.mean,
stddev=self.stddev,
seed=self.seed,
dtype=dtype,
)
def get_config(self):
return {"mean": self.mean, "stddev": self.stddev, "seed": self.seed}
@keras_core_export("keras_core.initializers.RandomUniform")
class RandomUniform(Initializer):
"""Random uniform initializer.