keras/keras_core/random/random_test.py
Ramesh Sampath bf734b0eec Adds Random Unit Tests (#174)
* Add torch backend - random

* Add Random Unit Test

* Adds Random Unit Test

* Random Test

* Adds Random Test
2023-05-16 00:21:36 -05:00

64 lines
2.8 KiB
Python

import numpy as np
from absl.testing import parameterized
from keras_core import testing
from keras_core.operations import numpy as knp
from keras_core.random import random
class RandomTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{"seed": 10, "shape": (5,), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3},
)
def test_normal(self, seed, shape, mean, stddev):
np.random.seed(seed)
np_res = np.random.normal(loc=mean, scale=stddev, size=shape)
res = random.normal(shape, mean=mean, stddev=stddev, seed=seed)
self.assertEqual(res.shape, shape)
self.assertEqual(res.shape, np_res.shape)
@parameterized.parameters(
{"seed": 10, "shape": (5,), "minval": 0, "maxval": 1},
{"seed": 10, "shape": (2, 3), "minval": 0, "maxval": 1},
{"seed": 10, "shape": (2, 3, 4), "minval": 0, "maxval": 2},
{"seed": 10, "shape": (2, 3), "minval": -1, "maxval": 1},
{"seed": 10, "shape": (2, 3), "minval": 1, "maxval": 3},
)
def test_uniform(self, seed, shape, minval, maxval):
np.random.seed(seed)
np_res = np.random.uniform(low=minval, high=maxval, size=shape)
res = random.uniform(shape, minval=minval, maxval=maxval, seed=seed)
self.assertEqual(res.shape, shape)
self.assertEqual(res.shape, np_res.shape)
self.assertLessEqual(knp.max(res), maxval)
self.assertGreaterEqual(knp.max(res), minval)
@parameterized.parameters(
{"seed": 10, "shape": (5,), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1},
{"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3},
)
def test_truncated_normal(self, seed, shape, mean, stddev):
np.random.seed(seed)
np_res = np.random.normal(loc=mean, scale=stddev, size=shape)
res = random.truncated_normal(
shape, mean=mean, stddev=stddev, seed=seed
)
self.assertEqual(res.shape, shape)
self.assertEqual(res.shape, np_res.shape)
self.assertLessEqual(knp.max(res), mean + 2 * stddev)
self.assertGreaterEqual(knp.max(res), mean - 2 * stddev)
def test_dropout(self):
x = knp.ones((3, 5))
self.assertAllClose(random.dropout(x, rate=0, seed=0), x)
x_res = random.dropout(x, rate=0.8, seed=0)
self.assertGreater(knp.max(x_res), knp.max(x))
self.assertGreater(knp.sum(x_res == 0), 2)