diff --git a/keras_core/layers/convolutional/conv_test.py b/keras_core/layers/convolutional/conv_test.py index 13781131f..4dceda51e 100644 --- a/keras_core/layers/convolutional/conv_test.py +++ b/keras_core/layers/convolutional/conv_test.py @@ -473,4 +473,4 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase): outputs = layer(inputs) expected = tf_keras_layer(inputs) - self.assertAllClose(outputs, expected) + self.assertAllClose(outputs, expected, rtol=1e-5) diff --git a/keras_core/random/random.py b/keras_core/random/random.py index 8c827d21f..84bf827a0 100644 --- a/keras_core/random/random.py +++ b/keras_core/random/random.py @@ -24,7 +24,9 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): across multiple calls, use as seed an instance of `keras_core.random.SeedGenerator`. """ - return normal(shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed) + return backend.random.normal( + shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) @keras_core_export("keras_core.random.uniform") diff --git a/keras_core/random/random_test.py b/keras_core/random/random_test.py new file mode 100644 index 000000000..27007892b --- /dev/null +++ b/keras_core/random/random_test.py @@ -0,0 +1,63 @@ +import numpy as np +from absl.testing import parameterized + +from keras_core import testing +from keras_core.operations import numpy as knp +from keras_core.random import random + + +class RandomTest(testing.TestCase, parameterized.TestCase): + @parameterized.parameters( + {"seed": 10, "shape": (5,), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3}, + ) + def test_normal(self, seed, shape, mean, stddev): + np.random.seed(seed) + np_res = np.random.normal(loc=mean, scale=stddev, size=shape) + res = random.normal(shape, mean=mean, stddev=stddev, seed=seed) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + + @parameterized.parameters( + {"seed": 10, "shape": (5,), "minval": 0, "maxval": 1}, + {"seed": 10, "shape": (2, 3), "minval": 0, "maxval": 1}, + {"seed": 10, "shape": (2, 3, 4), "minval": 0, "maxval": 2}, + {"seed": 10, "shape": (2, 3), "minval": -1, "maxval": 1}, + {"seed": 10, "shape": (2, 3), "minval": 1, "maxval": 3}, + ) + def test_uniform(self, seed, shape, minval, maxval): + np.random.seed(seed) + np_res = np.random.uniform(low=minval, high=maxval, size=shape) + res = random.uniform(shape, minval=minval, maxval=maxval, seed=seed) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + self.assertLessEqual(knp.max(res), maxval) + self.assertGreaterEqual(knp.max(res), minval) + + @parameterized.parameters( + {"seed": 10, "shape": (5,), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3}, + ) + def test_truncated_normal(self, seed, shape, mean, stddev): + np.random.seed(seed) + np_res = np.random.normal(loc=mean, scale=stddev, size=shape) + res = random.truncated_normal( + shape, mean=mean, stddev=stddev, seed=seed + ) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + self.assertLessEqual(knp.max(res), mean + 2 * stddev) + self.assertGreaterEqual(knp.max(res), mean - 2 * stddev) + + def test_dropout(self): + x = knp.ones((3, 5)) + self.assertAllClose(random.dropout(x, rate=0, seed=0), x) + x_res = random.dropout(x, rate=0.8, seed=0) + self.assertGreater(knp.max(x_res), knp.max(x)) + self.assertGreater(knp.sum(x_res == 0), 2) diff --git a/requirements.txt b/requirements.txt index 993874378..313ed93b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ tensorflow -torch +# TODO: Use Torch CPU +# Remove after resolving Cuda version differences with TF +torch>=2.0.1+cpu jax[cpu] namex black>=22