From 13039c01b05ee23e27d24d6a6ae7c0cb562c0f2e Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 10 May 2023 21:41:31 -0700 Subject: [PATCH] Exclude jax from RNN tests for now --- keras_core/layers/rnn/rnn_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/keras_core/layers/rnn/rnn_test.py b/keras_core/layers/rnn/rnn_test.py index 841fcbe49..5f429650b 100644 --- a/keras_core/layers/rnn/rnn_test.py +++ b/keras_core/layers/rnn/rnn_test.py @@ -1,5 +1,7 @@ import numpy as np +import pytest +from keras_core import backend from keras_core import layers from keras_core import operations as ops from keras_core import testing @@ -66,6 +68,10 @@ class TwoStatesRNNCell(layers.Layer): return output, [output_1, output_2] +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only implemented for TF for now.", +) class RNNTest(testing.TestCase): def test_compute_output_shape_single_state(self): sequence = np.ones((3, 4, 5))