Add TruncatedNormal initializer
This commit is contained in:
parent
1181f444f2
commit
de595e39a3
@ -18,7 +18,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`).
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = draw_seed(seed)
|
||||
@ -39,7 +46,14 @@ def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`)
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = draw_seed(seed)
|
||||
@ -60,7 +74,14 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`)
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = draw_seed(seed)
|
||||
|
@ -23,7 +23,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`).
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
@ -45,7 +52,14 @@ def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`)
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
@ -70,7 +84,14 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
supported. If not specified, `keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise (via
|
||||
`keras.backend.set_floatx(float_dtype)`)
|
||||
seed: TODO
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
|
@ -13,6 +13,7 @@ from keras_core.initializers.random_initializers import LecunNormal
|
||||
from keras_core.initializers.random_initializers import LecunUniform
|
||||
from keras_core.initializers.random_initializers import RandomNormal
|
||||
from keras_core.initializers.random_initializers import RandomUniform
|
||||
from keras_core.initializers.random_initializers import TruncatedNormal
|
||||
from keras_core.initializers.random_initializers import VarianceScaling
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.utils.naming import to_snake_case
|
||||
@ -29,6 +30,7 @@ ALL_OBJECTS = {
|
||||
LecunNormal,
|
||||
LecunUniform,
|
||||
RandomNormal,
|
||||
TruncatedNormal,
|
||||
RandomUniform,
|
||||
VarianceScaling,
|
||||
}
|
||||
|
@ -55,6 +55,59 @@ class RandomNormal(Initializer):
|
||||
return {"mean": self.mean, "stddev": self.stddev, "seed": self.seed}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.initializers.TruncatedNormal")
|
||||
class TruncatedNormal(Initializer):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
|
||||
The values generated are similar to values from a
|
||||
`RandomNormal` initializer, except that values more
|
||||
than two standard deviations from the mean are
|
||||
discarded and re-drawn.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # Standalone usage:
|
||||
>>> initializer = TruncatedNormal(mean=0., stddev=1.)
|
||||
>>> values = initializer(shape=(2, 2))
|
||||
|
||||
>>> # Usage in a Keras layer:
|
||||
>>> initializer = TruncatedNormal(mean=0., stddev=1.)
|
||||
>>> layer = Dense(3, kernel_initializer=initializer)
|
||||
|
||||
Args:
|
||||
mean: A python scalar or a scalar keras tensor. Mean of the random values to
|
||||
generate.
|
||||
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
|
||||
random values to generate.
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.SeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.SeedGenerator`.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=1.0, seed=None):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed or random.make_default_seed()
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return random.truncated_normal(
|
||||
shape=shape,
|
||||
mean=self.mean,
|
||||
stddev=self.stddev,
|
||||
seed=self.seed,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean, "stddev": self.stddev, "seed": self.seed}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.initializers.RandomUniform")
|
||||
class RandomUniform(Initializer):
|
||||
"""Random uniform initializer.
|
||||
|
Loading…
Reference in New Issue
Block a user