diff --git a/keras_core/random/random_test.py b/keras_core/random/random_test.py index 27007892b..009de3798 100644 --- a/keras_core/random/random_test.py +++ b/keras_core/random/random_test.py @@ -1,6 +1,8 @@ import numpy as np +import pytest from absl.testing import parameterized +import keras_core from keras_core import testing from keras_core.operations import numpy as knp from keras_core.random import random @@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase): x_res = random.dropout(x, rate=0.8, seed=0) self.assertGreater(knp.max(x_res), knp.max(x)) self.assertGreater(knp.sum(x_res == 0), 2) + + @pytest.mark.skipif( + keras_core.backend.backend() != "jax", + reason="This test requires `jax` as the backend.", + ) + def test_dropout_jax_jit_stateless(self): + import jax + import jax.numpy as jnp + + x = knp.ones(3) + + @jax.jit + def train_step(x): + with keras_core.backend.StatelessScope(): + x = keras_core.layers.Dropout(rate=0.1)(x, training=True) + return x + + keras_core.utils.traceback_utils.disable_traceback_filtering() + x = train_step(x) + assert isinstance(x, jnp.ndarray) diff --git a/keras_core/random/seed_generator.py b/keras_core/random/seed_generator.py index 645070e35..f94b104c7 100644 --- a/keras_core/random/seed_generator.py +++ b/keras_core/random/seed_generator.py @@ -16,7 +16,10 @@ class SeedGenerator: ) def seed_initializer(*args, **kwargs): - return [seed, 0] + from keras_core.backend import convert_to_tensor + + dtype = kwargs.get("dtype", None) + return convert_to_tensor([seed, 0], dtype=dtype) self.state = Variable( seed_initializer,