diff --git a/keras/backend.py b/keras/backend.py index c66c4f98c..9edd41208 100644 --- a/keras/backend.py +++ b/keras/backend.py @@ -1954,60 +1954,107 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable): else: return random.randint(1, 1e9) - def random_normal(self, shape, mean=0., stddev=1., dtype=None): + def random_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None): + """Produce random number based on the normal distribution. + + Args: + shape: The shape of the random values to generate. + mean: Floats, default to 0. Mean of the random values to generate. + stddev: Floats, default to 1. Standard deviation of the random values to + generate. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + nonce: Optional integer scalar, that will be folded into the seed in the + stateless mode. + """ self._maybe_init() dtype = dtype or floatx() if self._rng_type == self.RNG_STATEFUL: return self._generator.normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype) elif self._rng_type == self.RNG_STATELESS: + seed = self.make_seed_for_stateless_op() + if nonce: + seed = tf.random.experimental.stateless_fold_in(seed, nonce) return tf.random.stateless_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, - seed=self.make_seed_for_stateless_op()) + seed=seed) return tf.random.normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.make_legacy_seed()) - def random_uniform(self, shape, minval=0., maxval=None, dtype=None): + def random_uniform(self, shape, minval=0., maxval=None, dtype=None, + nonce=None): + """Produce random number based on the uniform distribution. + + Args: + shape: The shape of the random values to generate. + minval: Floats, default to 0. Lower bound of the range of + random values to generate (inclusive). + minval: Floats, default to None. Upper bound of the range of + random values to generate (exclusive). + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + nonce: Optional integer scalar, that will be folded into the seed in the + stateless mode. + """ self._maybe_init() dtype = dtype or floatx() if self._rng_type == self.RNG_STATEFUL: return self._generator.uniform( shape=shape, minval=minval, maxval=maxval, dtype=dtype) elif self._rng_type == self.RNG_STATELESS: + seed = self.make_seed_for_stateless_op() + if nonce: + seed = tf.random.experimental.stateless_fold_in(seed, nonce) return tf.random.stateless_uniform( shape=shape, minval=minval, maxval=maxval, dtype=dtype, - seed=self.make_seed_for_stateless_op()) + seed=seed) return tf.random.uniform( shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=self.make_legacy_seed()) - def truncated_normal(self, shape, mean=0., stddev=1., dtype=None): + def truncated_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None): + """Produce random number based on the truncated normal distribution. + + Args: + shape: The shape of the random values to generate. + mean: Floats, default to 0. Mean of the random values to generate. + stddev: Floats, default to 1. Standard deviation of the random values to + generate. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + nonce: Optional integer scalar, that will be folded into the seed in the + stateless mode. + """ self._maybe_init() dtype = dtype or floatx() if self._rng_type == self.RNG_STATEFUL: return self._generator.truncated_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype) elif self._rng_type == self.RNG_STATELESS: + seed = self.make_seed_for_stateless_op() + if nonce: + seed = tf.random.experimental.stateless_fold_in(seed, nonce) return tf.random.stateless_truncated_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, - seed=self.make_seed_for_stateless_op()) + seed=seed) return tf.random.truncated_normal( shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.make_legacy_seed()) def dropout(self, inputs, rate, noise_shape=None): self._maybe_init() - if self._rng_type == self.RNG_STATEFUL: + if self._rng_type in [self.RNG_STATEFUL, self.RNG_STATELESS]: return tf.nn.experimental.stateless_dropout( inputs, rate=rate, noise_shape=noise_shape, seed=self.make_seed_for_stateless_op()) - elif self._rng_type == self.RNG_STATELESS: - return tf.nn.experimental.stateless_dropout( - inputs, rate=rate, noise_shape=noise_shape, - seed=self.make_seed_for_stateless_op()) - # We don't support stateless in this case, otherwise the dropout - # will always have identical behavior across the batches. return tf.nn.dropout(inputs, rate=rate, noise_shape=noise_shape, seed=self.make_legacy_seed()) diff --git a/keras/backend_test.py b/keras/backend_test.py index 40fd8f02c..cee51d964 100644 --- a/keras/backend_test.py +++ b/keras/backend_test.py @@ -2353,7 +2353,7 @@ class ContextValueCacheTest(tf.test.TestCase): @test_combinations.generate(test_combinations.combine(mode=['graph', 'eager'])) -class RandomGeneratorTest(tf.test.TestCase): +class RandomGeneratorTest(tf.test.TestCase, parameterized.TestCase): def test_generator_reproducibility(self): seed = 1337 @@ -2471,6 +2471,25 @@ class RandomGeneratorTest(tf.test.TestCase): output3 = gen2.random_normal(shape=[2, 3]) self.assertAllClose(output3, output1) + @parameterized.named_parameters( + ('seeded', 1337), ('unseeded', None) + ) + def test_stateless_with_seed_delta(self, seed): + gen = backend.RandomGenerator(seed=seed, rng_type='stateless') + output1 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1))) + seed1 = gen._seed + output2 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1))) + seed2 = gen._seed + output3 = gen.random_normal(shape=[2, 3], nonce=hash((2, 1))) + seed3 = gen._seed + + self.assertAllClose(output1, output2) + # Different seed_delta will produce different value. + self.assertNotAllClose(output1, output3) + # Make sure the internal seed is not changed at all. + self.assertEqual(seed1, seed2) + self.assertEqual(seed1, seed3) + def test_unknown_rng_type(self): with self.assertRaisesRegex(ValueError, 'Got: unknown'): backend.RandomGenerator(seed=None, rng_type='unknown') diff --git a/keras/initializers/initializers_test.py b/keras/initializers/initializers_test.py index 4e62ee7bf..b460aab6b 100644 --- a/keras/initializers/initializers_test.py +++ b/keras/initializers/initializers_test.py @@ -14,8 +14,7 @@ # ============================================================================== """Tests for Keras initializers.""" -import tensorflow.compat.v2 as tf - +from absl.testing import parameterized import numpy as np from keras import backend @@ -26,6 +25,8 @@ from keras.testing_infra import test_utils from keras.engine import input_layer from keras.layers import core +import tensorflow.compat.v2 as tf + def _compute_fans(shape): """Computes the number of input and output units for a weight shape. @@ -55,7 +56,7 @@ def _compute_fans(shape): @test_combinations.generate(test_combinations.combine(mode=['graph', 'eager'])) -class KerasInitializersTest(tf.test.TestCase): +class KerasInitializersTest(tf.test.TestCase, parameterized.TestCase): def _runner(self, init, shape, target_mean=None, target_std=None, target_max=None, target_min=None): @@ -254,33 +255,54 @@ class KerasInitializersTest(tf.test.TestCase): initializer = initializers.deserialize(external_serialized_json) self.assertEqual(initializer.distribution, 'truncated_normal') - def test_partition(self): + @parameterized.named_parameters( + ('Zeros', initializers.ZerosV2, {}), + ('Ones', initializers.OnesV2, {}), + ('Constant', initializers.ConstantV2, {}), + ('RandomUniform', initializers.RandomUniformV2, {}), + ('RandomUniform_seeded', initializers.RandomUniformV2, {'seed': 123}), + ('RandomNormal', initializers.RandomNormalV2, {}), + ('RandomNormal_seeded', initializers.RandomNormalV2, {'seed': 123}), + ('TruncatedNormal', initializers.TruncatedNormalV2, {}), + ('TruncatedNormal_seeded', initializers.TruncatedNormalV2, {'seed': 123}), + ('LecunUniform', initializers.LecunUniformV2, {}), + ('LecunUniform_seeded', initializers.LecunUniformV2, {'seed': 123}), + ('GlorotUniform', initializers.GlorotUniformV2, {}), + ('GlorotUniform_seeded', initializers.GlorotUniformV2, {'seed': 123}), + ('HeUniform', initializers.HeUniformV2, {}), + ('HeUniform_seeded', initializers.HeUniformV2, {'seed': 123}), + ) + def test_partition(self, initializer_cls, kwargs): with self.cached_session(): - partition_enabled_initializers = [ - initializers.ZerosV2(), - initializers.OnesV2(), - initializers.RandomUniformV2(), - initializers.RandomNormalV2(), - initializers.TruncatedNormalV2(), - initializers.LecunUniformV2(), - initializers.GlorotUniformV2(), - initializers.HeUniformV2() - ] - for initializer in partition_enabled_initializers: - got = initializer( - shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) - self.assertEqual(got.shape, (2, 2)) + initializer = initializer_cls(**kwargs) + result = initializer( + shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) + self.assertEqual(result.shape, (2, 2)) - partition_forbidden_initializers = [ - initializers.OrthogonalV2(), - initializers.IdentityV2() - ] - for initializer in partition_forbidden_initializers: - with self.assertRaisesRegex( - ValueError, - "initializer doesn't support partition-related arguments"): - initializer( - shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) + if hasattr(initializer, 'seed'): + # Make sure the result are different when the partition_shape is same, + # but partition_offset is different, for random related initializers. + result_2 = initializer( + shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0)) + self.assertNotAllClose(result, result_2) + + # Make sure initializer produce same result when provide same + # partition offset. + # TODO(scottzhu): Enable this assert when initializer is fully stateless + # result_3 = initializer( + # shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0)) + # self.assertAllClose(result_2, result_3) + + @parameterized.named_parameters( + ('Orthogonal', initializers.OrthogonalV2), + ('Identity', initializers.IdentityV2), + ) + def test_partition_unsupported(self, initializer_cls): + with self.assertRaisesRegex( + ValueError, + "initializer doesn't support partition-related arguments"): + initializer_cls()( + shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) if __name__ == '__main__': diff --git a/keras/initializers/initializers_v2.py b/keras/initializers/initializers_v2.py index 27b5c91f3..0b8ae4f01 100644 --- a/keras/initializers/initializers_v2.py +++ b/keras/initializers/initializers_v2.py @@ -233,6 +233,10 @@ class Constant(Initializer): (via `tf.keras.backend.set_floatx(float_dtype)`). **kwargs: Additional keyword arguments. """ + _validate_kwargs(self.__class__.__name__, kwargs) + dtype = _get_dtype(dtype) + if _PARTITION_SHAPE in kwargs: + shape = kwargs[_PARTITION_SHAPE] layout = kwargs.pop('layout', None) if layout: return utils.call_with_layout(tf.constant, layout, self.value, @@ -299,15 +303,17 @@ class RandomUniform(Initializer): raise ValueError(f'Expected float or integer dtype, got {dtype}.') if _PARTITION_SHAPE in kwargs: shape = kwargs[_PARTITION_SHAPE] + partition_offset = kwargs.get(_PARTITION_OFFSET, None) + nonce = hash(partition_offset) if partition_offset else None layout = kwargs.pop('layout', None) if layout: self._random_generator._rng_type = self._random_generator.RNG_STATEFUL _ensure_keras_seeded() return utils.call_with_layout( self._random_generator.random_uniform, layout, shape, self.minval, - self.maxval, dtype) - return self._random_generator.random_uniform(shape, self.minval, - self.maxval, dtype) + self.maxval, dtype, nonce) + return self._random_generator.random_uniform( + shape, self.minval, self.maxval, dtype, nonce) def get_config(self): return { @@ -369,15 +375,17 @@ class RandomNormal(Initializer): dtype = _assert_float_dtype(_get_dtype(dtype)) if _PARTITION_SHAPE in kwargs: shape = kwargs[_PARTITION_SHAPE] + partition_offset = kwargs.get(_PARTITION_OFFSET, None) + nonce = hash(partition_offset) if partition_offset else None layout = kwargs.pop('layout', None) if layout: self._random_generator._rng_type = self._random_generator.RNG_STATEFUL _ensure_keras_seeded() return utils.call_with_layout( self._random_generator.random_normal, layout, shape, self.mean, - self.stddev, dtype) - return self._random_generator.random_normal(shape, self.mean, self.stddev, - dtype) + self.stddev, dtype, nonce) + return self._random_generator.random_normal( + shape, self.mean, self.stddev, dtype, nonce) def get_config(self): return { @@ -444,15 +452,17 @@ class TruncatedNormal(Initializer): dtype = _assert_float_dtype(_get_dtype(dtype)) if _PARTITION_SHAPE in kwargs: shape = kwargs[_PARTITION_SHAPE] + partition_offset = kwargs.get(_PARTITION_OFFSET, None) + nonce = hash(partition_offset) if partition_offset else None layout = kwargs.pop('layout', None) if layout: self._random_generator._rng_type = self._random_generator.RNG_STATEFUL _ensure_keras_seeded() return utils.call_with_layout( self._random_generator.truncated_normal, layout, shape, self.mean, - self.stddev, dtype) - return self._random_generator.truncated_normal(shape, self.mean, - self.stddev, dtype) + self.stddev, dtype, nonce) + return self._random_generator.truncated_normal( + shape, self.mean, self.stddev, dtype, nonce) def get_config(self): return { @@ -550,15 +560,19 @@ class VarianceScaling(Initializer): dtype = _assert_float_dtype(_get_dtype(dtype)) if _PARTITION_SHAPE in kwargs: shape = kwargs[_PARTITION_SHAPE] + partition_offset = kwargs.get(_PARTITION_OFFSET, None) + nonce = hash(partition_offset) if partition_offset else None layout = kwargs.pop('layout', None) if layout: self._random_generator._rng_type = self._random_generator.RNG_STATEFUL _ensure_keras_seeded() - return utils.call_with_layout(self._generate_init_val, layout, - shape=shape, dtype=dtype) - return self._generate_init_val(shape=shape, dtype=dtype) + return utils.call_with_layout( + self._generate_init_val, layout, shape=shape, dtype=dtype, + nonce=nonce) + return self._generate_init_val(shape=shape, dtype=dtype, + nonce=nonce) - def _generate_init_val(self, shape, dtype): + def _generate_init_val(self, shape, dtype, nonce): scale = self.scale fan_in, fan_out = _compute_fans(shape) if self.mode == 'fan_in': @@ -570,13 +584,16 @@ class VarianceScaling(Initializer): if self.distribution == 'truncated_normal': # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = math.sqrt(scale) / .87962566103423978 - return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype) + return self._random_generator.truncated_normal( + shape, 0.0, stddev, dtype, nonce) elif self.distribution == 'untruncated_normal': stddev = math.sqrt(scale) - return self._random_generator.random_normal(shape, 0.0, stddev, dtype) + return self._random_generator.random_normal( + shape, 0.0, stddev, dtype, nonce) else: limit = math.sqrt(3.0 * scale) - return self._random_generator.random_uniform(shape, -limit, limit, dtype) + return self._random_generator.random_uniform( + shape, -limit, limit, dtype, nonce) def get_config(self): return {