Fix Dropout not working with jax backend when using jit + StateScope in training mode (#249)
* Fix Dropout not working with jax backend when using jit + StateScope in training mode * Use pytest.mark.skipif and remove the ticket number from the test
This commit is contained in:
parent
913405c237
commit
9e58d0d0fb
@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import keras_core
|
||||
from keras_core import testing
|
||||
from keras_core.operations import numpy as knp
|
||||
from keras_core.random import random
|
||||
@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
|
||||
x_res = random.dropout(x, rate=0.8, seed=0)
|
||||
self.assertGreater(knp.max(x_res), knp.max(x))
|
||||
self.assertGreater(knp.sum(x_res == 0), 2)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
keras_core.backend.backend() != "jax",
|
||||
reason="This test requires `jax` as the backend.",
|
||||
)
|
||||
def test_dropout_jax_jit_stateless(self):
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
x = knp.ones(3)
|
||||
|
||||
@jax.jit
|
||||
def train_step(x):
|
||||
with keras_core.backend.StatelessScope():
|
||||
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
|
||||
return x
|
||||
|
||||
keras_core.utils.traceback_utils.disable_traceback_filtering()
|
||||
x = train_step(x)
|
||||
assert isinstance(x, jnp.ndarray)
|
||||
|
@ -16,7 +16,10 @@ class SeedGenerator:
|
||||
)
|
||||
|
||||
def seed_initializer(*args, **kwargs):
|
||||
return [seed, 0]
|
||||
from keras_core.backend import convert_to_tensor
|
||||
|
||||
dtype = kwargs.get("dtype", None)
|
||||
return convert_to_tensor([seed, 0], dtype=dtype)
|
||||
|
||||
self.state = Variable(
|
||||
seed_initializer,
|
||||
|
Loading…
Reference in New Issue
Block a user