Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
8d63604975
commit
ff037a7ff8
@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import keras_core
|
||||
from keras_core import testing
|
||||
from keras_core.operations import numpy as knp
|
||||
from keras_core.random import random
|
||||
@ -61,3 +63,23 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
|
||||
x_res = random.dropout(x, rate=0.8, seed=0)
|
||||
self.assertGreater(knp.max(x_res), knp.max(x))
|
||||
self.assertGreater(knp.sum(x_res == 0), 2)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
keras_core.backend.backend() != "jax",
|
||||
reason="This test requires `jax` as the backend.",
|
||||
)
|
||||
def test_dropout_jax_jit_stateless(self):
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
x = knp.ones(3)
|
||||
|
||||
@jax.jit
|
||||
def train_step(x):
|
||||
with keras_core.backend.StatelessScope():
|
||||
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
|
||||
return x
|
||||
|
||||
keras_core.utils.traceback_utils.disable_traceback_filtering()
|
||||
x = train_step(x)
|
||||
assert isinstance(x, jnp.ndarray)
|
||||
|
@ -16,7 +16,10 @@ class SeedGenerator:
|
||||
)
|
||||
|
||||
def seed_initializer(*args, **kwargs):
|
||||
return [seed, 0]
|
||||
from keras_core.backend import convert_to_tensor
|
||||
|
||||
dtype = kwargs.get("dtype", None)
|
||||
return convert_to_tensor([seed, 0], dtype=dtype)
|
||||
|
||||
self.state = Variable(
|
||||
seed_initializer,
|
||||
|
Loading…
Reference in New Issue
Block a user