Address the issue of "partition_offset" being ignored by keras initializers.

With the upcoming stateless initializer change, the current behavior of
keras initializer will cause issue for unseeded initializer in PSS,
which call the initializer multiple times with potentially same shape
and different offset. We need to pass the offset information to keras
backend RNG and let it produce different result for different offset
value. Since the offset is only provided at call(), we need to add
extra param to all the methods in keras RNG.

PiperOrigin-RevId: 445261970
This commit is contained in:
Scott Zhu 2022-04-28 15:20:50 -07:00 committed by TensorFlower Gardener
parent 574c7a20b9
commit 78fd5adeb1
4 changed files with 163 additions and 58 deletions

@ -1954,60 +1954,107 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
else:
return random.randint(1, 1e9)
def random_normal(self, shape, mean=0., stddev=1., dtype=None):
def random_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None):
"""Produce random number based on the normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values to
generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
nonce: Optional integer scalar, that will be folded into the seed in the
stateless mode.
"""
self._maybe_init()
dtype = dtype or floatx()
if self._rng_type == self.RNG_STATEFUL:
return self._generator.normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
elif self._rng_type == self.RNG_STATELESS:
seed = self.make_seed_for_stateless_op()
if nonce:
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
return tf.random.stateless_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
seed=self.make_seed_for_stateless_op())
seed=seed)
return tf.random.normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
seed=self.make_legacy_seed())
def random_uniform(self, shape, minval=0., maxval=None, dtype=None):
def random_uniform(self, shape, minval=0., maxval=None, dtype=None,
nonce=None):
"""Produce random number based on the uniform distribution.
Args:
shape: The shape of the random values to generate.
minval: Floats, default to 0. Lower bound of the range of
random values to generate (inclusive).
minval: Floats, default to None. Upper bound of the range of
random values to generate (exclusive).
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
nonce: Optional integer scalar, that will be folded into the seed in the
stateless mode.
"""
self._maybe_init()
dtype = dtype or floatx()
if self._rng_type == self.RNG_STATEFUL:
return self._generator.uniform(
shape=shape, minval=minval, maxval=maxval, dtype=dtype)
elif self._rng_type == self.RNG_STATELESS:
seed = self.make_seed_for_stateless_op()
if nonce:
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
return tf.random.stateless_uniform(
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
seed=self.make_seed_for_stateless_op())
seed=seed)
return tf.random.uniform(
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
seed=self.make_legacy_seed())
def truncated_normal(self, shape, mean=0., stddev=1., dtype=None):
def truncated_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None):
"""Produce random number based on the truncated normal distribution.
Args:
shape: The shape of the random values to generate.
mean: Floats, default to 0. Mean of the random values to generate.
stddev: Floats, default to 1. Standard deviation of the random values to
generate.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
nonce: Optional integer scalar, that will be folded into the seed in the
stateless mode.
"""
self._maybe_init()
dtype = dtype or floatx()
if self._rng_type == self.RNG_STATEFUL:
return self._generator.truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
elif self._rng_type == self.RNG_STATELESS:
seed = self.make_seed_for_stateless_op()
if nonce:
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
return tf.random.stateless_truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
seed=self.make_seed_for_stateless_op())
seed=seed)
return tf.random.truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
seed=self.make_legacy_seed())
def dropout(self, inputs, rate, noise_shape=None):
self._maybe_init()
if self._rng_type == self.RNG_STATEFUL:
if self._rng_type in [self.RNG_STATEFUL, self.RNG_STATELESS]:
return tf.nn.experimental.stateless_dropout(
inputs, rate=rate, noise_shape=noise_shape,
seed=self.make_seed_for_stateless_op())
elif self._rng_type == self.RNG_STATELESS:
return tf.nn.experimental.stateless_dropout(
inputs, rate=rate, noise_shape=noise_shape,
seed=self.make_seed_for_stateless_op())
# We don't support stateless in this case, otherwise the dropout
# will always have identical behavior across the batches.
return tf.nn.dropout(inputs, rate=rate, noise_shape=noise_shape,
seed=self.make_legacy_seed())

@ -2353,7 +2353,7 @@ class ContextValueCacheTest(tf.test.TestCase):
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
class RandomGeneratorTest(tf.test.TestCase):
class RandomGeneratorTest(tf.test.TestCase, parameterized.TestCase):
def test_generator_reproducibility(self):
seed = 1337
@ -2471,6 +2471,25 @@ class RandomGeneratorTest(tf.test.TestCase):
output3 = gen2.random_normal(shape=[2, 3])
self.assertAllClose(output3, output1)
@parameterized.named_parameters(
('seeded', 1337), ('unseeded', None)
)
def test_stateless_with_seed_delta(self, seed):
gen = backend.RandomGenerator(seed=seed, rng_type='stateless')
output1 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1)))
seed1 = gen._seed
output2 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1)))
seed2 = gen._seed
output3 = gen.random_normal(shape=[2, 3], nonce=hash((2, 1)))
seed3 = gen._seed
self.assertAllClose(output1, output2)
# Different seed_delta will produce different value.
self.assertNotAllClose(output1, output3)
# Make sure the internal seed is not changed at all.
self.assertEqual(seed1, seed2)
self.assertEqual(seed1, seed3)
def test_unknown_rng_type(self):
with self.assertRaisesRegex(ValueError, 'Got: unknown'):
backend.RandomGenerator(seed=None, rng_type='unknown')

@ -14,8 +14,7 @@
# ==============================================================================
"""Tests for Keras initializers."""
import tensorflow.compat.v2 as tf
from absl.testing import parameterized
import numpy as np
from keras import backend
@ -26,6 +25,8 @@ from keras.testing_infra import test_utils
from keras.engine import input_layer
from keras.layers import core
import tensorflow.compat.v2 as tf
def _compute_fans(shape):
"""Computes the number of input and output units for a weight shape.
@ -55,7 +56,7 @@ def _compute_fans(shape):
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
class KerasInitializersTest(tf.test.TestCase):
class KerasInitializersTest(tf.test.TestCase, parameterized.TestCase):
def _runner(self, init, shape, target_mean=None, target_std=None,
target_max=None, target_min=None):
@ -254,33 +255,54 @@ class KerasInitializersTest(tf.test.TestCase):
initializer = initializers.deserialize(external_serialized_json)
self.assertEqual(initializer.distribution, 'truncated_normal')
def test_partition(self):
@parameterized.named_parameters(
('Zeros', initializers.ZerosV2, {}),
('Ones', initializers.OnesV2, {}),
('Constant', initializers.ConstantV2, {}),
('RandomUniform', initializers.RandomUniformV2, {}),
('RandomUniform_seeded', initializers.RandomUniformV2, {'seed': 123}),
('RandomNormal', initializers.RandomNormalV2, {}),
('RandomNormal_seeded', initializers.RandomNormalV2, {'seed': 123}),
('TruncatedNormal', initializers.TruncatedNormalV2, {}),
('TruncatedNormal_seeded', initializers.TruncatedNormalV2, {'seed': 123}),
('LecunUniform', initializers.LecunUniformV2, {}),
('LecunUniform_seeded', initializers.LecunUniformV2, {'seed': 123}),
('GlorotUniform', initializers.GlorotUniformV2, {}),
('GlorotUniform_seeded', initializers.GlorotUniformV2, {'seed': 123}),
('HeUniform', initializers.HeUniformV2, {}),
('HeUniform_seeded', initializers.HeUniformV2, {'seed': 123}),
)
def test_partition(self, initializer_cls, kwargs):
with self.cached_session():
partition_enabled_initializers = [
initializers.ZerosV2(),
initializers.OnesV2(),
initializers.RandomUniformV2(),
initializers.RandomNormalV2(),
initializers.TruncatedNormalV2(),
initializers.LecunUniformV2(),
initializers.GlorotUniformV2(),
initializers.HeUniformV2()
]
for initializer in partition_enabled_initializers:
got = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
self.assertEqual(got.shape, (2, 2))
initializer = initializer_cls(**kwargs)
result = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
self.assertEqual(result.shape, (2, 2))
partition_forbidden_initializers = [
initializers.OrthogonalV2(),
initializers.IdentityV2()
]
for initializer in partition_forbidden_initializers:
with self.assertRaisesRegex(
ValueError,
"initializer doesn't support partition-related arguments"):
initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
if hasattr(initializer, 'seed'):
# Make sure the result are different when the partition_shape is same,
# but partition_offset is different, for random related initializers.
result_2 = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0))
self.assertNotAllClose(result, result_2)
# Make sure initializer produce same result when provide same
# partition offset.
# TODO(scottzhu): Enable this assert when initializer is fully stateless
# result_3 = initializer(
# shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0))
# self.assertAllClose(result_2, result_3)
@parameterized.named_parameters(
('Orthogonal', initializers.OrthogonalV2),
('Identity', initializers.IdentityV2),
)
def test_partition_unsupported(self, initializer_cls):
with self.assertRaisesRegex(
ValueError,
"initializer doesn't support partition-related arguments"):
initializer_cls()(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
if __name__ == '__main__':

@ -233,6 +233,10 @@ class Constant(Initializer):
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
_validate_kwargs(self.__class__.__name__, kwargs)
dtype = _get_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
layout = kwargs.pop('layout', None)
if layout:
return utils.call_with_layout(tf.constant, layout, self.value,
@ -299,15 +303,17 @@ class RandomUniform(Initializer):
raise ValueError(f'Expected float or integer dtype, got {dtype}.')
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
nonce = hash(partition_offset) if partition_offset else None
layout = kwargs.pop('layout', None)
if layout:
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
_ensure_keras_seeded()
return utils.call_with_layout(
self._random_generator.random_uniform, layout, shape, self.minval,
self.maxval, dtype)
return self._random_generator.random_uniform(shape, self.minval,
self.maxval, dtype)
self.maxval, dtype, nonce)
return self._random_generator.random_uniform(
shape, self.minval, self.maxval, dtype, nonce)
def get_config(self):
return {
@ -369,15 +375,17 @@ class RandomNormal(Initializer):
dtype = _assert_float_dtype(_get_dtype(dtype))
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
nonce = hash(partition_offset) if partition_offset else None
layout = kwargs.pop('layout', None)
if layout:
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
_ensure_keras_seeded()
return utils.call_with_layout(
self._random_generator.random_normal, layout, shape, self.mean,
self.stddev, dtype)
return self._random_generator.random_normal(shape, self.mean, self.stddev,
dtype)
self.stddev, dtype, nonce)
return self._random_generator.random_normal(
shape, self.mean, self.stddev, dtype, nonce)
def get_config(self):
return {
@ -444,15 +452,17 @@ class TruncatedNormal(Initializer):
dtype = _assert_float_dtype(_get_dtype(dtype))
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
nonce = hash(partition_offset) if partition_offset else None
layout = kwargs.pop('layout', None)
if layout:
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
_ensure_keras_seeded()
return utils.call_with_layout(
self._random_generator.truncated_normal, layout, shape, self.mean,
self.stddev, dtype)
return self._random_generator.truncated_normal(shape, self.mean,
self.stddev, dtype)
self.stddev, dtype, nonce)
return self._random_generator.truncated_normal(
shape, self.mean, self.stddev, dtype, nonce)
def get_config(self):
return {
@ -550,15 +560,19 @@ class VarianceScaling(Initializer):
dtype = _assert_float_dtype(_get_dtype(dtype))
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
nonce = hash(partition_offset) if partition_offset else None
layout = kwargs.pop('layout', None)
if layout:
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
_ensure_keras_seeded()
return utils.call_with_layout(self._generate_init_val, layout,
shape=shape, dtype=dtype)
return self._generate_init_val(shape=shape, dtype=dtype)
return utils.call_with_layout(
self._generate_init_val, layout, shape=shape, dtype=dtype,
nonce=nonce)
return self._generate_init_val(shape=shape, dtype=dtype,
nonce=nonce)
def _generate_init_val(self, shape, dtype):
def _generate_init_val(self, shape, dtype, nonce):
scale = self.scale
fan_in, fan_out = _compute_fans(shape)
if self.mode == 'fan_in':
@ -570,13 +584,16 @@ class VarianceScaling(Initializer):
if self.distribution == 'truncated_normal':
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = math.sqrt(scale) / .87962566103423978
return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype)
return self._random_generator.truncated_normal(
shape, 0.0, stddev, dtype, nonce)
elif self.distribution == 'untruncated_normal':
stddev = math.sqrt(scale)
return self._random_generator.random_normal(shape, 0.0, stddev, dtype)
return self._random_generator.random_normal(
shape, 0.0, stddev, dtype, nonce)
else:
limit = math.sqrt(3.0 * scale)
return self._random_generator.random_uniform(shape, -limit, limit, dtype)
return self._random_generator.random_uniform(
shape, -limit, limit, dtype, nonce)
def get_config(self):
return {