keras/keras_core/backend/jax/random.py
Chen Qian f6df67f2d2 Add numpy module in jax/ and tensorflow/ (#13)
* Add jax/numpy and tensorflow/numpy

* refactor code

* more

* even better
2023-04-18 18:45:30 -07:00

83 lines
3.2 KiB
Python

import jax
from keras_core.backend.common.random import SeedGenerator
from keras_core.backend.common.random import draw_seed
from keras_core.backend.common.random import make_default_seed
from keras_core.backend.config import floatx
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values
to generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`).
seed: TODO
"""
dtype = dtype or floatx()
seed = draw_seed(seed)
sample = jax.random.normal(seed, shape=shape, dtype=dtype)
return sample * stddev + mean
def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
"""Produce random number based on the uniform distribution.
Args:
shape: The shape of the random values to generate.
minval: Floats, default to 0. Lower bound of the range of
random values to generate (inclusive).
minval: Floats, default to None. Upper bound of the range of
random values to generate (exclusive).
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
"""
dtype = dtype or floatx()
seed = draw_seed(seed)
return jax.random.uniform(
seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval
)
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the truncated normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values
to generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
"""
dtype = dtype or floatx()
seed = draw_seed(seed)
sample = jax.random.truncated_normal(
seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype
)
return sample * stddev + mean
def dropout(inputs, rate, noise_shape=None, seed=None):
seed = draw_seed(seed)
keep_prob = 1.0 - rate
if noise_shape is None:
noise_shape = inputs.shape
mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)
mask = jax.numpy.broadcast_to(mask, inputs.shape)
return jax.lax.select(
mask, inputs / keep_prob, jax.numpy.zeros_like(inputs)
)