2023-04-09 19:21:45 +00:00
|
|
|
import random as python_random
|
|
|
|
|
2023-05-15 22:01:46 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-13 04:07:17 +00:00
|
|
|
|
2023-05-15 22:01:46 +00:00
|
|
|
|
|
|
|
@keras_core_export("keras_core.random.SeedGenerator")
|
2023-04-13 04:07:17 +00:00
|
|
|
class SeedGenerator:
|
2023-04-09 19:21:45 +00:00
|
|
|
def __init__(self, seed):
|
|
|
|
from keras_core.backend import Variable
|
|
|
|
|
2023-04-19 02:22:13 +00:00
|
|
|
if seed is None:
|
|
|
|
seed = make_default_seed()
|
2023-04-09 19:21:45 +00:00
|
|
|
if not isinstance(seed, int):
|
|
|
|
raise ValueError(
|
|
|
|
"Argument `seed` must be an integer. " f"Received: seed={seed}"
|
|
|
|
)
|
2023-04-26 17:29:40 +00:00
|
|
|
|
|
|
|
def seed_initializer(*args, **kwargs):
|
|
|
|
return [seed, 0]
|
|
|
|
|
|
|
|
self.state = Variable(
|
|
|
|
seed_initializer, shape=(2,), dtype="uint32", trainable=False
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def make_default_seed():
|
|
|
|
return python_random.randint(1, int(1e9))
|
|
|
|
|
|
|
|
|
|
|
|
def draw_seed(seed):
|
|
|
|
from keras_core.backend import convert_to_tensor
|
|
|
|
|
2023-04-13 04:07:17 +00:00
|
|
|
if isinstance(seed, SeedGenerator):
|
2023-05-04 21:52:00 +00:00
|
|
|
# Use * 1 to create a copy
|
|
|
|
new_seed_value = seed.state.value * 1
|
2023-04-12 18:00:14 +00:00
|
|
|
seed.state.assign(
|
|
|
|
seed.state + convert_to_tensor([0, 1], dtype="uint32")
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
return new_seed_value
|
|
|
|
elif isinstance(seed, int):
|
|
|
|
return convert_to_tensor([seed, 0], dtype="uint32")
|
|
|
|
elif seed is None:
|
|
|
|
return convert_to_tensor([make_default_seed(), 0], dtype="uint32")
|
|
|
|
raise ValueError(
|
|
|
|
"Argument `seed` must be either an integer "
|
2023-04-13 04:07:17 +00:00
|
|
|
"or an instance of `SeedGenerator`. "
|
2023-04-09 19:21:45 +00:00
|
|
|
f"Received: seed={seed} (of type {type(seed)})"
|
|
|
|
)
|