Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
8d63604975
commit
ff037a7ff8
@ -1,6 +1,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
import keras_core
|
||||||
from keras_core import testing
|
from keras_core import testing
|
||||||
from keras_core.operations import numpy as knp
|
from keras_core.operations import numpy as knp
|
||||||
from keras_core.random import random
|
from keras_core.random import random
|
||||||
@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
|
|||||||
x_res = random.dropout(x, rate=0.8, seed=0)
|
x_res = random.dropout(x, rate=0.8, seed=0)
|
||||||
self.assertGreater(knp.max(x_res), knp.max(x))
|
self.assertGreater(knp.max(x_res), knp.max(x))
|
||||||
self.assertGreater(knp.sum(x_res == 0), 2)
|
self.assertGreater(knp.sum(x_res == 0), 2)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
keras_core.backend.backend() != "jax",
|
||||||
|
reason="This test requires `jax` as the backend.",
|
||||||
|
)
|
||||||
|
def test_dropout_jax_jit_stateless(self):
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
x = knp.ones(3)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def train_step(x):
|
||||||
|
with keras_core.backend.StatelessScope():
|
||||||
|
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
keras_core.utils.traceback_utils.disable_traceback_filtering()
|
||||||
|
x = train_step(x)
|
||||||
|
assert isinstance(x, jnp.ndarray)
|
||||||
|
@ -16,7 +16,10 @@ class SeedGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def seed_initializer(*args, **kwargs):
|
def seed_initializer(*args, **kwargs):
|
||||||
return [seed, 0]
|
from keras_core.backend import convert_to_tensor
|
||||||
|
|
||||||
|
dtype = kwargs.get("dtype", None)
|
||||||
|
return convert_to_tensor([seed, 0], dtype=dtype)
|
||||||
|
|
||||||
self.state = Variable(
|
self.state = Variable(
|
||||||
seed_initializer,
|
seed_initializer,
|
||||||
|
Loading…
Reference in New Issue
Block a user