keras/keras_core/initializers/random_initializers_test.py
Matt Watson 273d3be5cd Fixed SeedGenerator in the torch backend (#434)
We were ignoring the second seed entirely, which is quite incorrect as
that is the only part of the seed we are incrementing.

I'm just summing the fixed seed and increment counter together for now.
Not sure if that's the best approach for truly random seeding, or
what approach torch is even using for there seeds, but figure this is
much better than nothing.
2023-07-10 10:43:08 -07:00

145 lines
5.0 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import initializers
from keras_core import testing
from keras_core import utils
class InitializersTest(testing.TestCase):
def test_random_normal(self):
utils.set_random_seed(1337)
shape = (25, 20)
mean = 0.0
stddev = 1.0
seed = 1234
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=seed
)
values = initializer(shape=shape)
self.assertEqual(initializer.mean, mean)
self.assertEqual(initializer.stddev, stddev)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.assertAllClose(
np.std(backend.convert_to_numpy(values)), stddev, atol=1e-1
)
self.run_class_serialization_test(initializer)
# Test that a fixed seed yields the same results each call.
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=1337
)
values = initializer(shape=shape)
next_values = initializer(shape=shape)
self.assertAllClose(values, next_values)
# Test that a SeedGenerator yields different results each call.
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337)
)
values = initializer(shape=shape)
next_values = initializer(shape=shape)
self.assertNotAllClose(values, next_values)
# Test serialization with SeedGenerator
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337)
)
values = initializer(shape=shape)
# Test that unseeded generator gets different results after cloning
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=None
)
values = initializer(shape=shape)
cloned_initializer = initializers.RandomNormal.from_config(
initializer.get_config()
)
new_values = cloned_initializer(shape=shape)
self.assertNotAllClose(values, new_values)
# Test that seeded generator gets same results after cloning
initializer = initializers.RandomNormal(
mean=mean, stddev=stddev, seed=1337
)
values = initializer(shape=shape)
cloned_initializer = initializers.RandomNormal.from_config(
initializer.get_config()
)
new_values = cloned_initializer(shape=shape)
self.assertAllClose(values, new_values)
def test_random_uniform(self):
shape = (5, 5)
minval = -1.0
maxval = 1.0
seed = 1234
initializer = initializers.RandomUniform(
minval=minval, maxval=maxval, seed=seed
)
values = initializer(shape=shape)
self.assertEqual(initializer.minval, minval)
self.assertEqual(initializer.maxval, maxval)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
values = backend.convert_to_numpy(values)
self.assertGreaterEqual(np.min(values), minval)
self.assertLess(np.max(values), maxval)
self.run_class_serialization_test(initializer)
def test_variance_scaling(self):
utils.set_random_seed(1337)
shape = (25, 20)
scale = 2.0
seed = 1234
initializer = initializers.VarianceScaling(
scale=scale, seed=seed, mode="fan_in"
)
values = initializer(shape=shape)
self.assertEqual(initializer.scale, scale)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.assertAllClose(
np.std(backend.convert_to_numpy(values)),
np.sqrt(scale / 25),
atol=1e-1,
)
self.run_class_serialization_test(initializer)
initializer = initializers.VarianceScaling(
scale=scale, seed=seed, mode="fan_out"
)
values = initializer(shape=shape)
self.assertEqual(initializer.scale, scale)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.assertAllClose(
np.std(backend.convert_to_numpy(values)),
np.sqrt(scale / 20),
atol=1e-1,
)
self.run_class_serialization_test(initializer)
def test_orthogonal_initializer(self):
shape = (5, 5)
gain = 2.0
seed = 1234
initializer = initializers.OrthogonalInitializer(gain=gain, seed=seed)
_ = initializer(shape=shape)
# TODO: test correctness
self.run_class_serialization_test(initializer)
def test_get_method(self):
obj = initializers.get("glorot_normal")
self.assertTrue(obj, initializers.GlorotNormal)
obj = initializers.get(None)
self.assertEqual(obj, None)
with self.assertRaises(ValueError):
initializers.get("typo")