Use SeedGenerator backend when creating variable (#439)
This commit is contained in:
parent
c69d163133
commit
60ca4b1799
@ -1,6 +1,6 @@
|
|||||||
from tensorflow import nest
|
from tensorflow import nest
|
||||||
|
|
||||||
from keras_core import backend
|
import keras_core.backend
|
||||||
from keras_core.layers.layer import Layer
|
from keras_core.layers.layer import Layer
|
||||||
from keras_core.random.seed_generator import SeedGenerator
|
from keras_core.random.seed_generator import SeedGenerator
|
||||||
from keras_core.utils import backend_utils
|
from keras_core.utils import backend_utils
|
||||||
@ -22,7 +22,7 @@ class TFDataLayer(Layer):
|
|||||||
|
|
||||||
def __call__(self, inputs, **kwargs):
|
def __call__(self, inputs, **kwargs):
|
||||||
if backend_utils.in_tf_graph() and not isinstance(
|
if backend_utils.in_tf_graph() and not isinstance(
|
||||||
inputs, backend.KerasTensor
|
inputs, keras_core.backend.KerasTensor
|
||||||
):
|
):
|
||||||
# We're in a TF graph, e.g. a tf.data pipeline.
|
# We're in a TF graph, e.g. a tf.data pipeline.
|
||||||
self.backend.set_backend("tensorflow")
|
self.backend.set_backend("tensorflow")
|
||||||
@ -47,7 +47,7 @@ class TFDataLayer(Layer):
|
|||||||
|
|
||||||
@tracking.no_automatic_dependency_tracking
|
@tracking.no_automatic_dependency_tracking
|
||||||
def _get_seed_generator(self, backend=None):
|
def _get_seed_generator(self, backend=None):
|
||||||
if backend is None or backend == self.backend._backend:
|
if backend is None or backend == keras_core.backend.backend():
|
||||||
return self.generator
|
return self.generator
|
||||||
if not hasattr(self, "_backend_generators"):
|
if not hasattr(self, "_backend_generators"):
|
||||||
self._backend_generators = {}
|
self._backend_generators = {}
|
||||||
|
@ -2,6 +2,7 @@ import random as python_random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import keras_core.backend
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
|
|
||||||
|
|
||||||
@ -30,9 +31,9 @@ class SeedGenerator:
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
|
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
|
||||||
if custom_backend is not None:
|
if custom_backend is not None:
|
||||||
backend = custom_backend
|
self.backend = custom_backend
|
||||||
else:
|
else:
|
||||||
from keras_core import backend
|
self.backend = keras_core.backend
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = make_default_seed()
|
seed = make_default_seed()
|
||||||
@ -43,9 +44,9 @@ class SeedGenerator:
|
|||||||
|
|
||||||
def seed_initializer(*args, **kwargs):
|
def seed_initializer(*args, **kwargs):
|
||||||
dtype = kwargs.get("dtype", None)
|
dtype = kwargs.get("dtype", None)
|
||||||
return backend.convert_to_tensor([seed, 0], dtype=dtype)
|
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)
|
||||||
|
|
||||||
self.state = backend.Variable(
|
self.state = self.backend.Variable(
|
||||||
seed_initializer,
|
seed_initializer,
|
||||||
shape=(2,),
|
shape=(2,),
|
||||||
dtype="uint32",
|
dtype="uint32",
|
||||||
@ -65,7 +66,9 @@ def draw_seed(seed):
|
|||||||
seed_state = seed.state
|
seed_state = seed.state
|
||||||
# Use * 1 to create a copy
|
# Use * 1 to create a copy
|
||||||
new_seed_value = seed_state.value * 1
|
new_seed_value = seed_state.value * 1
|
||||||
increment = convert_to_tensor(np.array([0, 1]), dtype="uint32")
|
increment = seed.backend.convert_to_tensor(
|
||||||
|
np.array([0, 1]), dtype="uint32"
|
||||||
|
)
|
||||||
seed.state.assign(seed_state + increment)
|
seed.state.assign(seed_state + increment)
|
||||||
return new_seed_value
|
return new_seed_value
|
||||||
elif isinstance(seed, int):
|
elif isinstance(seed, int):
|
||||||
|
Loading…
Reference in New Issue
Block a user