Minor refactor of backend.random

This commit is contained in:
Francois Chollet 2023-04-12 21:07:17 -07:00
parent 1fc98ab59b
commit af0040d153
7 changed files with 34 additions and 35 deletions

@ -29,7 +29,7 @@ class MiniDropout(Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
self.seed_generator = backend.random.RandomSeedGenerator(1337)
self.seed_generator = backend.random.SeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(

@ -32,7 +32,7 @@ class MiniDropout(Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
self.seed_generator = backend.random.RandomSeedGenerator(1337)
self.seed_generator = backend.random.SeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(

@ -1,7 +1,9 @@
import random as python_random
from keras_core.backend import backend
class RandomSeedGenerator:
class SeedGenerator:
def __init__(self, seed):
from keras_core.backend import Variable
@ -20,7 +22,7 @@ def make_default_seed():
def draw_seed(seed):
from keras_core.backend import convert_to_tensor
if isinstance(seed, RandomSeedGenerator):
if isinstance(seed, SeedGenerator):
new_seed_value = seed.state.value
seed.state.assign(
seed.state + convert_to_tensor([0, 1], dtype="uint32")
@ -32,6 +34,12 @@ def draw_seed(seed):
return convert_to_tensor([make_default_seed(), 0], dtype="uint32")
raise ValueError(
"Argument `seed` must be either an integer "
"or an instance of `RandomSeedGenerator`. "
"or an instance of `SeedGenerator`. "
f"Received: seed={seed} (of type {type(seed)})"
)
if backend() == "jax":
from keras_core.backend.jax.random import *
else:
from keras_core.backend.tensorflow.random import *

@ -1,9 +0,0 @@
from keras_core.backend import backend
from keras_core.backend.random.random_seed_generator import RandomSeedGenerator
from keras_core.backend.random.random_seed_generator import draw_seed
from keras_core.backend.random.random_seed_generator import make_default_seed
if backend() == "jax":
from keras_core.backend.jax.random import *
else:
from keras_core.backend.tensorflow.random import *

@ -35,13 +35,13 @@ class VarianceScaling(Initializer):
distribution: Random distribution to use.
One of `"truncated_normal"`, `"untruncated_normal"`, or `"uniform"`.
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
"""
def __init__(
@ -143,13 +143,13 @@ class GlorotUniform(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
References:
@ -185,13 +185,13 @@ class GlorotNormal(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
References:
- [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
@ -232,13 +232,13 @@ class LecunNormal(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
References:
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
@ -272,13 +272,13 @@ class LecunUniform(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
References:
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
@ -312,13 +312,13 @@ class HeNormal(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
Reference:
- [He et al., 2015](https://arxiv.org/abs/1502.01852)
@ -352,13 +352,13 @@ class HeUniform(VarianceScaling):
Args:
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
Reference:
- [He et al., 2015](https://arxiv.org/abs/1502.01852)
@ -422,13 +422,13 @@ class RandomNormal(Initializer):
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
random values to generate.
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
"""
def __init__(self, mean=0.0, stddev=1.0, seed=None):
@ -471,13 +471,13 @@ class RandomUniform(Initializer):
maxval: A python scalar or a scalar keras tensor. Upper bound of the range of
random values to generate (exclusive).
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
`keras_core.backend.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
of `keras_core.backend.SeedGenerator`.
"""
def __init__(self, minval=0.0, maxval=1.0, seed=None):

@ -71,7 +71,7 @@ class Layer(Operation):
and not isinstance(x, Metric),
self._layers,
),
# TODO: RandomSeedGenerator tracking
# TODO: SeedGenerator tracking
}
)
@ -176,7 +176,7 @@ class Layer(Operation):
@property
def variables(self):
# TODO: include not just weights by any variables (also from metrics, optimizers, RandomSeedGenerators)
# TODO: include not just weights by any variables (also from metrics, optimizers, SeedGenerators)
variables = self.weights[:]
return variables

@ -29,7 +29,7 @@ class MiniDropout(Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
self.seed_generator = backend.random.RandomSeedGenerator(1337)
self.seed_generator = backend.random.SeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(