keras/keras_core/layers/reshaping/cropping3d_test.py
Neel Kovelamudi 2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00

162 lines
5.5 KiB
Python

import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
(
{"dim1_cropping": (1, 2), "dim1_expected": (1, 5)}, # both
{"dim1_cropping": (0, 2), "dim1_expected": (0, 5)}, # left only
{"dim1_cropping": (1, 0), "dim1_expected": (1, 7)}, # right only
),
(
{"dim2_cropping": (3, 4), "dim2_expected": (3, 5)}, # both
{"dim2_cropping": (0, 4), "dim2_expected": (0, 5)}, # left only
{"dim2_cropping": (3, 0), "dim2_expected": (3, 9)}, # right only
),
(
{"dim3_cropping": (5, 6), "dim3_expected": (5, 7)}, # both
{"dim3_cropping": (0, 6), "dim3_expected": (0, 7)}, # left only
{"dim3_cropping": (5, 0), "dim3_expected": (5, 13)}, # right only
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d(
self,
dim1_cropping,
dim2_cropping,
dim3_cropping,
data_format,
dim1_expected,
dim2_expected,
dim3_expected,
):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9, 13)
expected_output = ops.convert_to_tensor(
inputs[
:,
:,
dim1_expected[0] : dim1_expected[1],
dim2_expected[0] : dim2_expected[1],
dim3_expected[0] : dim3_expected[1],
]
)
else:
inputs = np.random.rand(3, 7, 9, 13, 5)
expected_output = ops.convert_to_tensor(
inputs[
:,
dim1_expected[0] : dim1_expected[1],
dim2_expected[0] : dim2_expected[1],
dim3_expected[0] : dim3_expected[1],
:,
]
)
cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
self.run_layer_test(
layers.Cropping3D,
init_kwargs={"cropping": cropping, "data_format": data_format},
input_data=inputs,
expected_output=expected_output,
)
@parameterized.product(
(
# same cropping values with 3 tuples
{
"cropping": ((2, 2), (2, 2), (2, 2)),
"expected": ((2, 5), (2, 7), (2, 11)),
},
# same cropping values with 1 tuple
{"cropping": (2, 2, 2), "expected": ((2, 5), (2, 7), (2, 11))},
# same cropping values with an integer
{"cropping": 2, "expected": ((2, 5), (2, 7), (2, 11))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d_with_same_cropping(
self, cropping, data_format, expected
):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9, 13)
expected_output = ops.convert_to_tensor(
inputs[
:,
:,
expected[0][0] : expected[0][1],
expected[1][0] : expected[1][1],
expected[2][0] : expected[2][1],
]
)
else:
inputs = np.random.rand(3, 7, 9, 13, 5)
expected_output = ops.convert_to_tensor(
inputs[
:,
expected[0][0] : expected[0][1],
expected[1][0] : expected[1][1],
expected[2][0] : expected[2][1],
:,
]
)
self.run_layer_test(
layers.Cropping3D,
init_kwargs={"cropping": cropping, "data_format": data_format},
input_data=inputs,
expected_output=expected_output,
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 7, 9, 13, 5))
permuted = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 4, 2, 2, 5))
@parameterized.product(
(
{"cropping": ((3, 6), (0, 0), (0, 0))},
{"cropping": ((0, 0), (5, 8), (0, 0))},
{"cropping": ((0, 0), (0, 0), (7, 6))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d_errors_if_cropping_more_than_available(
self, cropping, data_format
):
input_layer = layers.Input(batch_shape=(3, 7, 9, 13, 5))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=cropping, data_format=data_format)(
input_layer
)
def test_cropping_3d_errors_if_cropping_argument_invalid(self):
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1,))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1, 2))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1, 2, 3, 4))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping="1")