From 60ca4b1799176f7548c5cbf259783f10ffe974b2 Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Tue, 11 Jul 2023 22:33:54 +0530 Subject: [PATCH] Use SeedGenerator backend when creating variable (#439) --- keras_core/layers/preprocessing/tf_data_layer.py | 6 +++--- keras_core/random/seed_generator.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/keras_core/layers/preprocessing/tf_data_layer.py b/keras_core/layers/preprocessing/tf_data_layer.py index 9a5fb714c..7dab47231 100644 --- a/keras_core/layers/preprocessing/tf_data_layer.py +++ b/keras_core/layers/preprocessing/tf_data_layer.py @@ -1,6 +1,6 @@ from tensorflow import nest -from keras_core import backend +import keras_core.backend from keras_core.layers.layer import Layer from keras_core.random.seed_generator import SeedGenerator from keras_core.utils import backend_utils @@ -22,7 +22,7 @@ class TFDataLayer(Layer): def __call__(self, inputs, **kwargs): if backend_utils.in_tf_graph() and not isinstance( - inputs, backend.KerasTensor + inputs, keras_core.backend.KerasTensor ): # We're in a TF graph, e.g. a tf.data pipeline. self.backend.set_backend("tensorflow") @@ -47,7 +47,7 @@ class TFDataLayer(Layer): @tracking.no_automatic_dependency_tracking def _get_seed_generator(self, backend=None): - if backend is None or backend == self.backend._backend: + if backend is None or backend == keras_core.backend.backend(): return self.generator if not hasattr(self, "_backend_generators"): self._backend_generators = {} diff --git a/keras_core/random/seed_generator.py b/keras_core/random/seed_generator.py index 11d0e9eee..5acbac252 100644 --- a/keras_core/random/seed_generator.py +++ b/keras_core/random/seed_generator.py @@ -2,6 +2,7 @@ import random as python_random import numpy as np +import keras_core.backend from keras_core.api_export import keras_core_export @@ -30,9 +31,9 @@ class SeedGenerator: if kwargs: raise ValueError(f"Unrecognized keyword arguments: {kwargs}") if custom_backend is not None: - backend = custom_backend + self.backend = custom_backend else: - from keras_core import backend + self.backend = keras_core.backend if seed is None: seed = make_default_seed() @@ -43,9 +44,9 @@ class SeedGenerator: def seed_initializer(*args, **kwargs): dtype = kwargs.get("dtype", None) - return backend.convert_to_tensor([seed, 0], dtype=dtype) + return self.backend.convert_to_tensor([seed, 0], dtype=dtype) - self.state = backend.Variable( + self.state = self.backend.Variable( seed_initializer, shape=(2,), dtype="uint32", @@ -65,7 +66,9 @@ def draw_seed(seed): seed_state = seed.state # Use * 1 to create a copy new_seed_value = seed_state.value * 1 - increment = convert_to_tensor(np.array([0, 1]), dtype="uint32") + increment = seed.backend.convert_to_tensor( + np.array([0, 1]), dtype="uint32" + ) seed.state.assign(seed_state + increment) return new_seed_value elif isinstance(seed, int):