Add max and average pooling layer (#66)

* Add max and poolig layer

* fix tests

* handle TF transpose

* renaming

* rename tests

* Fix comments

* Move out the shape computation logic
This commit is contained in:
Chen Qian 2023-05-02 15:17:59 -07:00 committed by Francois Chollet
parent e253911d18
commit 841b8d702d
15 changed files with 1144 additions and 58 deletions

@ -69,6 +69,36 @@ def log_softmax(x, axis=None):
return tf.nn.log_softmax(x, axis=axis)
def _transpose_spatial_inputs(inputs):
num_spatial_dims = len(inputs.shape) - 2
# Tensorflow pooling does not support `channels_first` format, so
# we need to transpose to `channels_last` format.
if num_spatial_dims == 1:
inputs = tf.transpose(inputs, (0, 2, 1))
elif num_spatial_dims == 2:
inputs = tf.transpose(inputs, (0, 2, 3, 1))
elif num_spatial_dims == 3:
inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
else:
raise ValueError(
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
f"and 3D inputs. But received shape: {inputs.shape}."
)
return inputs
def _transpose_spatial_outputs(outputs):
# Undo the tranpose in `_transpose_spatial_inputs`.
num_spatial_dims = len(outputs.shape) - 2
if num_spatial_dims == 1:
outputs = tf.transpose(outputs, (0, 2, 1))
elif num_spatial_dims == 2:
outputs = tf.transpose(outputs, (0, 3, 1, 2))
elif num_spatial_dims == 3:
outputs = tf.transpose(outputs, (0, 4, 1, 2, 3))
return outputs
def max_pool(
inputs,
pool_size,
@ -78,8 +108,22 @@ def max_pool(
):
strides = pool_size if strides is None else strides
padding = padding.upper()
data_format = _convert_data_format(data_format, len(inputs.shape))
return tf.nn.max_pool(inputs, pool_size, strides, padding, data_format)
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
if data_format == "channels_first":
# Tensorflow pooling does not support `channels_first` format, so
# we need to transpose to `channels_last` format.
inputs = _transpose_spatial_inputs(inputs)
outputs = tf.nn.max_pool(
inputs,
pool_size,
strides,
padding,
tf_data_format,
)
if data_format == "channels_first":
outputs = _transpose_spatial_outputs(outputs)
return outputs
def average_pool(
@ -91,8 +135,22 @@ def average_pool(
):
strides = pool_size if strides is None else strides
padding = padding.upper()
data_format = _convert_data_format(data_format, len(inputs.shape))
return tf.nn.avg_pool(inputs, pool_size, strides, padding, data_format)
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
if data_format == "channels_first":
# Tensorflow pooling does not support `channels_first` format, so
# we need to transpose to `channels_last` format.
inputs = _transpose_spatial_inputs(inputs)
outputs = tf.nn.avg_pool(
inputs,
pool_size,
strides,
padding,
tf_data_format,
)
if data_format == "channels_first":
outputs = _transpose_spatial_outputs(outputs)
return outputs
def _convert_data_format(data_format, ndim):

@ -8,6 +8,12 @@ from keras_core.layers.merging.add import Add
from keras_core.layers.merging.add import add
from keras_core.layers.merging.subtract import Subtract
from keras_core.layers.merging.subtract import subtract
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
from keras_core.layers.pooling.average_pooling3d import AveragePooling3D
from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
from keras_core.layers.regularization.activity_regularization import (
ActivityRegularization,
)

@ -0,0 +1,92 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.AveragePooling1D", "keras_core.layers.AvgPool1D"]
)
class AveragePooling1D(BasePooling):
"""Average pooling for temporal data.
Downsamples the input representation by taking the average value over the
window defined by `pool_size`. The window is shifted by `strides`. The
resulting output when using "valid" padding option has a shape of:
`output_shape = (input_shape - pool_size + 1) / strides)`
The resulting output shape when using the "same" padding option is:
`output_shape = input_shape / strides`
Args:
pool_size: int, size of the max pooling window.
strides: int or None. Specifies how much the pooling window moves
for each pooling step. If None, it will default to `pool_size`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, steps, features)`.
- If `data_format="channels_first"`:
3D tensor with shape `(batch_size, features, steps)`.
Output shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, downsampled_steps, features)`.
- If `data_format="channels_first"`:
3D tensor with shape `(batch_size, features, downsampled_steps)`.
Examples:
`strides=1` and `padding="valid"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
... strides=1, padding="valid")
>>> avg_pool_1d(x)
`strides=2` and `padding="valid"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
... strides=2, padding="valid")
>>> avg_pool_1d(x)
`strides=1` and `padding="same"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
... strides=1, padding="same")
>>> avg_pool_1d(x)
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format=None,
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=1,
pool_mode="average",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,109 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.AveragePooling2D", "keras_core.layers.AvgPool2D"]
)
class AveragePooling2D(BasePooling):
"""Average pooling operation for 2D spatial data.
Downsamples the input along its spatial dimensions (height and width)
by taking the average value over an input window
(of size defined by `pool_size`) for each channel of the input.
The window is shifted by `strides` along each dimension.
The resulting output when using the `"valid"` padding option has a spatial
shape (number of rows or columns) of:
`output_shape = math.floor((input_shape - pool_size) / strides) + 1`
(when `input_shape >= pool_size`)
The resulting output shape when using the `"same"` padding option is:
`output_shape = math.floor((input_shape - 1) / strides) + 1`
Args:
pool_size: int or tuple of 2 integers, factors by which to downscale
(dim1, dim2). If only one integer is specified, the same
window length will be used for all dimensions.
strides: int or tuple of 2 integers, or None. Strides values. If None,
it will default to `pool_size`. If only one int is specified, the
same stride size will be used for all dimensions.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
4D tensor with shape `(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape `(batch_size, channels, height, width)`.
Output shape:
- If `data_format="channels_last"`:
4D tensor with shape
`(batch_size, pooled_height, pooled_width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape
`(batch_size, channels, pooled_height, pooled_width)`.
Examples:
`strides=(1, 1)` and `padding="valid"`:
>>> x = np.array([[1., 2., 3.],
... [4., 5., 6.],
... [7., 8., 9.]])
>>> x = np.reshape(x, [1, 3, 3, 1])
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
... strides=(1, 1), padding="valid")
>>> avg_pool_2d(x)
`strides=(2, 2)` and `padding="valid"`:
>>> x = np.array([[1., 2., 3., 4.],
... [5., 6., 7., 8.],
... [9., 10., 11., 12.]])
>>> x = np.reshape(x, [1, 3, 4, 1])
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
... strides=(2, 2), padding="valid")
>>> avg_pool_2d(x)
`stride=(1, 1)` and `padding="same"`:
>>> x = np.array([[1., 2., 3.],
... [4., 5., 6.],
... [7., 8., 9.]])
>>> x = np.reshape(x, [1, 3, 3, 1])
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
... strides=(1, 1), padding="same")
>>> avg_pool_2d(x)
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format="channels_last",
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=2,
pool_mode="average",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,85 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.AveragePooling3D", "keras_core.layers.AvgPool3D"]
)
class AveragePooling3D(BasePooling):
"""Average pooling operation for 3D data (spatial or spatio-temporal).
Downsamples the input along its spatial dimensions (depth, height, and
width) by taking the average value over an input window (of size defined by
`pool_size`) for each channel of the input. The window is shifted by
`strides` along each dimension.
Args:
pool_size: int or tuple of 3 integers, factors by which to downscale
(dim1, dim2, dim3). If only one integer is specified, the same
window length will be used for all dimensions.
strides: int or tuple of 3 integers, or None. Strides values. If None,
it will default to `pool_size`. If only one int is specified, the
same stride size will be used for all dimensions.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while
`"channels_first"` corresponds to inputs with shape
`(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your Keras
config file at `~/.keras/keras.json`. If you never set it, then it
will be `"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
Example:
```python
depth = 30
height = 30
width = 30
channels = 3
inputs = keras_core.layers.Input(shape=(depth, height, width, channels))
layer = keras_core.layers.AveragePooling3D(pool_size=3)
outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3)
```
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format="channels_last",
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=3,
pool_mode="average",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,181 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class AveragePoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)),
((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)),
)
def test_average_pooling1d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.AveragePooling1D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
)
def test_average_pooling2d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.AveragePooling2D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),
(
(2, 3, 2),
(2, 2, 1),
"valid",
"channels_last",
(3, 5, 5, 5, 4),
(3, 2, 2, 4, 4),
),
)
def test_average_pooling3d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.AveragePooling3D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
class AveragePoolingCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2,), (2,), "valid", "channels_last"),
)
def test_average_pooling1d(self, pool_size, strides, padding, data_format):
inputs = np.arange(24, dtype=np.float).reshape((2, 3, 4))
layer = layers.AveragePooling1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
((2, 3), (2, 2), "same", "channels_last"),
)
def test_average_pooling2d(self, pool_size, strides, padding, data_format):
inputs = np.arange(300, dtype=np.float).reshape((3, 5, 5, 4))
layer = layers.AveragePooling2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2, 3, 2), (2, 2, 1), "valid", "channels_last"),
)
def test_average_pooling3d(self, pool_size, strides, padding, data_format):
inputs = np.arange(240, dtype=np.float).reshape((2, 3, 4, 5, 2))
layer = layers.AveragePooling3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@ -0,0 +1,76 @@
from keras_core import operations as ops
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_pooling_output_shape
class BasePooling(Layer):
"""Base pooling layer."""
def __init__(
self,
pool_size,
strides,
pool_dimensions,
pool_mode="max",
padding="valid",
data_format=None,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.pool_size = pool_size
self.strides = pool_size if strides is None else strides
self.pool_mode = pool_mode
self.padding = padding
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.input_spec = InputSpec(ndim=pool_dimensions + 2)
def call(self, inputs):
if self.pool_mode == "max":
return ops.max_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
elif self.pool_mode == "average":
return ops.average_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
else:
raise ValueError(
"`pool_mode` must be either 'max' or 'average'. Received: "
f"{self.pool_mode}."
)
def compute_output_shape(self, input_shape):
return compute_pooling_output_shape(
input_shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
def get_config(self):
config = super().get_config()
config.update(
{
"pool_size": self.pool_size,
"padding": self.padding,
"strides": self.strides,
"data_format": self.data_format,
}
)
return config

@ -0,0 +1,93 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.MaxPooling1D", "keras_core.layers.MaxPool1D"]
)
class MaxPooling1D(BasePooling):
"""Max pooling operation for 1D temporal data.
Downsamples the input representation by taking the maximum value over a
spatial window of size `pool_size`. The window is shifted by `strides`.
The resulting output when using the `"valid"` padding option has a shape of:
`output_shape = (input_shape - pool_size + 1) / strides)`.
The resulting output shape when using the `"same"` padding option is:
`output_shape = input_shape / strides`
Args:
pool_size: int, size of the max pooling window.
strides: int or None. Specifies how much the pooling window moves
for each pooling step. If None, it will default to `pool_size`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, steps, features)`.
- If `data_format="channels_first"`:
3D tensor with shape `(batch_size, features, steps)`.
Output shape:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, downsampled_steps, features)`.
- If `data_format="channels_first"`:
3D tensor with shape `(batch_size, features, downsampled_steps)`.
Examples:
`strides=1` and `padding="valid"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
... strides=1, padding="valid")
>>> max_pool_1d(x)
`strides=2` and `padding="valid"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
... strides=2, padding="valid")
>>> max_pool_1d(x)
`strides=1` and `padding="same"`:
>>> x = np.array([1., 2., 3., 4., 5.])
>>> x = np.reshape(x, [1, 5, 1])
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
... strides=1, padding="same")
>>> max_pool_1d(x)
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format=None,
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=1,
pool_mode="max",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,109 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.MaxPooling2D", "keras_core.layers.MaxPool2D"]
)
class MaxPooling2D(BasePooling):
"""Max pooling operation for 2D spatial data.
Downsamples the input along its spatial dimensions (height and width)
by taking the maximum value over an input window
(of size defined by `pool_size`) for each channel of the input.
The window is shifted by `strides` along each dimension.
The resulting output when using the `"valid"` padding option has a spatial
shape (number of rows or columns) of:
`output_shape = math.floor((input_shape - pool_size) / strides) + 1`
(when `input_shape >= pool_size`)
The resulting output shape when using the `"same"` padding option is:
`output_shape = math.floor((input_shape - 1) / strides) + 1`
Args:
pool_size: int or tuple of 2 integers, factors by which to downscale
(dim1, dim2). If only one integer is specified, the same
window length will be used for all dimensions.
strides: int or tuple of 2 integers, or None. Strides values. If None,
it will default to `pool_size`. If only one int is specified, the
same stride size will be used for all dimensions.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
4D tensor with shape `(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape `(batch_size, channels, height, width)`.
Output shape:
- If `data_format="channels_last"`:
4D tensor with shape
`(batch_size, pooled_height, pooled_width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape
`(batch_size, channels, pooled_height, pooled_width)`.
Examples:
`strides=(1, 1)` and `padding="valid"`:
>>> x = np.array([[1., 2., 3.],
... [4., 5., 6.],
... [7., 8., 9.]])
>>> x = np.reshape(x, [1, 3, 3, 1])
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
... strides=(1, 1), padding="valid")
>>> max_pool_2d(x)
`strides=(2, 2)` and `padding="valid"`:
>>> x = np.array([[1., 2., 3., 4.],
... [5., 6., 7., 8.],
... [9., 10., 11., 12.]])
>>> x = np.reshape(x, [1, 3, 4, 1])
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
... strides=(2, 2), padding="valid")
>>> max_pool_2d(x)
`stride=(1, 1)` and `padding="same"`:
>>> x = np.array([[1., 2., 3.],
... [4., 5., 6.],
... [7., 8., 9.]])
>>> x = np.reshape(x, [1, 3, 3, 1])
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
... strides=(1, 1), padding="same")
>>> max_pool_2d(x)
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format=None,
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=2,
pool_mode="max",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,85 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_pooling import BasePooling
@keras_core_export(
["keras_core.layers.MaxPooling3D", "keras_core.layers.MaxPool3D"]
)
class MaxPooling3D(BasePooling):
"""Max pooling operation for 3D data (spatial or spatio-temporal).
Downsamples the input along its spatial dimensions (depth, height, and
width) by taking the maximum value over an input window (of size defined by
`pool_size`) for each channel of the input. The window is shifted by
`strides` along each dimension.
Args:
pool_size: int or tuple of 3 integers, factors by which to downscale
(dim1, dim2, dim3). If only one integer is specified, the same
window length will be used for all dimensions.
strides: int or tuple of 3 integers, or None. Strides values. If None,
it will default to `pool_size`. If only one int is specified, the
same stride size will be used for all dimensions.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while
`"channels_first"` corresponds to inputs with shape
`(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your Keras
config file at `~/.keras/keras.json`. If you never set it, then it
will be `"channels_last"`.
Input shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
Example:
```python
depth = 30
height = 30
width = 30
channels = 3
inputs = keras_core.layers.Input(shape=(depth, height, width, channels))
layer = keras_core.layers.MaxPooling3D(pool_size=3)
outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3)
```
"""
def __init__(
self,
pool_size,
strides,
padding="valid",
data_format=None,
name=None,
**kwargs
):
super().__init__(
pool_size,
strides,
pool_dimensions=3,
pool_mode="max",
padding=padding,
data_format=data_format,
name=name,
**kwargs,
)

@ -0,0 +1,181 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class MaxPoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)),
((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)),
)
def test_max_pooling1d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.MaxPooling1D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
)
def test_max_pooling2d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.MaxPooling2D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),
(2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),
(
(2, 3, 2),
(2, 2, 1),
"valid",
"channels_last",
(3, 5, 5, 5, 4),
(3, 2, 2, 4, 4),
),
)
def test_max_pooling3d(
self,
pool_size,
strides,
padding,
data_format,
input_shape,
output_shape,
):
self.run_layer_test(
layers.MaxPooling3D,
init_kwargs={
"pool_size": pool_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
class MaxPoolingCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2,), (2,), "valid", "channels_last"),
)
def test_max_pooling1d(self, pool_size, strides, padding, data_format):
inputs = np.arange(24, dtype=np.float).reshape((2, 3, 4))
layer = layers.MaxPooling1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.MaxPool1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
((2, 3), (2, 2), "same", "channels_last"),
)
def test_max_pooling2d(self, pool_size, strides, padding, data_format):
inputs = np.arange(300, dtype=np.float).reshape((3, 5, 5, 4))
layer = layers.MaxPooling2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.MaxPool2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2, 3, 2), (2, 2, 1), "valid", "channels_last"),
)
def test_max_pooling3d(self, pool_size, strides, padding, data_format):
inputs = np.arange(240, dtype=np.float).reshape((2, 3, 4, 5, 2))
layer = layers.MaxPooling3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.MaxPool3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@ -9,6 +9,7 @@ from keras_core.backend import is_tensor
from keras_core.backend import name_scope
from keras_core.backend import random
from keras_core.backend import shape
from keras_core.operations import operation_utils
from keras_core.operations.math import * # noqa: F403
from keras_core.operations.nn import * # noqa: F403
from keras_core.operations.numpy import * # noqa: F403

@ -37,6 +37,7 @@ from keras_core.backend import any_symbolic_tensors
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras_core.operations import operation_utils
from keras_core.operations.operation import Operation
@ -290,34 +291,13 @@ class MaxPool(Operation):
)
def compute_output_spec(self, inputs):
strides = self.pool_size if self.strides is None else self.strides
input_shape = np.array(inputs.shape)
if self.data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
pool_size = np.array(self.pool_size)
if self.padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - self.pool_size) / strides) + 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif self.padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = [int(i) for i in output_spatial_shape]
if self.data_format == "channels_last":
output_shape = (
[inputs.shape[0]]
+ list(output_spatial_shape)
+ [inputs.shape[-1]]
)
else:
output_shape = inputs.shape[:2] + list(output_spatial_shape)
output_shape = operation_utils.compute_pooling_output_shape(
inputs.shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
return KerasTensor(output_shape, dtype=inputs.dtype)
@ -394,32 +374,13 @@ class AveragePool(Operation):
)
def compute_output_spec(self, inputs):
strides = self.pool_size if self.strides is None else self.strides
input_shape = np.array(inputs.shape)
if self.data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
pool_size = np.array(self.pool_size)
if self.padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - self.pool_size) / strides) + 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif self.padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = [int(i) for i in output_spatial_shape]
if self.data_format == "channels_last":
output_shape = (
[inputs.shape[0]] + output_spatial_shape + [inputs.shape[-1]]
)
else:
output_shape = inputs.shape[:2] + output_spatial_shape
output_shape = operation_utils.compute_pooling_output_shape(
inputs.shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
return KerasTensor(output_shape, dtype=inputs.dtype)

@ -0,0 +1,49 @@
import numpy as np
def compute_pooling_output_shape(
input_shape,
pool_size,
strides,
padding="valid",
data_format="channels_last",
):
"""Compute the output shape of pooling ops."""
strides = pool_size if strides is None else strides
input_shape_origin = list(input_shape)
input_shape = np.array(input_shape)
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
pool_size = np.array(pool_size)
if padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - pool_size) / strides) + 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
else:
raise ValueError(
"`padding` must be either `'valid'` or `'same'`. Received "
f"{padding}."
)
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
if data_format == "channels_last":
output_shape = (
(input_shape_origin[0],)
+ output_spatial_shape
+ (input_shape_origin[-1],)
)
else:
output_shape = (
input_shape_origin[0],
input_shape_origin[1],
) + output_spatial_shape
return output_shape