Add keras_core.layers.Cropping2D. (#147)

This commit is contained in:
hertschuh 2023-05-11 15:58:31 -07:00 committed by Francois Chollet
parent bb96cd5ad1
commit 0b299f226c
6 changed files with 317 additions and 5 deletions

@ -89,6 +89,7 @@ from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
from keras_core.layers.reshaping.cropping1d import Cropping1D
from keras_core.layers.reshaping.cropping2d import Cropping2D
from keras_core.layers.reshaping.flatten import Flatten
from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector

@ -13,7 +13,7 @@ class Cropping1D(Layer):
>>> input_shape = (2, 3, 2)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> print(x)
>>> x
[[[ 0 1]
[ 2 3]
[ 4 5]]
@ -21,7 +21,7 @@ class Cropping1D(Layer):
[ 8 9]
[10 11]]]
>>> y = keras_core.layers.Cropping1D(cropping=1)(x)
>>> print(y)
>>> y
[[[2 3]]
[[8 9]]]

@ -58,5 +58,5 @@ class CroppingTest(testing.TestCase):
def test_cropping_1d_errors_if_cropping_more_than_available(self):
with self.assertRaises(ValueError):
input_layer = layers.Input(shape=(3, 4, 7))
input_layer = layers.Input(batch_shape=(3, 5, 7))
layers.Cropping1D(cropping=(2, 3))(input_layer)

@ -0,0 +1,212 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Cropping2D")
class Cropping2D(Layer):
"""Cropping layer for 2D input (e.g. picture).
It crops along spatial dimensions, i.e. height and width.
Examples:
>>> input_shape = (2, 28, 28, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> y = keras_core.layers.Cropping2D(cropping=((2, 2), (4, 4)))(x)
>>> y.shape
(2, 24, 20, 3)
Args:
cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric cropping is applied to height and
width.
- If tuple of 2 ints: interpreted as two different symmetric
cropping values for height and width:
`(symmetric_height_crop, symmetric_width_crop)`.
- If tuple of 2 tuples of 2 ints: interpreted as
`((top_crop, bottom_crop), (left_crop, right_crop))`
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, height, width, channels)` while `"channels_first"`
corresponds to inputs with shape
`(batch_size, channels, height, width)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, height, width, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, height, width)`
Output shape:
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, cropped_height, cropped_width, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, cropped_height, cropped_width)`
"""
def __init__(
self, cropping=((0, 0), (0, 0)), data_format=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = ((cropping, cropping), (cropping, cropping))
elif hasattr(cropping, "__len__"):
if len(cropping) != 2:
raise ValueError(
"`cropping` should have two elements. "
f"Received: cropping={cropping}."
)
height_cropping = cropping[0]
if isinstance(height_cropping, int):
height_cropping = (height_cropping, height_cropping)
width_cropping = cropping[1]
if isinstance(width_cropping, int):
width_cropping = (width_cropping, width_cropping)
self.cropping = (height_cropping, width_cropping)
else:
raise ValueError(
"`cropping` should be either an int, a tuple of 2 ints "
"(symmetric_height_crop, symmetric_width_crop), "
"or a tuple of 2 tuples of 2 ints "
"((top_crop, bottom_crop), (left_crop, right_crop)). "
f"Received: cropping={cropping}."
)
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
if (
input_shape[2] is not None
and sum(self.cropping[0]) >= input_shape[2]
) or (
input_shape[3] is not None
and sum(self.cropping[1]) >= input_shape[3]
):
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"input_shape={input_shape}, cropping={self.cropping}"
)
return (
input_shape[0],
input_shape[1],
input_shape[2] - self.cropping[0][0] - self.cropping[0][1]
if input_shape[2] is not None
else None,
input_shape[3] - self.cropping[1][0] - self.cropping[1][1]
if input_shape[3] is not None
else None,
)
else:
if (
input_shape[1] is not None
and sum(self.cropping[0]) >= input_shape[1]
) or (
input_shape[2] is not None
and sum(self.cropping[1]) >= input_shape[2]
):
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"input_shape={input_shape}, cropping={self.cropping}"
)
return (
input_shape[0],
input_shape[1] - self.cropping[0][0] - self.cropping[0][1]
if input_shape[1] is not None
else None,
input_shape[2] - self.cropping[1][0] - self.cropping[1][1]
if input_shape[2] is not None
else None,
input_shape[3],
)
def call(self, inputs):
if self.data_format == "channels_first":
if (
inputs.shape[2] is not None
and sum(self.cropping[0]) >= inputs.shape[2]
) or (
inputs.shape[3] is not None
and sum(self.cropping[1]) >= inputs.shape[3]
):
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"inputs.shape={inputs.shape}, cropping={self.cropping}"
)
if self.cropping[0][1] == self.cropping[1][1] == 0:
return inputs[
:, :, self.cropping[0][0] :, self.cropping[1][0] :
]
elif self.cropping[0][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
]
elif self.cropping[1][1] == 0:
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
]
return inputs[
:,
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
]
else:
if (
inputs.shape[1] is not None
and sum(self.cropping[0]) >= inputs.shape[1]
) or (
inputs.shape[2] is not None
and sum(self.cropping[1]) >= inputs.shape[2]
):
raise ValueError(
"Values in `cropping` argument should be greater than the "
"corresponding spatial dimension of the input. Received: "
f"inputs.shape={inputs.shape}, cropping={self.cropping}"
)
if self.cropping[0][1] == self.cropping[1][1] == 0:
return inputs[
:, self.cropping[0][0] :, self.cropping[1][0] :, :
]
elif self.cropping[0][1] == 0:
return inputs[
:,
self.cropping[0][0] :,
self.cropping[1][0] : -self.cropping[1][1],
:,
]
elif self.cropping[1][1] == 0:
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] :,
:,
]
return inputs[
:,
self.cropping[0][0] : -self.cropping[0][1],
self.cropping[1][0] : -self.cropping[1][1],
:,
]
def get_config(self):
config = {"cropping": self.cropping, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,99 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
(
# different cropping values
{"cropping": ((1, 2), (3, 4)), "expected_ranges": ((1, 5), (3, 5))},
# same cropping values with 2 tuples
{"cropping": ((2, 2), (2, 2)), "expected_ranges": ((2, 5), (2, 7))},
# same cropping values with 1 tuple
{"cropping": (2, 2), "expected_ranges": ((2, 5), (2, 7))},
# same cropping values with an integer
{"cropping": 2, "expected_ranges": ((2, 5), (2, 7))},
# cropping right only in both dimensions
{"cropping": ((0, 2), (0, 4)), "expected_ranges": ((0, 5), (0, 5))},
# cropping left only in both dimensions
{"cropping": ((1, 0), (3, 0)), "expected_ranges": ((1, 7), (3, 9))},
# cropping left only in rows dimension
{"cropping": ((1, 0), (3, 4)), "expected_ranges": ((1, 7), (3, 5))},
# cropping left only in cols dimension
{"cropping": ((1, 2), (3, 0)), "expected_ranges": ((1, 5), (3, 9))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_2d(self, cropping, data_format, expected_ranges):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9)
expected_output = ops.convert_to_tensor(
inputs[
:,
:,
expected_ranges[0][0] : expected_ranges[0][1],
expected_ranges[1][0] : expected_ranges[1][1],
]
)
else:
inputs = np.random.rand(3, 7, 9, 5)
expected_output = ops.convert_to_tensor(
inputs[
:,
expected_ranges[0][0] : expected_ranges[0][1],
expected_ranges[1][0] : expected_ranges[1][1],
:,
]
)
self.run_layer_test(
layers.Cropping2D,
init_kwargs={"cropping": cropping, "data_format": data_format},
input_data=inputs,
expected_output=expected_output,
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_2d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 7, 9, 5))
permuted = layers.Cropping2D(((1, 2), (3, 4)))(input_layer)
self.assertEqual(permuted.shape, (None, 4, 2, 5))
@parameterized.product(
(
{"cropping": ((3, 6), (0, 0))},
{"cropping": ((0, 0), (5, 4))},
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_cropping_2d_errors_if_cropping_more_than_available(
self, cropping, data_format
):
input_layer = layers.Input(batch_shape=(3, 7, 9, 5))
with self.assertRaises(ValueError):
layers.Cropping2D(cropping=cropping, data_format=data_format)(
input_layer
)
def test_cropping_2d_errors_if_cropping_argument_invalid(self):
with self.assertRaises(ValueError):
layers.Cropping2D(cropping=(1,))
with self.assertRaises(ValueError):
layers.Cropping2D(cropping=(1, 2, 3))
with self.assertRaises(ValueError):
layers.Cropping2D(cropping="1")

@ -14,13 +14,13 @@ class UpSampling1D(Layer):
>>> input_shape = (2, 2, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> print(x)
>>> x
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
>>> y = keras_core.layers.UpSampling1D(size=2)(x)
>>> print(y)
>>> y
[[[ 0. 1. 2.]
[ 0. 1. 2.]
[ 3. 4. 5.]