Fix Dropout not working with jax backend when using jit + StateScope in training mode (#249)

* Fix Dropout not working with jax backend when using jit + StateScope in training mode

* Use pytest.mark.skipif and remove the ticket number from the test
This commit is contained in:
Tirth Patel 2023-06-02 19:05:42 +00:00 committed by Francois Chollet
parent 913405c237
commit 9e58d0d0fb
2 changed files with 26 additions and 1 deletions

@ -1,6 +1,8 @@
import numpy as np
import pytest
from absl.testing import parameterized
import keras_core
from keras_core import testing
from keras_core.operations import numpy as knp
from keras_core.random import random
@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
x_res = random.dropout(x, rate=0.8, seed=0)
self.assertGreater(knp.max(x_res), knp.max(x))
self.assertGreater(knp.sum(x_res == 0), 2)
@pytest.mark.skipif(
keras_core.backend.backend() != "jax",
reason="This test requires `jax` as the backend.",
)
def test_dropout_jax_jit_stateless(self):
import jax
import jax.numpy as jnp
x = knp.ones(3)
@jax.jit
def train_step(x):
with keras_core.backend.StatelessScope():
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
return x
keras_core.utils.traceback_utils.disable_traceback_filtering()
x = train_step(x)
assert isinstance(x, jnp.ndarray)

@ -16,7 +16,10 @@ class SeedGenerator:
)
def seed_initializer(*args, **kwargs):
return [seed, 0]
from keras_core.backend import convert_to_tensor
dtype = kwargs.get("dtype", None)
return convert_to_tensor([seed, 0], dtype=dtype)
self.state = Variable(
seed_initializer,