keras/keras_core/backend/tensorflow/random.py
Chen Qian eabdb87f9f Add some numpy ops (#1)
* Add numpy ops (initial batch) and some config

* Add unit test

* fix call

* Revert "fix call"

This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0.

* full unit test coverage

* fix setup.py
2023-04-12 11:31:58 -07:00

88 lines
3.1 KiB
Python

import tensorflow as tf
from keras_core.backend import floatx
from keras_core.backend.random import draw_seed
def tf_draw_seed(seed):
# TF ops only accept int32/64 seeds but our base seed is uint32.
return tf.cast(draw_seed(seed), dtype="int32")
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values
to generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`).
seed: TODO
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
return tf.random.stateless_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
def uniform(shape, minval=0.0, maxval=None, dtype=None, seed=None):
"""Produce random number based on the uniform distribution.
Args:
shape: The shape of the random values to generate.
minval: Floats, default to 0. Lower bound of the range of
random values to generate (inclusive).
minval: Floats, default to None. Upper bound of the range of
random values to generate (exclusive).
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
return tf.random.stateless_uniform(
shape=shape,
minval=minval,
maxval=maxval,
dtype=dtype,
seed=seed,
)
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the truncated normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values
to generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise (via
`keras.backend.set_floatx(float_dtype)`)
seed: TODO
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
return tf.random.stateless_truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
def dropout(inputs, rate, noise_shape=None, seed=None):
seed = tf_draw_seed(seed)
return tf.nn.experimental.stateless_dropout(
inputs,
rate=rate,
noise_shape=noise_shape,
seed=seed,
)