Use SeedGenerator backend when creating variable (#439)
This commit is contained in:
parent
c69d163133
commit
60ca4b1799
@ -1,6 +1,6 @@
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
import keras_core.backend
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.random.seed_generator import SeedGenerator
|
||||
from keras_core.utils import backend_utils
|
||||
@ -22,7 +22,7 @@ class TFDataLayer(Layer):
|
||||
|
||||
def __call__(self, inputs, **kwargs):
|
||||
if backend_utils.in_tf_graph() and not isinstance(
|
||||
inputs, backend.KerasTensor
|
||||
inputs, keras_core.backend.KerasTensor
|
||||
):
|
||||
# We're in a TF graph, e.g. a tf.data pipeline.
|
||||
self.backend.set_backend("tensorflow")
|
||||
@ -47,7 +47,7 @@ class TFDataLayer(Layer):
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def _get_seed_generator(self, backend=None):
|
||||
if backend is None or backend == self.backend._backend:
|
||||
if backend is None or backend == keras_core.backend.backend():
|
||||
return self.generator
|
||||
if not hasattr(self, "_backend_generators"):
|
||||
self._backend_generators = {}
|
||||
|
@ -2,6 +2,7 @@ import random as python_random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import keras_core.backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
|
||||
@ -30,9 +31,9 @@ class SeedGenerator:
|
||||
if kwargs:
|
||||
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
|
||||
if custom_backend is not None:
|
||||
backend = custom_backend
|
||||
self.backend = custom_backend
|
||||
else:
|
||||
from keras_core import backend
|
||||
self.backend = keras_core.backend
|
||||
|
||||
if seed is None:
|
||||
seed = make_default_seed()
|
||||
@ -43,9 +44,9 @@ class SeedGenerator:
|
||||
|
||||
def seed_initializer(*args, **kwargs):
|
||||
dtype = kwargs.get("dtype", None)
|
||||
return backend.convert_to_tensor([seed, 0], dtype=dtype)
|
||||
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)
|
||||
|
||||
self.state = backend.Variable(
|
||||
self.state = self.backend.Variable(
|
||||
seed_initializer,
|
||||
shape=(2,),
|
||||
dtype="uint32",
|
||||
@ -65,7 +66,9 @@ def draw_seed(seed):
|
||||
seed_state = seed.state
|
||||
# Use * 1 to create a copy
|
||||
new_seed_value = seed_state.value * 1
|
||||
increment = convert_to_tensor(np.array([0, 1]), dtype="uint32")
|
||||
increment = seed.backend.convert_to_tensor(
|
||||
np.array([0, 1]), dtype="uint32"
|
||||
)
|
||||
seed.state.assign(seed_state + increment)
|
||||
return new_seed_value
|
||||
elif isinstance(seed, int):
|
||||
|
Loading…
Reference in New Issue
Block a user