Exclude jax from RNN tests for now

This commit is contained in:
Francois Chollet 2023-05-10 21:41:31 -07:00
parent 4f3274bc5f
commit 13039c01b0

@ -1,5 +1,7 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
@ -66,6 +68,10 @@ class TwoStatesRNNCell(layers.Layer):
return output, [output_1, output_2]
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Only implemented for TF for now.",
)
class RNNTest(testing.TestCase):
def test_compute_output_shape_single_state(self):
sequence = np.ones((3, 4, 5))