Address the issue of "partition_offset" being ignored by keras initializers.
With the upcoming stateless initializer change, the current behavior of keras initializer will cause issue for unseeded initializer in PSS, which call the initializer multiple times with potentially same shape and different offset. We need to pass the offset information to keras backend RNG and let it produce different result for different offset value. Since the offset is only provided at call(), we need to add extra param to all the methods in keras RNG. PiperOrigin-RevId: 445261970
This commit is contained in:
parent
574c7a20b9
commit
78fd5adeb1
@ -1954,60 +1954,107 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
|
|||||||
else:
|
else:
|
||||||
return random.randint(1, 1e9)
|
return random.randint(1, 1e9)
|
||||||
|
|
||||||
def random_normal(self, shape, mean=0., stddev=1., dtype=None):
|
def random_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None):
|
||||||
|
"""Produce random number based on the normal distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the random values to generate.
|
||||||
|
mean: Floats, default to 0. Mean of the random values to generate.
|
||||||
|
stddev: Floats, default to 1. Standard deviation of the random values to
|
||||||
|
generate.
|
||||||
|
dtype: Optional dtype of the tensor. Only floating point types are
|
||||||
|
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||||
|
default to `float32` unless you configured it otherwise (via
|
||||||
|
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||||
|
nonce: Optional integer scalar, that will be folded into the seed in the
|
||||||
|
stateless mode.
|
||||||
|
"""
|
||||||
self._maybe_init()
|
self._maybe_init()
|
||||||
dtype = dtype or floatx()
|
dtype = dtype or floatx()
|
||||||
if self._rng_type == self.RNG_STATEFUL:
|
if self._rng_type == self.RNG_STATEFUL:
|
||||||
return self._generator.normal(
|
return self._generator.normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
|
||||||
elif self._rng_type == self.RNG_STATELESS:
|
elif self._rng_type == self.RNG_STATELESS:
|
||||||
|
seed = self.make_seed_for_stateless_op()
|
||||||
|
if nonce:
|
||||||
|
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
|
||||||
return tf.random.stateless_normal(
|
return tf.random.stateless_normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
||||||
seed=self.make_seed_for_stateless_op())
|
seed=seed)
|
||||||
return tf.random.normal(
|
return tf.random.normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
||||||
seed=self.make_legacy_seed())
|
seed=self.make_legacy_seed())
|
||||||
|
|
||||||
def random_uniform(self, shape, minval=0., maxval=None, dtype=None):
|
def random_uniform(self, shape, minval=0., maxval=None, dtype=None,
|
||||||
|
nonce=None):
|
||||||
|
"""Produce random number based on the uniform distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the random values to generate.
|
||||||
|
minval: Floats, default to 0. Lower bound of the range of
|
||||||
|
random values to generate (inclusive).
|
||||||
|
minval: Floats, default to None. Upper bound of the range of
|
||||||
|
random values to generate (exclusive).
|
||||||
|
dtype: Optional dtype of the tensor. Only floating point types are
|
||||||
|
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||||
|
default to `float32` unless you configured it otherwise (via
|
||||||
|
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||||
|
nonce: Optional integer scalar, that will be folded into the seed in the
|
||||||
|
stateless mode.
|
||||||
|
"""
|
||||||
self._maybe_init()
|
self._maybe_init()
|
||||||
dtype = dtype or floatx()
|
dtype = dtype or floatx()
|
||||||
if self._rng_type == self.RNG_STATEFUL:
|
if self._rng_type == self.RNG_STATEFUL:
|
||||||
return self._generator.uniform(
|
return self._generator.uniform(
|
||||||
shape=shape, minval=minval, maxval=maxval, dtype=dtype)
|
shape=shape, minval=minval, maxval=maxval, dtype=dtype)
|
||||||
elif self._rng_type == self.RNG_STATELESS:
|
elif self._rng_type == self.RNG_STATELESS:
|
||||||
|
seed = self.make_seed_for_stateless_op()
|
||||||
|
if nonce:
|
||||||
|
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
|
||||||
return tf.random.stateless_uniform(
|
return tf.random.stateless_uniform(
|
||||||
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
|
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
|
||||||
seed=self.make_seed_for_stateless_op())
|
seed=seed)
|
||||||
return tf.random.uniform(
|
return tf.random.uniform(
|
||||||
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
|
shape=shape, minval=minval, maxval=maxval, dtype=dtype,
|
||||||
seed=self.make_legacy_seed())
|
seed=self.make_legacy_seed())
|
||||||
|
|
||||||
def truncated_normal(self, shape, mean=0., stddev=1., dtype=None):
|
def truncated_normal(self, shape, mean=0., stddev=1., dtype=None, nonce=None):
|
||||||
|
"""Produce random number based on the truncated normal distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the random values to generate.
|
||||||
|
mean: Floats, default to 0. Mean of the random values to generate.
|
||||||
|
stddev: Floats, default to 1. Standard deviation of the random values to
|
||||||
|
generate.
|
||||||
|
dtype: Optional dtype of the tensor. Only floating point types are
|
||||||
|
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||||
|
default to `float32` unless you configured it otherwise (via
|
||||||
|
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||||
|
nonce: Optional integer scalar, that will be folded into the seed in the
|
||||||
|
stateless mode.
|
||||||
|
"""
|
||||||
self._maybe_init()
|
self._maybe_init()
|
||||||
dtype = dtype or floatx()
|
dtype = dtype or floatx()
|
||||||
if self._rng_type == self.RNG_STATEFUL:
|
if self._rng_type == self.RNG_STATEFUL:
|
||||||
return self._generator.truncated_normal(
|
return self._generator.truncated_normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype)
|
||||||
elif self._rng_type == self.RNG_STATELESS:
|
elif self._rng_type == self.RNG_STATELESS:
|
||||||
|
seed = self.make_seed_for_stateless_op()
|
||||||
|
if nonce:
|
||||||
|
seed = tf.random.experimental.stateless_fold_in(seed, nonce)
|
||||||
return tf.random.stateless_truncated_normal(
|
return tf.random.stateless_truncated_normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
||||||
seed=self.make_seed_for_stateless_op())
|
seed=seed)
|
||||||
return tf.random.truncated_normal(
|
return tf.random.truncated_normal(
|
||||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
shape=shape, mean=mean, stddev=stddev, dtype=dtype,
|
||||||
seed=self.make_legacy_seed())
|
seed=self.make_legacy_seed())
|
||||||
|
|
||||||
def dropout(self, inputs, rate, noise_shape=None):
|
def dropout(self, inputs, rate, noise_shape=None):
|
||||||
self._maybe_init()
|
self._maybe_init()
|
||||||
if self._rng_type == self.RNG_STATEFUL:
|
if self._rng_type in [self.RNG_STATEFUL, self.RNG_STATELESS]:
|
||||||
return tf.nn.experimental.stateless_dropout(
|
return tf.nn.experimental.stateless_dropout(
|
||||||
inputs, rate=rate, noise_shape=noise_shape,
|
inputs, rate=rate, noise_shape=noise_shape,
|
||||||
seed=self.make_seed_for_stateless_op())
|
seed=self.make_seed_for_stateless_op())
|
||||||
elif self._rng_type == self.RNG_STATELESS:
|
|
||||||
return tf.nn.experimental.stateless_dropout(
|
|
||||||
inputs, rate=rate, noise_shape=noise_shape,
|
|
||||||
seed=self.make_seed_for_stateless_op())
|
|
||||||
# We don't support stateless in this case, otherwise the dropout
|
|
||||||
# will always have identical behavior across the batches.
|
|
||||||
return tf.nn.dropout(inputs, rate=rate, noise_shape=noise_shape,
|
return tf.nn.dropout(inputs, rate=rate, noise_shape=noise_shape,
|
||||||
seed=self.make_legacy_seed())
|
seed=self.make_legacy_seed())
|
||||||
|
|
||||||
|
@ -2353,7 +2353,7 @@ class ContextValueCacheTest(tf.test.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
|
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
|
||||||
class RandomGeneratorTest(tf.test.TestCase):
|
class RandomGeneratorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_generator_reproducibility(self):
|
def test_generator_reproducibility(self):
|
||||||
seed = 1337
|
seed = 1337
|
||||||
@ -2471,6 +2471,25 @@ class RandomGeneratorTest(tf.test.TestCase):
|
|||||||
output3 = gen2.random_normal(shape=[2, 3])
|
output3 = gen2.random_normal(shape=[2, 3])
|
||||||
self.assertAllClose(output3, output1)
|
self.assertAllClose(output3, output1)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('seeded', 1337), ('unseeded', None)
|
||||||
|
)
|
||||||
|
def test_stateless_with_seed_delta(self, seed):
|
||||||
|
gen = backend.RandomGenerator(seed=seed, rng_type='stateless')
|
||||||
|
output1 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1)))
|
||||||
|
seed1 = gen._seed
|
||||||
|
output2 = gen.random_normal(shape=[2, 3], nonce=hash((1, 1)))
|
||||||
|
seed2 = gen._seed
|
||||||
|
output3 = gen.random_normal(shape=[2, 3], nonce=hash((2, 1)))
|
||||||
|
seed3 = gen._seed
|
||||||
|
|
||||||
|
self.assertAllClose(output1, output2)
|
||||||
|
# Different seed_delta will produce different value.
|
||||||
|
self.assertNotAllClose(output1, output3)
|
||||||
|
# Make sure the internal seed is not changed at all.
|
||||||
|
self.assertEqual(seed1, seed2)
|
||||||
|
self.assertEqual(seed1, seed3)
|
||||||
|
|
||||||
def test_unknown_rng_type(self):
|
def test_unknown_rng_type(self):
|
||||||
with self.assertRaisesRegex(ValueError, 'Got: unknown'):
|
with self.assertRaisesRegex(ValueError, 'Got: unknown'):
|
||||||
backend.RandomGenerator(seed=None, rng_type='unknown')
|
backend.RandomGenerator(seed=None, rng_type='unknown')
|
||||||
|
@ -14,8 +14,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for Keras initializers."""
|
"""Tests for Keras initializers."""
|
||||||
|
|
||||||
import tensorflow.compat.v2 as tf
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from keras import backend
|
from keras import backend
|
||||||
@ -26,6 +25,8 @@ from keras.testing_infra import test_utils
|
|||||||
from keras.engine import input_layer
|
from keras.engine import input_layer
|
||||||
from keras.layers import core
|
from keras.layers import core
|
||||||
|
|
||||||
|
import tensorflow.compat.v2 as tf
|
||||||
|
|
||||||
|
|
||||||
def _compute_fans(shape):
|
def _compute_fans(shape):
|
||||||
"""Computes the number of input and output units for a weight shape.
|
"""Computes the number of input and output units for a weight shape.
|
||||||
@ -55,7 +56,7 @@ def _compute_fans(shape):
|
|||||||
|
|
||||||
|
|
||||||
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
|
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
|
||||||
class KerasInitializersTest(tf.test.TestCase):
|
class KerasInitializersTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _runner(self, init, shape, target_mean=None, target_std=None,
|
def _runner(self, init, shape, target_mean=None, target_std=None,
|
||||||
target_max=None, target_min=None):
|
target_max=None, target_min=None):
|
||||||
@ -254,33 +255,54 @@ class KerasInitializersTest(tf.test.TestCase):
|
|||||||
initializer = initializers.deserialize(external_serialized_json)
|
initializer = initializers.deserialize(external_serialized_json)
|
||||||
self.assertEqual(initializer.distribution, 'truncated_normal')
|
self.assertEqual(initializer.distribution, 'truncated_normal')
|
||||||
|
|
||||||
def test_partition(self):
|
@parameterized.named_parameters(
|
||||||
|
('Zeros', initializers.ZerosV2, {}),
|
||||||
|
('Ones', initializers.OnesV2, {}),
|
||||||
|
('Constant', initializers.ConstantV2, {}),
|
||||||
|
('RandomUniform', initializers.RandomUniformV2, {}),
|
||||||
|
('RandomUniform_seeded', initializers.RandomUniformV2, {'seed': 123}),
|
||||||
|
('RandomNormal', initializers.RandomNormalV2, {}),
|
||||||
|
('RandomNormal_seeded', initializers.RandomNormalV2, {'seed': 123}),
|
||||||
|
('TruncatedNormal', initializers.TruncatedNormalV2, {}),
|
||||||
|
('TruncatedNormal_seeded', initializers.TruncatedNormalV2, {'seed': 123}),
|
||||||
|
('LecunUniform', initializers.LecunUniformV2, {}),
|
||||||
|
('LecunUniform_seeded', initializers.LecunUniformV2, {'seed': 123}),
|
||||||
|
('GlorotUniform', initializers.GlorotUniformV2, {}),
|
||||||
|
('GlorotUniform_seeded', initializers.GlorotUniformV2, {'seed': 123}),
|
||||||
|
('HeUniform', initializers.HeUniformV2, {}),
|
||||||
|
('HeUniform_seeded', initializers.HeUniformV2, {'seed': 123}),
|
||||||
|
)
|
||||||
|
def test_partition(self, initializer_cls, kwargs):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
partition_enabled_initializers = [
|
initializer = initializer_cls(**kwargs)
|
||||||
initializers.ZerosV2(),
|
result = initializer(
|
||||||
initializers.OnesV2(),
|
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
||||||
initializers.RandomUniformV2(),
|
self.assertEqual(result.shape, (2, 2))
|
||||||
initializers.RandomNormalV2(),
|
|
||||||
initializers.TruncatedNormalV2(),
|
|
||||||
initializers.LecunUniformV2(),
|
|
||||||
initializers.GlorotUniformV2(),
|
|
||||||
initializers.HeUniformV2()
|
|
||||||
]
|
|
||||||
for initializer in partition_enabled_initializers:
|
|
||||||
got = initializer(
|
|
||||||
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
|
||||||
self.assertEqual(got.shape, (2, 2))
|
|
||||||
|
|
||||||
partition_forbidden_initializers = [
|
if hasattr(initializer, 'seed'):
|
||||||
initializers.OrthogonalV2(),
|
# Make sure the result are different when the partition_shape is same,
|
||||||
initializers.IdentityV2()
|
# but partition_offset is different, for random related initializers.
|
||||||
]
|
result_2 = initializer(
|
||||||
for initializer in partition_forbidden_initializers:
|
shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0))
|
||||||
with self.assertRaisesRegex(
|
self.assertNotAllClose(result, result_2)
|
||||||
ValueError,
|
|
||||||
"initializer doesn't support partition-related arguments"):
|
# Make sure initializer produce same result when provide same
|
||||||
initializer(
|
# partition offset.
|
||||||
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
# TODO(scottzhu): Enable this assert when initializer is fully stateless
|
||||||
|
# result_3 = initializer(
|
||||||
|
# shape=(4, 2), partition_shape=(2, 2), partition_offset=(1, 0))
|
||||||
|
# self.assertAllClose(result_2, result_3)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('Orthogonal', initializers.OrthogonalV2),
|
||||||
|
('Identity', initializers.IdentityV2),
|
||||||
|
)
|
||||||
|
def test_partition_unsupported(self, initializer_cls):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"initializer doesn't support partition-related arguments"):
|
||||||
|
initializer_cls()(
|
||||||
|
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -233,6 +233,10 @@ class Constant(Initializer):
|
|||||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
_validate_kwargs(self.__class__.__name__, kwargs)
|
||||||
|
dtype = _get_dtype(dtype)
|
||||||
|
if _PARTITION_SHAPE in kwargs:
|
||||||
|
shape = kwargs[_PARTITION_SHAPE]
|
||||||
layout = kwargs.pop('layout', None)
|
layout = kwargs.pop('layout', None)
|
||||||
if layout:
|
if layout:
|
||||||
return utils.call_with_layout(tf.constant, layout, self.value,
|
return utils.call_with_layout(tf.constant, layout, self.value,
|
||||||
@ -299,15 +303,17 @@ class RandomUniform(Initializer):
|
|||||||
raise ValueError(f'Expected float or integer dtype, got {dtype}.')
|
raise ValueError(f'Expected float or integer dtype, got {dtype}.')
|
||||||
if _PARTITION_SHAPE in kwargs:
|
if _PARTITION_SHAPE in kwargs:
|
||||||
shape = kwargs[_PARTITION_SHAPE]
|
shape = kwargs[_PARTITION_SHAPE]
|
||||||
|
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
|
||||||
|
nonce = hash(partition_offset) if partition_offset else None
|
||||||
layout = kwargs.pop('layout', None)
|
layout = kwargs.pop('layout', None)
|
||||||
if layout:
|
if layout:
|
||||||
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
||||||
_ensure_keras_seeded()
|
_ensure_keras_seeded()
|
||||||
return utils.call_with_layout(
|
return utils.call_with_layout(
|
||||||
self._random_generator.random_uniform, layout, shape, self.minval,
|
self._random_generator.random_uniform, layout, shape, self.minval,
|
||||||
self.maxval, dtype)
|
self.maxval, dtype, nonce)
|
||||||
return self._random_generator.random_uniform(shape, self.minval,
|
return self._random_generator.random_uniform(
|
||||||
self.maxval, dtype)
|
shape, self.minval, self.maxval, dtype, nonce)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
@ -369,15 +375,17 @@ class RandomNormal(Initializer):
|
|||||||
dtype = _assert_float_dtype(_get_dtype(dtype))
|
dtype = _assert_float_dtype(_get_dtype(dtype))
|
||||||
if _PARTITION_SHAPE in kwargs:
|
if _PARTITION_SHAPE in kwargs:
|
||||||
shape = kwargs[_PARTITION_SHAPE]
|
shape = kwargs[_PARTITION_SHAPE]
|
||||||
|
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
|
||||||
|
nonce = hash(partition_offset) if partition_offset else None
|
||||||
layout = kwargs.pop('layout', None)
|
layout = kwargs.pop('layout', None)
|
||||||
if layout:
|
if layout:
|
||||||
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
||||||
_ensure_keras_seeded()
|
_ensure_keras_seeded()
|
||||||
return utils.call_with_layout(
|
return utils.call_with_layout(
|
||||||
self._random_generator.random_normal, layout, shape, self.mean,
|
self._random_generator.random_normal, layout, shape, self.mean,
|
||||||
self.stddev, dtype)
|
self.stddev, dtype, nonce)
|
||||||
return self._random_generator.random_normal(shape, self.mean, self.stddev,
|
return self._random_generator.random_normal(
|
||||||
dtype)
|
shape, self.mean, self.stddev, dtype, nonce)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
@ -444,15 +452,17 @@ class TruncatedNormal(Initializer):
|
|||||||
dtype = _assert_float_dtype(_get_dtype(dtype))
|
dtype = _assert_float_dtype(_get_dtype(dtype))
|
||||||
if _PARTITION_SHAPE in kwargs:
|
if _PARTITION_SHAPE in kwargs:
|
||||||
shape = kwargs[_PARTITION_SHAPE]
|
shape = kwargs[_PARTITION_SHAPE]
|
||||||
|
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
|
||||||
|
nonce = hash(partition_offset) if partition_offset else None
|
||||||
layout = kwargs.pop('layout', None)
|
layout = kwargs.pop('layout', None)
|
||||||
if layout:
|
if layout:
|
||||||
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
||||||
_ensure_keras_seeded()
|
_ensure_keras_seeded()
|
||||||
return utils.call_with_layout(
|
return utils.call_with_layout(
|
||||||
self._random_generator.truncated_normal, layout, shape, self.mean,
|
self._random_generator.truncated_normal, layout, shape, self.mean,
|
||||||
self.stddev, dtype)
|
self.stddev, dtype, nonce)
|
||||||
return self._random_generator.truncated_normal(shape, self.mean,
|
return self._random_generator.truncated_normal(
|
||||||
self.stddev, dtype)
|
shape, self.mean, self.stddev, dtype, nonce)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
@ -550,15 +560,19 @@ class VarianceScaling(Initializer):
|
|||||||
dtype = _assert_float_dtype(_get_dtype(dtype))
|
dtype = _assert_float_dtype(_get_dtype(dtype))
|
||||||
if _PARTITION_SHAPE in kwargs:
|
if _PARTITION_SHAPE in kwargs:
|
||||||
shape = kwargs[_PARTITION_SHAPE]
|
shape = kwargs[_PARTITION_SHAPE]
|
||||||
|
partition_offset = kwargs.get(_PARTITION_OFFSET, None)
|
||||||
|
nonce = hash(partition_offset) if partition_offset else None
|
||||||
layout = kwargs.pop('layout', None)
|
layout = kwargs.pop('layout', None)
|
||||||
if layout:
|
if layout:
|
||||||
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
|
||||||
_ensure_keras_seeded()
|
_ensure_keras_seeded()
|
||||||
return utils.call_with_layout(self._generate_init_val, layout,
|
return utils.call_with_layout(
|
||||||
shape=shape, dtype=dtype)
|
self._generate_init_val, layout, shape=shape, dtype=dtype,
|
||||||
return self._generate_init_val(shape=shape, dtype=dtype)
|
nonce=nonce)
|
||||||
|
return self._generate_init_val(shape=shape, dtype=dtype,
|
||||||
|
nonce=nonce)
|
||||||
|
|
||||||
def _generate_init_val(self, shape, dtype):
|
def _generate_init_val(self, shape, dtype, nonce):
|
||||||
scale = self.scale
|
scale = self.scale
|
||||||
fan_in, fan_out = _compute_fans(shape)
|
fan_in, fan_out = _compute_fans(shape)
|
||||||
if self.mode == 'fan_in':
|
if self.mode == 'fan_in':
|
||||||
@ -570,13 +584,16 @@ class VarianceScaling(Initializer):
|
|||||||
if self.distribution == 'truncated_normal':
|
if self.distribution == 'truncated_normal':
|
||||||
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
|
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
|
||||||
stddev = math.sqrt(scale) / .87962566103423978
|
stddev = math.sqrt(scale) / .87962566103423978
|
||||||
return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype)
|
return self._random_generator.truncated_normal(
|
||||||
|
shape, 0.0, stddev, dtype, nonce)
|
||||||
elif self.distribution == 'untruncated_normal':
|
elif self.distribution == 'untruncated_normal':
|
||||||
stddev = math.sqrt(scale)
|
stddev = math.sqrt(scale)
|
||||||
return self._random_generator.random_normal(shape, 0.0, stddev, dtype)
|
return self._random_generator.random_normal(
|
||||||
|
shape, 0.0, stddev, dtype, nonce)
|
||||||
else:
|
else:
|
||||||
limit = math.sqrt(3.0 * scale)
|
limit = math.sqrt(3.0 * scale)
|
||||||
return self._random_generator.random_uniform(shape, -limit, limit, dtype)
|
return self._random_generator.random_uniform(
|
||||||
|
shape, -limit, limit, dtype, nonce)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
|
Loading…
Reference in New Issue
Block a user