fix documentation in the random module (#141)

This commit is contained in:
Aakash Kumar Nain 2023-05-11 22:03:40 +05:30 committed by Francois Chollet
parent 13039c01b0
commit 5d81faef4d
2 changed files with 28 additions and 6 deletions

@ -7,7 +7,7 @@ from keras_core.backend.config import floatx
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the normal distribution.
"""Draw random samples from a normal (Gaussian) distribution.
Args:
shape: The shape of the random values to generate.
@ -34,7 +34,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Produce random number based on the uniform distribution.
"""Draw samples from a uniform distribution.
The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range,
while the upper bound `maxval` is excluded.
For floats, the default range is `[0, 1)`. For ints, at least `maxval`
must be specified explicitly.
Args:
shape: The shape of the random values to generate.
@ -63,7 +70,11 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the truncated normal distribution.
"""Draw samples from a truncated normal distribution.
The values are drawn from a normal distribution with specified mean and
standard deviation, discarding and re-drawing any samples that are more
than two standard deviations from the mean.
Args:
shape: The shape of the random values to generate.

@ -12,7 +12,7 @@ def tf_draw_seed(seed):
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the normal distribution.
"""Draw random samples from a normal (Gaussian) distribution.
Args:
shape: The shape of the random values to generate.
@ -40,7 +40,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Produce random number based on the uniform distribution.
"""Draw samples from a uniform distribution.
The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range,
while the upper bound `maxval` is excluded.
For floats, the default range is `[0, 1)`. For ints, at least `maxval`
must be specified explicitly.
Args:
shape: The shape of the random values to generate.
@ -73,7 +80,11 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the truncated normal distribution.
"""Draw samples from a truncated normal distribution.
The values are drawn from a normal distribution with specified mean and
standard deviation, discarding and re-drawing any samples that are more
than two standard deviations from the mean.
Args:
shape: The shape of the random values to generate.