Standardize annotation to skip tests that require dynamic shapes. (#65)

This commit is contained in:
hertschuh 2023-05-01 11:14:28 -07:00 committed by Francois Chollet
parent b9b4749e33
commit a3eeae0249
3 changed files with 11 additions and 10 deletions

@ -30,7 +30,8 @@ class DropoutTest(testing.TestCase):
self.assertAllClose(np.max(outputs), 2.0)
@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="Requires dynamic shapes"
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_dropout_partial_noise_shape_dynamic(self):
inputs = np.ones((20, 5, 10))

@ -2,15 +2,15 @@ import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
from keras_core import testing
from keras_core.backend import backend
from keras_core.backend.keras_tensor import KerasTensor
from keras_core.operations import nn as knn
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Dynamic shapes are only supported in TensorFlow backend.",
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NNOpsDynamicShapeTest(testing.TestCase):
def test_relu(self):

@ -2,8 +2,8 @@ import numpy as np
import pytest
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import backend
from keras_core import testing
from keras_core.backend import backend
from keras_core.backend.keras_tensor import KerasTensor
from keras_core.operations import numpy as knp
@ -12,8 +12,8 @@ np_config.enable_numpy_behavior()
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Dynamic shapes are only supported in TensorFlow backend.",
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
def test_add(self):
@ -624,8 +624,8 @@ class NumpyTwoInputOpsStaticShapeTest(testing.TestCase):
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Dynamic shapes are only supported in TensorFlow backend.",
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
def test_mean(self):
@ -2635,7 +2635,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.meshgrid(x, y, z, indexing="ij"),
)
if backend() == "tensorflow":
if backend.backend() == "tensorflow":
# Arguments to `jax.numpy.meshgrid` must be 1D now.
x = np.ones([1, 2, 3])
y = np.ones([4, 5, 6, 6])