Use SeedGenerator backend when creating variable (#439)

This commit is contained in:
Ramesh Sampath 2023-07-11 22:33:54 +05:30 committed by Francois Chollet
parent c69d163133
commit 60ca4b1799
2 changed files with 11 additions and 8 deletions

@ -1,6 +1,6 @@
from tensorflow import nest from tensorflow import nest
from keras_core import backend import keras_core.backend
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.random.seed_generator import SeedGenerator from keras_core.random.seed_generator import SeedGenerator
from keras_core.utils import backend_utils from keras_core.utils import backend_utils
@ -22,7 +22,7 @@ class TFDataLayer(Layer):
def __call__(self, inputs, **kwargs): def __call__(self, inputs, **kwargs):
if backend_utils.in_tf_graph() and not isinstance( if backend_utils.in_tf_graph() and not isinstance(
inputs, backend.KerasTensor inputs, keras_core.backend.KerasTensor
): ):
# We're in a TF graph, e.g. a tf.data pipeline. # We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow") self.backend.set_backend("tensorflow")
@ -47,7 +47,7 @@ class TFDataLayer(Layer):
@tracking.no_automatic_dependency_tracking @tracking.no_automatic_dependency_tracking
def _get_seed_generator(self, backend=None): def _get_seed_generator(self, backend=None):
if backend is None or backend == self.backend._backend: if backend is None or backend == keras_core.backend.backend():
return self.generator return self.generator
if not hasattr(self, "_backend_generators"): if not hasattr(self, "_backend_generators"):
self._backend_generators = {} self._backend_generators = {}

@ -2,6 +2,7 @@ import random as python_random
import numpy as np import numpy as np
import keras_core.backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
@ -30,9 +31,9 @@ class SeedGenerator:
if kwargs: if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}") raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
if custom_backend is not None: if custom_backend is not None:
backend = custom_backend self.backend = custom_backend
else: else:
from keras_core import backend self.backend = keras_core.backend
if seed is None: if seed is None:
seed = make_default_seed() seed = make_default_seed()
@ -43,9 +44,9 @@ class SeedGenerator:
def seed_initializer(*args, **kwargs): def seed_initializer(*args, **kwargs):
dtype = kwargs.get("dtype", None) dtype = kwargs.get("dtype", None)
return backend.convert_to_tensor([seed, 0], dtype=dtype) return self.backend.convert_to_tensor([seed, 0], dtype=dtype)
self.state = backend.Variable( self.state = self.backend.Variable(
seed_initializer, seed_initializer,
shape=(2,), shape=(2,),
dtype="uint32", dtype="uint32",
@ -65,7 +66,9 @@ def draw_seed(seed):
seed_state = seed.state seed_state = seed.state
# Use * 1 to create a copy # Use * 1 to create a copy
new_seed_value = seed_state.value * 1 new_seed_value = seed_state.value * 1
increment = convert_to_tensor(np.array([0, 1]), dtype="uint32") increment = seed.backend.convert_to_tensor(
np.array([0, 1]), dtype="uint32"
)
seed.state.assign(seed_state + increment) seed.state.assign(seed_state + increment)
return new_seed_value return new_seed_value
elif isinstance(seed, int): elif isinstance(seed, int):