Exclude jax from RNN tests for now
This commit is contained in:
parent
4f3274bc5f
commit
13039c01b0
@ -1,5 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import testing
|
||||
@ -66,6 +68,10 @@ class TwoStatesRNNCell(layers.Layer):
|
||||
return output, [output_1, output_2]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "tensorflow",
|
||||
reason="Only implemented for TF for now.",
|
||||
)
|
||||
class RNNTest(testing.TestCase):
|
||||
def test_compute_output_shape_single_state(self):
|
||||
sequence = np.ones((3, 4, 5))
|
||||
|
Loading…
Reference in New Issue
Block a user