diff --git a/keras_core/layers/rnn/rnn_test.py b/keras_core/layers/rnn/rnn_test.py index 841fcbe49..5f429650b 100644 --- a/keras_core/layers/rnn/rnn_test.py +++ b/keras_core/layers/rnn/rnn_test.py @@ -1,5 +1,7 @@ import numpy as np +import pytest +from keras_core import backend from keras_core import layers from keras_core import operations as ops from keras_core import testing @@ -66,6 +68,10 @@ class TwoStatesRNNCell(layers.Layer): return output, [output_1, output_2] +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only implemented for TF for now.", +) class RNNTest(testing.TestCase): def test_compute_output_shape_single_state(self): sequence = np.ones((3, 4, 5))