Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-06-03 10:36:42 -07:00
parent 8d63604975
commit ff037a7ff8
2 changed files with 26 additions and 1 deletions

@ -1,6 +1,8 @@
import numpy as np
import pytest
from absl.testing import parameterized
import keras_core
from keras_core import testing
from keras_core.operations import numpy as knp
from keras_core.random import random
@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
x_res = random.dropout(x, rate=0.8, seed=0)
self.assertGreater(knp.max(x_res), knp.max(x))
self.assertGreater(knp.sum(x_res == 0), 2)
@pytest.mark.skipif(
keras_core.backend.backend() != "jax",
reason="This test requires `jax` as the backend.",
)
def test_dropout_jax_jit_stateless(self):
import jax
import jax.numpy as jnp
x = knp.ones(3)
@jax.jit
def train_step(x):
with keras_core.backend.StatelessScope():
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
return x
keras_core.utils.traceback_utils.disable_traceback_filtering()
x = train_step(x)
assert isinstance(x, jnp.ndarray)

@ -16,7 +16,10 @@ class SeedGenerator:
)
def seed_initializer(*args, **kwargs):
return [seed, 0]
from keras_core.backend import convert_to_tensor
dtype = kwargs.get("dtype", None)
return convert_to_tensor([seed, 0], dtype=dtype)
self.state = Variable(
seed_initializer,