Add kears_core.layers.ZeroPadding2D. (#187)

Minor tweaks to documentation and tests of Cropping and ZeroPadding3D layers.
This commit is contained in:
hertschuh 2023-05-18 16:15:01 -07:00 committed by Francois Chollet
parent 688d303bb7
commit ff0f78d114
10 changed files with 218 additions and 23 deletions

@ -105,6 +105,7 @@ from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.up_sampling3d import UpSampling3D
from keras_core.layers.reshaping.zero_padding2d import ZeroPadding2D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional
from keras_core.layers.rnn.conv_lstm1d import ConvLSTM1D

@ -7,7 +7,7 @@ from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase):
class Cropping1DTest(testing.TestCase):
def test_cropping_1d(self):
inputs = np.random.rand(3, 5, 7)
@ -51,10 +51,10 @@ class CroppingTest(testing.TestCase):
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_1d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 5, 7))
permuted = layers.Cropping1D((1, 2))(input_layer)
self.assertEqual(permuted.shape, (None, 2, 7))
def test_cropping_1d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, None, 7))
cropped = layers.Cropping1D((1, 2))(input_layer)
self.assertEqual(cropped.shape, (1, None, 7))
def test_cropping_1d_errors_if_cropping_more_than_available(self):
with self.assertRaises(ValueError):

@ -26,7 +26,7 @@ class Cropping2D(Layer):
cropping values for height and width:
`(symmetric_height_crop, symmetric_width_crop)`.
- If tuple of 2 tuples of 2 ints: interpreted as
`((top_crop, bottom_crop), (left_crop, right_crop))`
`((top_crop, bottom_crop), (left_crop, right_crop))`.
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape

@ -8,7 +8,7 @@ from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase, parameterized.TestCase):
class Cropping2DTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
(
# different cropping values
@ -66,10 +66,10 @@ class CroppingTest(testing.TestCase, parameterized.TestCase):
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_2d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 7, 9, 5))
permuted = layers.Cropping2D(((1, 2), (3, 4)))(input_layer)
self.assertEqual(permuted.shape, (None, 4, 2, 5))
def test_cropping_2d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, 7, None, 5))
cropped = layers.Cropping2D(((1, 2), (3, 4)))(input_layer)
self.assertEqual(cropped.shape, (1, 4, None, 5))
@parameterized.product(
(

@ -25,7 +25,7 @@ class Cropping3D(Layer):
`(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_crop, right_dim1_crop), (left_dim2_crop,
right_dim2_crop), (left_dim3_crop, right_dim3_crop))`
right_dim2_crop), (left_dim3_crop, right_dim3_crop))`.
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape

@ -8,7 +8,7 @@ from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase, parameterized.TestCase):
class Cropping3DTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
(
{"dim1_cropping": (1, 2), "dim1_expected": (1, 5)}, # both
@ -125,10 +125,10 @@ class CroppingTest(testing.TestCase, parameterized.TestCase):
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 7, 9, 13, 5))
permuted = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 4, 2, 2, 5))
def test_cropping_3d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, 7, None, 13, 5))
cropped = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(cropped.shape, (1, 4, None, 2, 5))
@parameterized.product(
(

@ -0,0 +1,118 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.ZeroPadding2D")
class ZeroPadding2D(Layer):
"""Zero-padding layer for 2D input (e.g. picture).
This layer can add rows and columns of zeros at the top, bottom, left and
right side of an image tensor.
Examples:
>>> input_shape = (1, 1, 2, 2)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> x
[[[[0 1]
[2 3]]]]
>>> y = keras_core.layers.ZeroPadding2D(padding=1)(x)
>>> y
[[[[0 0]
[0 0]
[0 0]
[0 0]]
[[0 0]
[0 1]
[2 3]
[0 0]]
[[0 0]
[0 0]
[0 0]
[0 0]]]]
Args:
padding: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric padding is applied to height and width.
- If tuple of 2 ints: interpreted as two different symmetric padding
values for height and width:
`(symmetric_height_pad, symmetric_width_pad)`.
- If tuple of 2 tuples of 2 ints: interpreted as
`((top_pad, bottom_pad), (left_pad, right_pad))`.
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, height, width, channels)` while `"channels_first"`
corresponds to inputs with shape
`(batch_size, channels, height, width)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, height, width, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, height, width)`
Output shape:
4D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, padded_height, padded_width, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, padded_height, padded_width)`
"""
def __init__(self, padding=(1, 1), data_format=None, **kwargs):
super().__init__(**kwargs)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding))
elif hasattr(padding, "__len__"):
if len(padding) != 2:
raise ValueError(
"`padding` should have two elements. "
f"Received: padding={padding}."
)
height_padding = padding[0]
if isinstance(height_padding, int):
height_padding = (height_padding, height_padding)
width_padding = padding[1]
if isinstance(width_padding, int):
width_padding = (width_padding, width_padding)
self.padding = (height_padding, width_padding)
else:
raise ValueError(
"`padding` should be either an int, a tuple of 2 ints "
"(symmetric_height_crop, symmetric_width_crop), "
"or a tuple of 2 tuples of 2 ints "
"((top_crop, bottom_crop), (left_crop, right_crop)). "
f"Received: padding={padding}."
)
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
spatial_dims_offset = 2 if self.data_format == "channels_first" else 1
for index in range(0, 2):
if output_shape[index + spatial_dims_offset] is not None:
output_shape[index + spatial_dims_offset] += (
self.padding[index][0] + self.padding[index][1]
)
return tuple(output_shape)
def call(self, inputs):
if self.data_format == "channels_first":
all_dims_padding = ((0, 0), (0, 0), *self.padding)
else:
all_dims_padding = ((0, 0), *self.padding, (0, 0))
return ops.pad(inputs, all_dims_padding)
def get_config(self):
config = {"padding": self.padding, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,76 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ZeroPadding2DTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("channels_first", "channels_first"), ("channels_last", "channels_last")
)
def test_zero_padding_2d(self, data_format):
inputs = np.random.rand(1, 2, 3, 4)
outputs = layers.ZeroPadding2D(
padding=((1, 2), (3, 4)), data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, :, index, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 1:-2, 3:-4], inputs)
else:
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, index, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, index, :], 0.0)
self.assertAllClose(outputs[:, 1:-2, 3:-4, :], inputs)
@parameterized.product(
(
{"padding": ((2, 2), (2, 2))}, # 2 tuples
{"padding": (2, 2)}, # 1 tuple
{"padding": 2}, # 1 int
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_zero_padding_2d_with_same_padding(self, padding, data_format):
inputs = np.random.rand(1, 2, 3, 4)
outputs = layers.ZeroPadding2D(
padding=padding, data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, :, index, :], 0.0)
self.assertAllClose(outputs[:, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 2:-2, 2:-2], inputs)
else:
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, index, :], 0.0)
self.assertAllClose(outputs[:, 2:-2, 2:-2, :], inputs)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_zero_padding_2d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, 2, None, 4))
padded = layers.ZeroPadding2D(((1, 2), (3, 4)))(input_layer)
self.assertEqual(padded.shape, (1, 5, None, 4))
def test_zero_padding_2d_errors_if_padding_argument_invalid(self):
with self.assertRaises(ValueError):
layers.ZeroPadding2D(padding=(1,))
with self.assertRaises(ValueError):
layers.ZeroPadding2D(padding=(1, 2, 3))
with self.assertRaises(ValueError):
layers.ZeroPadding2D(padding="1")

@ -26,7 +26,7 @@ class ZeroPadding3D(Layer):
`(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`.
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape

@ -7,7 +7,7 @@ from keras_core import layers
from keras_core import testing
class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
class ZeroPadding3DTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("channels_first", "channels_first"), ("channels_last", "channels_last")
)
@ -68,10 +68,10 @@ class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Backend does not support dynamic batch sizes",
)
def test_zero_padding_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5))
permuted = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 5, 10, 15, 5))
def test_zero_padding_3d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, 2, None, 4, 5))
padded = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(padded.shape, (1, 5, None, 15, 5))
def test_zero_padding_3d_errors_if_padding_argument_invalid(self):
with self.assertRaises(ValueError):