From 5d81faef4de70fc9f9ee93e201a1e27cb444a129 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Thu, 11 May 2023 22:03:40 +0530 Subject: [PATCH] fix documentation in the random module (#141) --- keras_core/backend/jax/random.py | 17 ++++++++++++++--- keras_core/backend/tensorflow/random.py | 17 ++++++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/keras_core/backend/jax/random.py b/keras_core/backend/jax/random.py index 7c753c4bf..519e680af 100644 --- a/keras_core/backend/jax/random.py +++ b/keras_core/backend/jax/random.py @@ -7,7 +7,7 @@ from keras_core.backend.config import floatx def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - """Produce random number based on the normal distribution. + """Draw random samples from a normal (Gaussian) distribution. Args: shape: The shape of the random values to generate. @@ -34,7 +34,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): - """Produce random number based on the uniform distribution. + """Draw samples from a uniform distribution. + + The generated values follow a uniform distribution in the range + `[minval, maxval)`. The lower bound `minval` is included in the range, + while the upper bound `maxval` is excluded. + + For floats, the default range is `[0, 1)`. For ints, at least `maxval` + must be specified explicitly. Args: shape: The shape of the random values to generate. @@ -63,7 +70,11 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - """Produce random number based on the truncated normal distribution. + """Draw samples from a truncated normal distribution. + + The values are drawn from a normal distribution with specified mean and + standard deviation, discarding and re-drawing any samples that are more + than two standard deviations from the mean. Args: shape: The shape of the random values to generate. diff --git a/keras_core/backend/tensorflow/random.py b/keras_core/backend/tensorflow/random.py index c0dea4268..0fe34c79f 100644 --- a/keras_core/backend/tensorflow/random.py +++ b/keras_core/backend/tensorflow/random.py @@ -12,7 +12,7 @@ def tf_draw_seed(seed): def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - """Produce random number based on the normal distribution. + """Draw random samples from a normal (Gaussian) distribution. Args: shape: The shape of the random values to generate. @@ -40,7 +40,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): - """Produce random number based on the uniform distribution. + """Draw samples from a uniform distribution. + + The generated values follow a uniform distribution in the range + `[minval, maxval)`. The lower bound `minval` is included in the range, + while the upper bound `maxval` is excluded. + + For floats, the default range is `[0, 1)`. For ints, at least `maxval` + must be specified explicitly. Args: shape: The shape of the random values to generate. @@ -73,7 +80,11 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - """Produce random number based on the truncated normal distribution. + """Draw samples from a truncated normal distribution. + + The values are drawn from a normal distribution with specified mean and + standard deviation, discarding and re-drawing any samples that are more + than two standard deviations from the mean. Args: shape: The shape of the random values to generate.