2023-04-09 19:21:45 +00:00
|
|
|
import tensorflow as tf
|
2023-04-12 18:31:58 +00:00
|
|
|
|
2023-04-19 01:45:30 +00:00
|
|
|
from keras_core.backend.config import floatx
|
2023-05-15 22:01:46 +00:00
|
|
|
from keras_core.random.seed_generator import SeedGenerator
|
|
|
|
from keras_core.random.seed_generator import draw_seed
|
|
|
|
from keras_core.random.seed_generator import make_default_seed
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def tf_draw_seed(seed):
|
|
|
|
# TF ops only accept int32/64 seeds but our base seed is uint32.
|
2023-04-26 05:43:11 +00:00
|
|
|
return tf.cast(draw_seed(seed), dtype="int32")
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
2023-05-11 16:33:40 +00:00
|
|
|
"""Draw random samples from a normal (Gaussian) distribution.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: The shape of the random values to generate.
|
2023-04-23 15:28:21 +00:00
|
|
|
mean: Floats, defaults to 0. Mean of the random values to generate.
|
|
|
|
stddev: Floats, defaults to 1. Standard deviation of the random values
|
2023-04-09 19:21:45 +00:00
|
|
|
to generate.
|
|
|
|
dtype: Optional dtype of the tensor. Only floating point types are
|
|
|
|
supported. If not specified, `keras.backend.floatx()` is used,
|
2023-04-23 15:28:21 +00:00
|
|
|
which defaults to `float32` unless you configured it otherwise (via
|
2023-04-09 19:21:45 +00:00
|
|
|
`keras.backend.set_floatx(float_dtype)`).
|
2023-04-23 03:20:56 +00:00
|
|
|
seed: A Python integer or instance of
|
|
|
|
`keras_core.backend.SeedGenerator`.
|
|
|
|
Used to make the behavior of the initializer
|
|
|
|
deterministic. Note that an initializer seeded with an integer
|
|
|
|
or None (unseeded) will produce the same random values
|
|
|
|
across multiple calls. To get different random values
|
|
|
|
across multiple calls, use as seed an instance
|
|
|
|
of `keras_core.backend.SeedGenerator`.
|
2023-04-09 19:21:45 +00:00
|
|
|
"""
|
|
|
|
dtype = dtype or floatx()
|
|
|
|
seed = tf_draw_seed(seed)
|
2023-04-26 05:43:11 +00:00
|
|
|
return tf.random.stateless_normal(
|
|
|
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-23 15:28:21 +00:00
|
|
|
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
|
2023-05-11 16:33:40 +00:00
|
|
|
"""Draw samples from a uniform distribution.
|
|
|
|
|
|
|
|
The generated values follow a uniform distribution in the range
|
|
|
|
`[minval, maxval)`. The lower bound `minval` is included in the range,
|
|
|
|
while the upper bound `maxval` is excluded.
|
|
|
|
|
|
|
|
For floats, the default range is `[0, 1)`. For ints, at least `maxval`
|
|
|
|
must be specified explicitly.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: The shape of the random values to generate.
|
2023-04-23 15:28:21 +00:00
|
|
|
minval: Floats, defaults to 0. Lower bound of the range of
|
2023-04-09 19:21:45 +00:00
|
|
|
random values to generate (inclusive).
|
2023-04-23 15:28:41 +00:00
|
|
|
maxval: Floats, defaults to 1. Upper bound of the range of
|
2023-04-09 19:21:45 +00:00
|
|
|
random values to generate (exclusive).
|
|
|
|
dtype: Optional dtype of the tensor. Only floating point types are
|
|
|
|
supported. If not specified, `keras.backend.floatx()` is used,
|
2023-04-23 15:28:21 +00:00
|
|
|
which defaults to `float32` unless you configured it otherwise (via
|
2023-04-09 19:21:45 +00:00
|
|
|
`keras.backend.set_floatx(float_dtype)`)
|
2023-04-23 03:20:56 +00:00
|
|
|
seed: A Python integer or instance of
|
|
|
|
`keras_core.backend.SeedGenerator`.
|
|
|
|
Used to make the behavior of the initializer
|
|
|
|
deterministic. Note that an initializer seeded with an integer
|
|
|
|
or None (unseeded) will produce the same random values
|
|
|
|
across multiple calls. To get different random values
|
|
|
|
across multiple calls, use as seed an instance
|
|
|
|
of `keras_core.backend.SeedGenerator`.
|
2023-04-09 19:21:45 +00:00
|
|
|
"""
|
|
|
|
dtype = dtype or floatx()
|
|
|
|
seed = tf_draw_seed(seed)
|
2023-04-26 05:43:11 +00:00
|
|
|
return tf.random.stateless_uniform(
|
|
|
|
shape=shape,
|
|
|
|
minval=minval,
|
|
|
|
maxval=maxval,
|
|
|
|
dtype=dtype,
|
|
|
|
seed=seed,
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
2023-05-11 16:33:40 +00:00
|
|
|
"""Draw samples from a truncated normal distribution.
|
|
|
|
|
|
|
|
The values are drawn from a normal distribution with specified mean and
|
|
|
|
standard deviation, discarding and re-drawing any samples that are more
|
|
|
|
than two standard deviations from the mean.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: The shape of the random values to generate.
|
2023-04-23 15:28:21 +00:00
|
|
|
mean: Floats, defaults to 0. Mean of the random values to generate.
|
|
|
|
stddev: Floats, defaults to 1. Standard deviation of the random values
|
2023-04-09 19:21:45 +00:00
|
|
|
to generate.
|
|
|
|
dtype: Optional dtype of the tensor. Only floating point types are
|
|
|
|
supported. If not specified, `keras.backend.floatx()` is used,
|
2023-04-23 15:28:21 +00:00
|
|
|
which defaults to `float32` unless you configured it otherwise (via
|
2023-04-09 19:21:45 +00:00
|
|
|
`keras.backend.set_floatx(float_dtype)`)
|
2023-04-23 03:20:56 +00:00
|
|
|
seed: A Python integer or instance of
|
|
|
|
`keras_core.backend.SeedGenerator`.
|
|
|
|
Used to make the behavior of the initializer
|
|
|
|
deterministic. Note that an initializer seeded with an integer
|
|
|
|
or None (unseeded) will produce the same random values
|
|
|
|
across multiple calls. To get different random values
|
|
|
|
across multiple calls, use as seed an instance
|
|
|
|
of `keras_core.backend.SeedGenerator`.
|
2023-04-09 19:21:45 +00:00
|
|
|
"""
|
|
|
|
dtype = dtype or floatx()
|
|
|
|
seed = tf_draw_seed(seed)
|
2023-04-26 05:43:11 +00:00
|
|
|
return tf.random.stateless_truncated_normal(
|
|
|
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def dropout(inputs, rate, noise_shape=None, seed=None):
|
|
|
|
seed = tf_draw_seed(seed)
|
2023-04-25 21:34:53 +00:00
|
|
|
return tf.nn.experimental.stateless_dropout(
|
|
|
|
inputs,
|
|
|
|
rate=rate,
|
|
|
|
noise_shape=noise_shape,
|
|
|
|
seed=seed,
|
|
|
|
)
|