Add keras_core.layers.Cropping3D
. (#163)
This commit is contained in:
parent
0e4bf48c5b
commit
b26160dd30
@ -94,6 +94,7 @@ from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
|
||||
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
|
||||
from keras_core.layers.reshaping.cropping1d import Cropping1D
|
||||
from keras_core.layers.reshaping.cropping2d import Cropping2D
|
||||
from keras_core.layers.reshaping.cropping3d import Cropping3D
|
||||
from keras_core.layers.reshaping.flatten import Flatten
|
||||
from keras_core.layers.reshaping.permute import Permute
|
||||
from keras_core.layers.reshaping.repeat_vector import RepeatVector
|
||||
|
282
keras_core/layers/reshaping/cropping3d.py
Normal file
282
keras_core/layers/reshaping/cropping3d.py
Normal file
@ -0,0 +1,282 @@
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Cropping3D")
|
||||
class Cropping3D(Layer):
|
||||
"""Cropping layer for 3D data (e.g. spatial or spatio-temporal).
|
||||
|
||||
Examples:
|
||||
|
||||
>>> input_shape = (2, 28, 28, 10, 3)
|
||||
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
|
||||
>>> y = keras_core.layers.Cropping3D(cropping=(2, 4, 2))(x)
|
||||
>>> y.shape
|
||||
(2, 24, 20, 6, 3)
|
||||
|
||||
Args:
|
||||
cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
|
||||
- If int: the same symmetric cropping is applied to depth, height,
|
||||
and width.
|
||||
- If tuple of 3 ints: interpreted as three different symmetric
|
||||
cropping values for depth, height, and width:
|
||||
`(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.
|
||||
- If tuple of 3 tuples of 2 ints: interpreted as
|
||||
`((left_dim1_crop, right_dim1_crop), (left_dim2_crop,
|
||||
right_dim2_crop), (left_dim3_crop, right_dim3_crop))`
|
||||
data_format: A string, one of `"channels_last"` (default) or
|
||||
`"channels_first"`. The ordering of the dimensions in the inputs.
|
||||
`"channels_last"` corresponds to inputs with shape
|
||||
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
When unspecified, uses `image_data_format` value found in your Keras
|
||||
config file at `~/.keras/keras.json` (if exists). Defaults to
|
||||
`"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
5D tensor with shape:
|
||||
- If `data_format` is `"channels_last"`:
|
||||
`(batch_size, first_axis_to_crop, second_axis_to_crop,
|
||||
third_axis_to_crop, channels)`
|
||||
- If `data_format` is `"channels_first"`:
|
||||
`(batch_size, channels, first_axis_to_crop, second_axis_to_crop,
|
||||
third_axis_to_crop)`
|
||||
|
||||
Output shape:
|
||||
5D tensor with shape:
|
||||
- If `data_format` is `"channels_last"`:
|
||||
`(batch_size, first_cropped_axis, second_cropped_axis,
|
||||
third_cropped_axis, channels)`
|
||||
- If `data_format` is `"channels_first"`:
|
||||
`(batch_size, channels, first_cropped_axis, second_cropped_axis,
|
||||
third_cropped_axis)`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cropping=((1, 1), (1, 1), (1, 1)),
|
||||
data_format=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
if isinstance(cropping, int):
|
||||
self.cropping = (
|
||||
(cropping, cropping),
|
||||
(cropping, cropping),
|
||||
(cropping, cropping),
|
||||
)
|
||||
elif hasattr(cropping, "__len__"):
|
||||
if len(cropping) != 3:
|
||||
raise ValueError(
|
||||
f"`cropping` should have 3 elements. Received: {cropping}."
|
||||
)
|
||||
dim1_cropping = cropping[0]
|
||||
if isinstance(dim1_cropping, int):
|
||||
dim1_cropping = (dim1_cropping, dim1_cropping)
|
||||
dim2_cropping = cropping[1]
|
||||
if isinstance(dim2_cropping, int):
|
||||
dim2_cropping = (dim2_cropping, dim2_cropping)
|
||||
dim3_cropping = cropping[2]
|
||||
if isinstance(dim3_cropping, int):
|
||||
dim3_cropping = (dim3_cropping, dim3_cropping)
|
||||
self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`cropping` should be either an int, a tuple of 3 ints "
|
||||
"(symmetric_dim1_crop, symmetric_dim2_crop, "
|
||||
"symmetric_dim3_crop), "
|
||||
"or a tuple of 3 tuples of 2 ints "
|
||||
"((left_dim1_crop, right_dim1_crop),"
|
||||
" (left_dim2_crop, right_dim2_crop),"
|
||||
" (left_dim3_crop, right_dim2_crop)). "
|
||||
f"Received: {cropping}."
|
||||
)
|
||||
self.input_spec = InputSpec(ndim=5)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if self.data_format == "channels_first":
|
||||
spatial_dims = list(input_shape[2:5])
|
||||
else:
|
||||
spatial_dims = list(input_shape[1:4])
|
||||
|
||||
for index in range(0, 3):
|
||||
if spatial_dims[index] is None:
|
||||
continue
|
||||
spatial_dims[index] -= sum(self.cropping[index])
|
||||
if spatial_dims[index] <= 0:
|
||||
raise ValueError(
|
||||
"Values in `cropping` argument should be greater than the "
|
||||
"corresponding spatial dimension of the input. Received: "
|
||||
f"input_shape={input_shape}, cropping={self.cropping}"
|
||||
)
|
||||
|
||||
if self.data_format == "channels_first":
|
||||
return (input_shape[0], input_shape[1], *spatial_dims)
|
||||
else:
|
||||
return (input_shape[0], *spatial_dims, input_shape[4])
|
||||
|
||||
def call(self, inputs):
|
||||
if self.data_format == "channels_first":
|
||||
spatial_dims = list(inputs.shape[2:5])
|
||||
else:
|
||||
spatial_dims = list(inputs.shape[1:4])
|
||||
|
||||
for index in range(0, 3):
|
||||
if spatial_dims[index] is None:
|
||||
continue
|
||||
spatial_dims[index] -= sum(self.cropping[index])
|
||||
if spatial_dims[index] <= 0:
|
||||
raise ValueError(
|
||||
"Values in `cropping` argument should be greater than the "
|
||||
"corresponding spatial dimension of the input. Received: "
|
||||
f"inputs.shape={inputs.shape}, cropping={self.cropping}"
|
||||
)
|
||||
|
||||
if self.data_format == "channels_first":
|
||||
if (
|
||||
self.cropping[0][1]
|
||||
== self.cropping[1][1]
|
||||
== self.cropping[2][1]
|
||||
== 0
|
||||
):
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] :,
|
||||
]
|
||||
elif self.cropping[0][1] == self.cropping[1][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
]
|
||||
elif self.cropping[1][1] == self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] :,
|
||||
]
|
||||
elif self.cropping[0][1] == self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] :,
|
||||
]
|
||||
elif self.cropping[0][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
]
|
||||
elif self.cropping[1][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
]
|
||||
elif self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] :,
|
||||
]
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
]
|
||||
else:
|
||||
if (
|
||||
self.cropping[0][1]
|
||||
== self.cropping[1][1]
|
||||
== self.cropping[2][1]
|
||||
== 0
|
||||
):
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] :,
|
||||
:,
|
||||
]
|
||||
elif self.cropping[0][1] == self.cropping[1][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
:,
|
||||
]
|
||||
elif self.cropping[1][1] == self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] :,
|
||||
:,
|
||||
]
|
||||
elif self.cropping[0][1] == self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] :,
|
||||
:,
|
||||
]
|
||||
elif self.cropping[0][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] :,
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
:,
|
||||
]
|
||||
elif self.cropping[1][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] :,
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
:,
|
||||
]
|
||||
elif self.cropping[2][1] == 0:
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] :,
|
||||
:,
|
||||
]
|
||||
return inputs[
|
||||
:,
|
||||
self.cropping[0][0] : -self.cropping[0][1],
|
||||
self.cropping[1][0] : -self.cropping[1][1],
|
||||
self.cropping[2][0] : -self.cropping[2][1],
|
||||
:,
|
||||
]
|
||||
|
||||
def get_config(self):
|
||||
config = {"cropping": self.cropping, "data_format": self.data_format}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
161
keras_core/layers/reshaping/cropping3d_test.py
Normal file
161
keras_core/layers/reshaping/cropping3d_test.py
Normal file
@ -0,0 +1,161 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class CroppingTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.product(
|
||||
(
|
||||
{"dim1_cropping": (1, 2), "dim1_expected": (1, 5)}, # both
|
||||
{"dim1_cropping": (0, 2), "dim1_expected": (0, 5)}, # left only
|
||||
{"dim1_cropping": (1, 0), "dim1_expected": (1, 7)}, # right only
|
||||
),
|
||||
(
|
||||
{"dim2_cropping": (3, 4), "dim2_expected": (3, 5)}, # both
|
||||
{"dim2_cropping": (0, 4), "dim2_expected": (0, 5)}, # left only
|
||||
{"dim2_cropping": (3, 0), "dim2_expected": (3, 9)}, # right only
|
||||
),
|
||||
(
|
||||
{"dim3_cropping": (5, 6), "dim3_expected": (5, 7)}, # both
|
||||
{"dim3_cropping": (0, 6), "dim3_expected": (0, 7)}, # left only
|
||||
{"dim3_cropping": (5, 0), "dim3_expected": (5, 13)}, # right only
|
||||
),
|
||||
(
|
||||
{"data_format": "channels_first"},
|
||||
{"data_format": "channels_last"},
|
||||
),
|
||||
)
|
||||
def test_cropping_3d(
|
||||
self,
|
||||
dim1_cropping,
|
||||
dim2_cropping,
|
||||
dim3_cropping,
|
||||
data_format,
|
||||
dim1_expected,
|
||||
dim2_expected,
|
||||
dim3_expected,
|
||||
):
|
||||
if data_format == "channels_first":
|
||||
inputs = np.random.rand(3, 5, 7, 9, 13)
|
||||
expected_output = ops.convert_to_tensor(
|
||||
inputs[
|
||||
:,
|
||||
:,
|
||||
dim1_expected[0] : dim1_expected[1],
|
||||
dim2_expected[0] : dim2_expected[1],
|
||||
dim3_expected[0] : dim3_expected[1],
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = np.random.rand(3, 7, 9, 13, 5)
|
||||
expected_output = ops.convert_to_tensor(
|
||||
inputs[
|
||||
:,
|
||||
dim1_expected[0] : dim1_expected[1],
|
||||
dim2_expected[0] : dim2_expected[1],
|
||||
dim3_expected[0] : dim3_expected[1],
|
||||
:,
|
||||
]
|
||||
)
|
||||
|
||||
cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
|
||||
self.run_layer_test(
|
||||
layers.Cropping3D,
|
||||
init_kwargs={"cropping": cropping, "data_format": data_format},
|
||||
input_data=inputs,
|
||||
expected_output=expected_output,
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
(
|
||||
# same cropping values with 3 tuples
|
||||
{
|
||||
"cropping": ((2, 2), (2, 2), (2, 2)),
|
||||
"expected": ((2, 5), (2, 7), (2, 11)),
|
||||
},
|
||||
# same cropping values with 1 tuple
|
||||
{"cropping": (2, 2, 2), "expected": ((2, 5), (2, 7), (2, 11))},
|
||||
# same cropping values with an integer
|
||||
{"cropping": 2, "expected": ((2, 5), (2, 7), (2, 11))},
|
||||
),
|
||||
(
|
||||
{"data_format": "channels_first"},
|
||||
{"data_format": "channels_last"},
|
||||
),
|
||||
)
|
||||
def test_cropping_3d_with_same_cropping(
|
||||
self, cropping, data_format, expected
|
||||
):
|
||||
if data_format == "channels_first":
|
||||
inputs = np.random.rand(3, 5, 7, 9, 13)
|
||||
expected_output = ops.convert_to_tensor(
|
||||
inputs[
|
||||
:,
|
||||
:,
|
||||
expected[0][0] : expected[0][1],
|
||||
expected[1][0] : expected[1][1],
|
||||
expected[2][0] : expected[2][1],
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = np.random.rand(3, 7, 9, 13, 5)
|
||||
expected_output = ops.convert_to_tensor(
|
||||
inputs[
|
||||
:,
|
||||
expected[0][0] : expected[0][1],
|
||||
expected[1][0] : expected[1][1],
|
||||
expected[2][0] : expected[2][1],
|
||||
:,
|
||||
]
|
||||
)
|
||||
|
||||
self.run_layer_test(
|
||||
layers.Cropping3D,
|
||||
init_kwargs={"cropping": cropping, "data_format": data_format},
|
||||
input_data=inputs,
|
||||
expected_output=expected_output,
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
def test_cropping_3d_with_dynamic_batch_size(self):
|
||||
input_layer = layers.Input(batch_shape=(None, 7, 9, 13, 5))
|
||||
permuted = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)
|
||||
self.assertEqual(permuted.shape, (None, 4, 2, 2, 5))
|
||||
|
||||
@parameterized.product(
|
||||
(
|
||||
{"cropping": ((3, 6), (0, 0), (0, 0))},
|
||||
{"cropping": ((0, 0), (5, 8), (0, 0))},
|
||||
{"cropping": ((0, 0), (0, 0), (7, 6))},
|
||||
),
|
||||
(
|
||||
{"data_format": "channels_first"},
|
||||
{"data_format": "channels_last"},
|
||||
),
|
||||
)
|
||||
def test_cropping_3d_errors_if_cropping_more_than_available(
|
||||
self, cropping, data_format
|
||||
):
|
||||
input_layer = layers.Input(batch_shape=(3, 7, 9, 13, 5))
|
||||
with self.assertRaises(ValueError):
|
||||
layers.Cropping3D(cropping=cropping, data_format=data_format)(
|
||||
input_layer
|
||||
)
|
||||
|
||||
def test_cropping_3d_errors_if_cropping_argument_invalid(self):
|
||||
with self.assertRaises(ValueError):
|
||||
layers.Cropping3D(cropping=(1,))
|
||||
with self.assertRaises(ValueError):
|
||||
layers.Cropping3D(cropping=(1, 2))
|
||||
with self.assertRaises(ValueError):
|
||||
layers.Cropping3D(cropping=(1, 2, 3, 4))
|
||||
with self.assertRaises(ValueError):
|
||||
layers.Cropping3D(cropping="1")
|
Loading…
Reference in New Issue
Block a user