Add keras_core.layers.Cropping3D. (#163)

This commit is contained in:
hertschuh 2023-05-12 15:56:58 -07:00 committed by Francois Chollet
parent 0e4bf48c5b
commit b26160dd30
3 changed files with 444 additions and 0 deletions

@ -94,6 +94,7 @@ from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
from keras_core.layers.reshaping.cropping1d import Cropping1D
from keras_core.layers.reshaping.cropping2d import Cropping2D
from keras_core.layers.reshaping.cropping3d import Cropping3D
from keras_core.layers.reshaping.flatten import Flatten
from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector

@ -0,0 +1,282 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Cropping3D")
class Cropping3D(Layer):
"""Cropping layer for 3D data (e.g. spatial or spatio-temporal).
Examples:
>>> input_shape = (2, 28, 28, 10, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> y = keras_core.layers.Cropping3D(cropping=(2, 4, 2))(x)
>>> y.shape
(2, 24, 20, 6, 3)
Args:
cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric cropping is applied to depth, height,
and width.
- If tuple of 3 ints: interpreted as three different symmetric
cropping values for depth, height, and width:
`(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_crop, right_dim1_crop), (left_dim2_crop,
right_dim2_crop), (left_dim3_crop, right_dim3_crop))`
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_axis_to_crop, second_axis_to_crop,
third_axis_to_crop, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, first_axis_to_crop, second_axis_to_crop,
third_axis_to_crop)`
Output shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_cropped_axis, second_cropped_axis,
third_cropped_axis, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, first_cropped_axis, second_cropped_axis,
third_cropped_axis)`
"""
def __init__(
self,
cropping=((1, 1), (1, 1), (1, 1)),
data_format=None,
name=None,
dtype=None,
):
super().__init__(name=name, dtype=dtype)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = (
(cropping, cropping),
(cropping, cropping),
(cropping, cropping),
)
elif hasattr(cropping, "__len__"):
if len(cropping) != 3:
raise ValueError(
f"`cropping` should have 3 elements. Received: {cropping}."
)
dim1_cropping = cropping[0]
if isinstance(dim1_cropping, int):
dim1_cropping = (dim1_cropping, dim1_cropping)
dim2_cropping = cropping[1]
if isinstance(dim2_cropping, int):
dim2_cropping = (dim2_cropping, dim2_cropping)
dim3_cropping = cropping[2]
if isinstance(dim3_cropping, int):
dim3_cropping = (dim3_cropping, dim3_cropping)
self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
else:
raise ValueError(
"`cropping` should be either an int, a tuple of 3 ints "
"(symmetric_dim1_crop, symmetric_dim2_crop, "
"symmetric_dim3_crop), "
"or a tuple of 3 tuples of 2 ints "
"((left_dim1_crop, right_dim1_crop),"
" (left_dim2_crop, right_dim2_crop),"
" (left_dim3_crop, right_dim2_crop)). "
f"Received: {cropping}."
)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
spatial_dims = list(input_shape[2:5])
else:
spatial_dims = list(input_shape[1:4])
for index in range(0, 3):
if spatial_dims[index] is None:
continue
spatial_dims[index] -= sum(self.cropping[index])
if spatial_dims[index] <= 0:
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"input_shape={input_shape}, cropping={self.cropping}"
)
if self.data_format == "channels_first":
return (input_shape[0], input_shape[1], *spatial_dims)
else:
return (input_shape[0], *spatial_dims, input_shape[4])
def call(self, inputs):
if self.data_format == "channels_first":
spatial_dims = list(inputs.shape[2:5])
else:
spatial_dims = list(inputs.shape[1:4])
for index in range(0, 3):
if spatial_dims[index] is None:
continue
spatial_dims[index] -= sum(self.cropping[index])
if spatial_dims[index] <= 0:
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"inputs.shape={inputs.shape}, cropping={self.cropping}"
)
if self.data_format == "channels_first":
if (
self.cropping[0][1]
== self.cropping[1][1]
== self.cropping[2][1]
== 0
):
return inputs[
:,
:,
self.cropping[0][0] :,
self.cropping[1][0] :,
self.cropping[2][0] :,
]
elif self.cropping[0][1] == self.cropping[1][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] :,
self.cropping[1][0] :,
self.cropping[2][0] : -self.cropping[2][1],
]
elif self.cropping[1][1] == self.cropping[2][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
self.cropping[2][0] :,
]
elif self.cropping[0][1] == self.cropping[2][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] :,
]
elif self.cropping[0][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] : -self.cropping[2][1],
]
elif self.cropping[1][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
self.cropping[2][0] : -self.cropping[2][1],
]
elif self.cropping[2][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] :,
]
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] : -self.cropping[2][1],
]
else:
if (
self.cropping[0][1]
== self.cropping[1][1]
== self.cropping[2][1]
== 0
):
return inputs[
:,
self.cropping[0][0] :,
self.cropping[1][0] :,
self.cropping[2][0] :,
:,
]
elif self.cropping[0][1] == self.cropping[1][1] == 0:
return inputs[
:,
self.cropping[0][0] :,
self.cropping[1][0] :,
self.cropping[2][0] : -self.cropping[2][1],
:,
]
elif self.cropping[1][1] == self.cropping[2][1] == 0:
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
self.cropping[2][0] :,
:,
]
elif self.cropping[0][1] == self.cropping[2][1] == 0:
return inputs[
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] :,
:,
]
elif self.cropping[0][1] == 0:
return inputs[
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] : -self.cropping[2][1],
:,
]
elif self.cropping[1][1] == 0:
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
self.cropping[2][0] : -self.cropping[2][1],
:,
]
elif self.cropping[2][1] == 0:
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] :,
:,
]
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
self.cropping[2][0] : -self.cropping[2][1],
:,
]
def get_config(self):
config = {"cropping": self.cropping, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,161 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
(
{"dim1_cropping": (1, 2), "dim1_expected": (1, 5)}, # both
{"dim1_cropping": (0, 2), "dim1_expected": (0, 5)}, # left only
{"dim1_cropping": (1, 0), "dim1_expected": (1, 7)}, # right only
),
(
{"dim2_cropping": (3, 4), "dim2_expected": (3, 5)}, # both
{"dim2_cropping": (0, 4), "dim2_expected": (0, 5)}, # left only
{"dim2_cropping": (3, 0), "dim2_expected": (3, 9)}, # right only
),
(
{"dim3_cropping": (5, 6), "dim3_expected": (5, 7)}, # both
{"dim3_cropping": (0, 6), "dim3_expected": (0, 7)}, # left only
{"dim3_cropping": (5, 0), "dim3_expected": (5, 13)}, # right only
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d(
self,
dim1_cropping,
dim2_cropping,
dim3_cropping,
data_format,
dim1_expected,
dim2_expected,
dim3_expected,
):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9, 13)
expected_output = ops.convert_to_tensor(
inputs[
:,
:,
dim1_expected[0] : dim1_expected[1],
dim2_expected[0] : dim2_expected[1],
dim3_expected[0] : dim3_expected[1],
]
)
else:
inputs = np.random.rand(3, 7, 9, 13, 5)
expected_output = ops.convert_to_tensor(
inputs[
:,
dim1_expected[0] : dim1_expected[1],
dim2_expected[0] : dim2_expected[1],
dim3_expected[0] : dim3_expected[1],
:,
]
)
cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
self.run_layer_test(
layers.Cropping3D,
init_kwargs={"cropping": cropping, "data_format": data_format},
input_data=inputs,
expected_output=expected_output,
)
@parameterized.product(
(
# same cropping values with 3 tuples
{
"cropping": ((2, 2), (2, 2), (2, 2)),
"expected": ((2, 5), (2, 7), (2, 11)),
},
# same cropping values with 1 tuple
{"cropping": (2, 2, 2), "expected": ((2, 5), (2, 7), (2, 11))},
# same cropping values with an integer
{"cropping": 2, "expected": ((2, 5), (2, 7), (2, 11))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d_with_same_cropping(
self, cropping, data_format, expected
):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9, 13)
expected_output = ops.convert_to_tensor(
inputs[
:,
:,
expected[0][0] : expected[0][1],
expected[1][0] : expected[1][1],
expected[2][0] : expected[2][1],
]
)
else:
inputs = np.random.rand(3, 7, 9, 13, 5)
expected_output = ops.convert_to_tensor(
inputs[
:,
expected[0][0] : expected[0][1],
expected[1][0] : expected[1][1],
expected[2][0] : expected[2][1],
:,
]
)
self.run_layer_test(
layers.Cropping3D,
init_kwargs={"cropping": cropping, "data_format": data_format},
input_data=inputs,
expected_output=expected_output,
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 7, 9, 13, 5))
permuted = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 4, 2, 2, 5))
@parameterized.product(
(
{"cropping": ((3, 6), (0, 0), (0, 0))},
{"cropping": ((0, 0), (5, 8), (0, 0))},
{"cropping": ((0, 0), (0, 0), (7, 6))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_3d_errors_if_cropping_more_than_available(
self, cropping, data_format
):
input_layer = layers.Input(batch_shape=(3, 7, 9, 13, 5))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=cropping, data_format=data_format)(
input_layer
)
def test_cropping_3d_errors_if_cropping_argument_invalid(self):
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1,))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1, 2))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping=(1, 2, 3, 4))
with self.assertRaises(ValueError):
layers.Cropping3D(cropping="1")