Add max and average pooling layer (#66)
* Add max and poolig layer * fix tests * handle TF transpose * renaming * rename tests * Fix comments * Move out the shape computation logic
This commit is contained in:
parent
e253911d18
commit
841b8d702d
@ -69,6 +69,36 @@ def log_softmax(x, axis=None):
|
||||
return tf.nn.log_softmax(x, axis=axis)
|
||||
|
||||
|
||||
def _transpose_spatial_inputs(inputs):
|
||||
num_spatial_dims = len(inputs.shape) - 2
|
||||
# Tensorflow pooling does not support `channels_first` format, so
|
||||
# we need to transpose to `channels_last` format.
|
||||
if num_spatial_dims == 1:
|
||||
inputs = tf.transpose(inputs, (0, 2, 1))
|
||||
elif num_spatial_dims == 2:
|
||||
inputs = tf.transpose(inputs, (0, 2, 3, 1))
|
||||
elif num_spatial_dims == 3:
|
||||
inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def _transpose_spatial_outputs(outputs):
|
||||
# Undo the tranpose in `_transpose_spatial_inputs`.
|
||||
num_spatial_dims = len(outputs.shape) - 2
|
||||
if num_spatial_dims == 1:
|
||||
outputs = tf.transpose(outputs, (0, 2, 1))
|
||||
elif num_spatial_dims == 2:
|
||||
outputs = tf.transpose(outputs, (0, 3, 1, 2))
|
||||
elif num_spatial_dims == 3:
|
||||
outputs = tf.transpose(outputs, (0, 4, 1, 2, 3))
|
||||
return outputs
|
||||
|
||||
|
||||
def max_pool(
|
||||
inputs,
|
||||
pool_size,
|
||||
@ -78,8 +108,22 @@ def max_pool(
|
||||
):
|
||||
strides = pool_size if strides is None else strides
|
||||
padding = padding.upper()
|
||||
data_format = _convert_data_format(data_format, len(inputs.shape))
|
||||
return tf.nn.max_pool(inputs, pool_size, strides, padding, data_format)
|
||||
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
|
||||
if data_format == "channels_first":
|
||||
# Tensorflow pooling does not support `channels_first` format, so
|
||||
# we need to transpose to `channels_last` format.
|
||||
inputs = _transpose_spatial_inputs(inputs)
|
||||
|
||||
outputs = tf.nn.max_pool(
|
||||
inputs,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
tf_data_format,
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
outputs = _transpose_spatial_outputs(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def average_pool(
|
||||
@ -91,8 +135,22 @@ def average_pool(
|
||||
):
|
||||
strides = pool_size if strides is None else strides
|
||||
padding = padding.upper()
|
||||
data_format = _convert_data_format(data_format, len(inputs.shape))
|
||||
return tf.nn.avg_pool(inputs, pool_size, strides, padding, data_format)
|
||||
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
|
||||
if data_format == "channels_first":
|
||||
# Tensorflow pooling does not support `channels_first` format, so
|
||||
# we need to transpose to `channels_last` format.
|
||||
inputs = _transpose_spatial_inputs(inputs)
|
||||
|
||||
outputs = tf.nn.avg_pool(
|
||||
inputs,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
tf_data_format,
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
outputs = _transpose_spatial_outputs(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def _convert_data_format(data_format, ndim):
|
||||
|
@ -8,6 +8,12 @@ from keras_core.layers.merging.add import Add
|
||||
from keras_core.layers.merging.add import add
|
||||
from keras_core.layers.merging.subtract import Subtract
|
||||
from keras_core.layers.merging.subtract import subtract
|
||||
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
|
||||
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
|
||||
from keras_core.layers.pooling.average_pooling3d import AveragePooling3D
|
||||
from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
|
||||
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
|
||||
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
|
||||
from keras_core.layers.regularization.activity_regularization import (
|
||||
ActivityRegularization,
|
||||
)
|
||||
|
0
keras_core/layers/pooling/__init__.py
Normal file
0
keras_core/layers/pooling/__init__.py
Normal file
92
keras_core/layers/pooling/average_pooling1d.py
Normal file
92
keras_core/layers/pooling/average_pooling1d.py
Normal file
@ -0,0 +1,92 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.AveragePooling1D", "keras_core.layers.AvgPool1D"]
|
||||
)
|
||||
class AveragePooling1D(BasePooling):
|
||||
"""Average pooling for temporal data.
|
||||
|
||||
Downsamples the input representation by taking the average value over the
|
||||
window defined by `pool_size`. The window is shifted by `strides`. The
|
||||
resulting output when using "valid" padding option has a shape of:
|
||||
`output_shape = (input_shape - pool_size + 1) / strides)`
|
||||
|
||||
The resulting output shape when using the "same" padding option is:
|
||||
`output_shape = input_shape / strides`
|
||||
|
||||
Args:
|
||||
pool_size: int, size of the max pooling window.
|
||||
strides: int or None. Specifies how much the pooling window moves
|
||||
for each pooling step. If None, it will default to `pool_size`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, steps, features)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, features, steps)`. It defaults to the `image_data_format`
|
||||
value found in your Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
3D tensor with shape `(batch_size, steps, features)`.
|
||||
- If `data_format="channels_first"`:
|
||||
3D tensor with shape `(batch_size, features, steps)`.
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
3D tensor with shape `(batch_size, downsampled_steps, features)`.
|
||||
- If `data_format="channels_first"`:
|
||||
3D tensor with shape `(batch_size, features, downsampled_steps)`.
|
||||
|
||||
Examples:
|
||||
|
||||
`strides=1` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
|
||||
... strides=1, padding="valid")
|
||||
>>> avg_pool_1d(x)
|
||||
|
||||
`strides=2` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
|
||||
... strides=2, padding="valid")
|
||||
>>> avg_pool_1d(x)
|
||||
|
||||
`strides=1` and `padding="same"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> avg_pool_1d = keras_core.layers.AveragePooling1D(pool_size=2,
|
||||
... strides=1, padding="same")
|
||||
>>> avg_pool_1d(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format=None,
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=1,
|
||||
pool_mode="average",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
109
keras_core/layers/pooling/average_pooling2d.py
Normal file
109
keras_core/layers/pooling/average_pooling2d.py
Normal file
@ -0,0 +1,109 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.AveragePooling2D", "keras_core.layers.AvgPool2D"]
|
||||
)
|
||||
class AveragePooling2D(BasePooling):
|
||||
"""Average pooling operation for 2D spatial data.
|
||||
|
||||
Downsamples the input along its spatial dimensions (height and width)
|
||||
by taking the average value over an input window
|
||||
(of size defined by `pool_size`) for each channel of the input.
|
||||
The window is shifted by `strides` along each dimension.
|
||||
|
||||
The resulting output when using the `"valid"` padding option has a spatial
|
||||
shape (number of rows or columns) of:
|
||||
`output_shape = math.floor((input_shape - pool_size) / strides) + 1`
|
||||
(when `input_shape >= pool_size`)
|
||||
|
||||
The resulting output shape when using the `"same"` padding option is:
|
||||
`output_shape = math.floor((input_shape - 1) / strides) + 1`
|
||||
|
||||
Args:
|
||||
pool_size: int or tuple of 2 integers, factors by which to downscale
|
||||
(dim1, dim2). If only one integer is specified, the same
|
||||
window length will be used for all dimensions.
|
||||
strides: int or tuple of 2 integers, or None. Strides values. If None,
|
||||
it will default to `pool_size`. If only one int is specified, the
|
||||
same stride size will be used for all dimensions.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, height, width, channels)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, height, width)`. It defaults to the
|
||||
`image_data_format` value found in your Keras config file at
|
||||
`~/.keras/keras.json`. If you never set it, then it will be
|
||||
`"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
4D tensor with shape `(batch_size, height, width, channels)`.
|
||||
- If `data_format="channels_first"`:
|
||||
4D tensor with shape `(batch_size, channels, height, width)`.
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
4D tensor with shape
|
||||
`(batch_size, pooled_height, pooled_width, channels)`.
|
||||
- If `data_format="channels_first"`:
|
||||
4D tensor with shape
|
||||
`(batch_size, channels, pooled_height, pooled_width)`.
|
||||
|
||||
Examples:
|
||||
|
||||
`strides=(1, 1)` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3.],
|
||||
... [4., 5., 6.],
|
||||
... [7., 8., 9.]])
|
||||
>>> x = np.reshape(x, [1, 3, 3, 1])
|
||||
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
|
||||
... strides=(1, 1), padding="valid")
|
||||
>>> avg_pool_2d(x)
|
||||
|
||||
`strides=(2, 2)` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3., 4.],
|
||||
... [5., 6., 7., 8.],
|
||||
... [9., 10., 11., 12.]])
|
||||
>>> x = np.reshape(x, [1, 3, 4, 1])
|
||||
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
|
||||
... strides=(2, 2), padding="valid")
|
||||
>>> avg_pool_2d(x)
|
||||
|
||||
`stride=(1, 1)` and `padding="same"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3.],
|
||||
... [4., 5., 6.],
|
||||
... [7., 8., 9.]])
|
||||
>>> x = np.reshape(x, [1, 3, 3, 1])
|
||||
>>> avg_pool_2d = keras_core.layers.AveragePooling2D(pool_size=(2, 2),
|
||||
... strides=(1, 1), padding="same")
|
||||
>>> avg_pool_2d(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=2,
|
||||
pool_mode="average",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
85
keras_core/layers/pooling/average_pooling3d.py
Normal file
85
keras_core/layers/pooling/average_pooling3d.py
Normal file
@ -0,0 +1,85 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.AveragePooling3D", "keras_core.layers.AvgPool3D"]
|
||||
)
|
||||
class AveragePooling3D(BasePooling):
|
||||
"""Average pooling operation for 3D data (spatial or spatio-temporal).
|
||||
|
||||
Downsamples the input along its spatial dimensions (depth, height, and
|
||||
width) by taking the average value over an input window (of size defined by
|
||||
`pool_size`) for each channel of the input. The window is shifted by
|
||||
`strides` along each dimension.
|
||||
|
||||
Args:
|
||||
pool_size: int or tuple of 3 integers, factors by which to downscale
|
||||
(dim1, dim2, dim3). If only one integer is specified, the same
|
||||
window length will be used for all dimensions.
|
||||
strides: int or tuple of 3 integers, or None. Strides values. If None,
|
||||
it will default to `pool_size`. If only one int is specified, the
|
||||
same stride size will be used for all dimensions.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape
|
||||
`(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while
|
||||
`"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
It defaults to the `image_data_format` value found in your Keras
|
||||
config file at `~/.keras/keras.json`. If you never set it, then it
|
||||
will be `"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
depth = 30
|
||||
height = 30
|
||||
width = 30
|
||||
channels = 3
|
||||
|
||||
inputs = keras_core.layers.Input(shape=(depth, height, width, channels))
|
||||
layer = keras_core.layers.AveragePooling3D(pool_size=3)
|
||||
outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=3,
|
||||
pool_mode="average",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
181
keras_core/layers/pooling/average_pooling_test.py
Normal file
181
keras_core/layers/pooling/average_pooling_test.py
Normal file
@ -0,0 +1,181 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class AveragePoolingBasicTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)),
|
||||
((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)),
|
||||
)
|
||||
def test_average_pooling1d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.AveragePooling1D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
|
||||
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
|
||||
)
|
||||
def test_average_pooling2d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.AveragePooling2D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),
|
||||
(
|
||||
(2, 3, 2),
|
||||
(2, 2, 1),
|
||||
"valid",
|
||||
"channels_last",
|
||||
(3, 5, 5, 5, 4),
|
||||
(3, 2, 2, 4, 4),
|
||||
),
|
||||
)
|
||||
def test_average_pooling3d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.AveragePooling3D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
|
||||
class AveragePoolingCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
(2, 1, "same", "channels_first"),
|
||||
((2,), (2,), "valid", "channels_last"),
|
||||
)
|
||||
def test_average_pooling1d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(24, dtype=np.float).reshape((2, 3, 4))
|
||||
|
||||
layer = layers.AveragePooling1D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.AveragePooling1D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
((2, 3), (2, 2), "same", "channels_last"),
|
||||
)
|
||||
def test_average_pooling2d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(300, dtype=np.float).reshape((3, 5, 5, 4))
|
||||
|
||||
layer = layers.AveragePooling2D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.AveragePooling2D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
(2, 1, "same", "channels_first"),
|
||||
((2, 3, 2), (2, 2, 1), "valid", "channels_last"),
|
||||
)
|
||||
def test_average_pooling3d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(240, dtype=np.float).reshape((2, 3, 4, 5, 2))
|
||||
|
||||
layer = layers.AveragePooling3D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.AveragePooling3D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
76
keras_core/layers/pooling/base_pooling.py
Normal file
76
keras_core/layers/pooling/base_pooling.py
Normal file
@ -0,0 +1,76 @@
|
||||
from keras_core import operations as ops
|
||||
from keras_core.backend import image_data_format
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.operations.operation_utils import compute_pooling_output_shape
|
||||
|
||||
|
||||
class BasePooling(Layer):
|
||||
"""Base pooling layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions,
|
||||
pool_mode="max",
|
||||
padding="valid",
|
||||
data_format=None,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name=name, **kwargs)
|
||||
|
||||
self.pool_size = pool_size
|
||||
self.strides = pool_size if strides is None else strides
|
||||
self.pool_mode = pool_mode
|
||||
self.padding = padding
|
||||
self.data_format = (
|
||||
image_data_format() if data_format is None else data_format
|
||||
)
|
||||
|
||||
self.input_spec = InputSpec(ndim=pool_dimensions + 2)
|
||||
|
||||
def call(self, inputs):
|
||||
if self.pool_mode == "max":
|
||||
return ops.max_pool(
|
||||
inputs,
|
||||
pool_size=self.pool_size,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
)
|
||||
elif self.pool_mode == "average":
|
||||
return ops.average_pool(
|
||||
inputs,
|
||||
pool_size=self.pool_size,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`pool_mode` must be either 'max' or 'average'. Received: "
|
||||
f"{self.pool_mode}."
|
||||
)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return compute_pooling_output_shape(
|
||||
input_shape,
|
||||
self.pool_size,
|
||||
self.strides,
|
||||
self.padding,
|
||||
self.data_format,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update(
|
||||
{
|
||||
"pool_size": self.pool_size,
|
||||
"padding": self.padding,
|
||||
"strides": self.strides,
|
||||
"data_format": self.data_format,
|
||||
}
|
||||
)
|
||||
return config
|
93
keras_core/layers/pooling/max_pooling1d.py
Normal file
93
keras_core/layers/pooling/max_pooling1d.py
Normal file
@ -0,0 +1,93 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.MaxPooling1D", "keras_core.layers.MaxPool1D"]
|
||||
)
|
||||
class MaxPooling1D(BasePooling):
|
||||
"""Max pooling operation for 1D temporal data.
|
||||
|
||||
Downsamples the input representation by taking the maximum value over a
|
||||
spatial window of size `pool_size`. The window is shifted by `strides`.
|
||||
|
||||
The resulting output when using the `"valid"` padding option has a shape of:
|
||||
`output_shape = (input_shape - pool_size + 1) / strides)`.
|
||||
|
||||
The resulting output shape when using the `"same"` padding option is:
|
||||
`output_shape = input_shape / strides`
|
||||
|
||||
Args:
|
||||
pool_size: int, size of the max pooling window.
|
||||
strides: int or None. Specifies how much the pooling window moves
|
||||
for each pooling step. If None, it will default to `pool_size`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, steps, features)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, features, steps)`. It defaults to the `image_data_format`
|
||||
value found in your Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
3D tensor with shape `(batch_size, steps, features)`.
|
||||
- If `data_format="channels_first"`:
|
||||
3D tensor with shape `(batch_size, features, steps)`.
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
3D tensor with shape `(batch_size, downsampled_steps, features)`.
|
||||
- If `data_format="channels_first"`:
|
||||
3D tensor with shape `(batch_size, features, downsampled_steps)`.
|
||||
|
||||
Examples:
|
||||
|
||||
`strides=1` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
|
||||
... strides=1, padding="valid")
|
||||
>>> max_pool_1d(x)
|
||||
|
||||
`strides=2` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
|
||||
... strides=2, padding="valid")
|
||||
>>> max_pool_1d(x)
|
||||
|
||||
`strides=1` and `padding="same"`:
|
||||
|
||||
>>> x = np.array([1., 2., 3., 4., 5.])
|
||||
>>> x = np.reshape(x, [1, 5, 1])
|
||||
>>> max_pool_1d = keras_core.layers.MaxPooling1D(pool_size=2,
|
||||
... strides=1, padding="same")
|
||||
>>> max_pool_1d(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format=None,
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=1,
|
||||
pool_mode="max",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
109
keras_core/layers/pooling/max_pooling2d.py
Normal file
109
keras_core/layers/pooling/max_pooling2d.py
Normal file
@ -0,0 +1,109 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.MaxPooling2D", "keras_core.layers.MaxPool2D"]
|
||||
)
|
||||
class MaxPooling2D(BasePooling):
|
||||
"""Max pooling operation for 2D spatial data.
|
||||
|
||||
Downsamples the input along its spatial dimensions (height and width)
|
||||
by taking the maximum value over an input window
|
||||
(of size defined by `pool_size`) for each channel of the input.
|
||||
The window is shifted by `strides` along each dimension.
|
||||
|
||||
The resulting output when using the `"valid"` padding option has a spatial
|
||||
shape (number of rows or columns) of:
|
||||
`output_shape = math.floor((input_shape - pool_size) / strides) + 1`
|
||||
(when `input_shape >= pool_size`)
|
||||
|
||||
The resulting output shape when using the `"same"` padding option is:
|
||||
`output_shape = math.floor((input_shape - 1) / strides) + 1`
|
||||
|
||||
Args:
|
||||
pool_size: int or tuple of 2 integers, factors by which to downscale
|
||||
(dim1, dim2). If only one integer is specified, the same
|
||||
window length will be used for all dimensions.
|
||||
strides: int or tuple of 2 integers, or None. Strides values. If None,
|
||||
it will default to `pool_size`. If only one int is specified, the
|
||||
same stride size will be used for all dimensions.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, height, width, channels)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, height, width)`. It defaults to the
|
||||
`image_data_format` value found in your Keras config file at
|
||||
`~/.keras/keras.json`. If you never set it, then it will be
|
||||
`"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
4D tensor with shape `(batch_size, height, width, channels)`.
|
||||
- If `data_format="channels_first"`:
|
||||
4D tensor with shape `(batch_size, channels, height, width)`.
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
4D tensor with shape
|
||||
`(batch_size, pooled_height, pooled_width, channels)`.
|
||||
- If `data_format="channels_first"`:
|
||||
4D tensor with shape
|
||||
`(batch_size, channels, pooled_height, pooled_width)`.
|
||||
|
||||
Examples:
|
||||
|
||||
`strides=(1, 1)` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3.],
|
||||
... [4., 5., 6.],
|
||||
... [7., 8., 9.]])
|
||||
>>> x = np.reshape(x, [1, 3, 3, 1])
|
||||
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
|
||||
... strides=(1, 1), padding="valid")
|
||||
>>> max_pool_2d(x)
|
||||
|
||||
`strides=(2, 2)` and `padding="valid"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3., 4.],
|
||||
... [5., 6., 7., 8.],
|
||||
... [9., 10., 11., 12.]])
|
||||
>>> x = np.reshape(x, [1, 3, 4, 1])
|
||||
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
|
||||
... strides=(2, 2), padding="valid")
|
||||
>>> max_pool_2d(x)
|
||||
|
||||
`stride=(1, 1)` and `padding="same"`:
|
||||
|
||||
>>> x = np.array([[1., 2., 3.],
|
||||
... [4., 5., 6.],
|
||||
... [7., 8., 9.]])
|
||||
>>> x = np.reshape(x, [1, 3, 3, 1])
|
||||
>>> max_pool_2d = keras_core.layers.MaxPooling2D(pool_size=(2, 2),
|
||||
... strides=(1, 1), padding="same")
|
||||
>>> max_pool_2d(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format=None,
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=2,
|
||||
pool_mode="max",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
85
keras_core/layers/pooling/max_pooling3d.py
Normal file
85
keras_core/layers/pooling/max_pooling3d.py
Normal file
@ -0,0 +1,85 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.pooling.base_pooling import BasePooling
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.layers.MaxPooling3D", "keras_core.layers.MaxPool3D"]
|
||||
)
|
||||
class MaxPooling3D(BasePooling):
|
||||
"""Max pooling operation for 3D data (spatial or spatio-temporal).
|
||||
|
||||
Downsamples the input along its spatial dimensions (depth, height, and
|
||||
width) by taking the maximum value over an input window (of size defined by
|
||||
`pool_size`) for each channel of the input. The window is shifted by
|
||||
`strides` along each dimension.
|
||||
|
||||
Args:
|
||||
pool_size: int or tuple of 3 integers, factors by which to downscale
|
||||
(dim1, dim2, dim3). If only one integer is specified, the same
|
||||
window length will be used for all dimensions.
|
||||
strides: int or tuple of 3 integers, or None. Strides values. If None,
|
||||
it will default to `pool_size`. If only one int is specified, the
|
||||
same stride size will be used for all dimensions.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape
|
||||
`(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while
|
||||
`"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
||||
It defaults to the `image_data_format` value found in your Keras
|
||||
config file at `~/.keras/keras.json`. If you never set it, then it
|
||||
will be `"channels_last"`.
|
||||
|
||||
Input shape:
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
depth = 30
|
||||
height = 30
|
||||
width = 30
|
||||
channels = 3
|
||||
|
||||
inputs = keras_core.layers.Input(shape=(depth, height, width, channels))
|
||||
layer = keras_core.layers.MaxPooling3D(pool_size=3)
|
||||
outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format=None,
|
||||
name=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pool_size,
|
||||
strides,
|
||||
pool_dimensions=3,
|
||||
pool_mode="max",
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
181
keras_core/layers/pooling/max_pooling_test.py
Normal file
181
keras_core/layers/pooling/max_pooling_test.py
Normal file
@ -0,0 +1,181 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class MaxPoolingBasicTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)),
|
||||
((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)),
|
||||
)
|
||||
def test_max_pooling1d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.MaxPooling1D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
|
||||
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
|
||||
)
|
||||
def test_max_pooling2d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.MaxPooling2D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),
|
||||
(2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),
|
||||
(
|
||||
(2, 3, 2),
|
||||
(2, 2, 1),
|
||||
"valid",
|
||||
"channels_last",
|
||||
(3, 5, 5, 5, 4),
|
||||
(3, 2, 2, 4, 4),
|
||||
),
|
||||
)
|
||||
def test_max_pooling3d(
|
||||
self,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
input_shape,
|
||||
output_shape,
|
||||
):
|
||||
self.run_layer_test(
|
||||
layers.MaxPooling3D,
|
||||
init_kwargs={
|
||||
"pool_size": pool_size,
|
||||
"strides": strides,
|
||||
"padding": padding,
|
||||
"data_format": data_format,
|
||||
},
|
||||
input_shape=input_shape,
|
||||
expected_output_shape=output_shape,
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
|
||||
class MaxPoolingCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
(2, 1, "same", "channels_first"),
|
||||
((2,), (2,), "valid", "channels_last"),
|
||||
)
|
||||
def test_max_pooling1d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(24, dtype=np.float).reshape((2, 3, 4))
|
||||
|
||||
layer = layers.MaxPooling1D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.MaxPool1D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
((2, 3), (2, 2), "same", "channels_last"),
|
||||
)
|
||||
def test_max_pooling2d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(300, dtype=np.float).reshape((3, 5, 5, 4))
|
||||
|
||||
layer = layers.MaxPooling2D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.MaxPool2D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(2, 1, "valid", "channels_last"),
|
||||
(2, 1, "same", "channels_first"),
|
||||
((2, 3, 2), (2, 2, 1), "valid", "channels_last"),
|
||||
)
|
||||
def test_max_pooling3d(self, pool_size, strides, padding, data_format):
|
||||
inputs = np.arange(240, dtype=np.float).reshape((2, 3, 4, 5, 2))
|
||||
|
||||
layer = layers.MaxPooling3D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
tf_keras_layer = tf.keras.layers.MaxPool3D(
|
||||
pool_size=pool_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
@ -9,6 +9,7 @@ from keras_core.backend import is_tensor
|
||||
from keras_core.backend import name_scope
|
||||
from keras_core.backend import random
|
||||
from keras_core.backend import shape
|
||||
from keras_core.operations import operation_utils
|
||||
from keras_core.operations.math import * # noqa: F403
|
||||
from keras_core.operations.nn import * # noqa: F403
|
||||
from keras_core.operations.numpy import * # noqa: F403
|
||||
|
@ -37,6 +37,7 @@ from keras_core.backend import any_symbolic_tensors
|
||||
from keras_core.backend.common.backend_utils import (
|
||||
compute_conv_transpose_output_shape,
|
||||
)
|
||||
from keras_core.operations import operation_utils
|
||||
from keras_core.operations.operation import Operation
|
||||
|
||||
|
||||
@ -290,34 +291,13 @@ class MaxPool(Operation):
|
||||
)
|
||||
|
||||
def compute_output_spec(self, inputs):
|
||||
strides = self.pool_size if self.strides is None else self.strides
|
||||
input_shape = np.array(inputs.shape)
|
||||
if self.data_format == "channels_last":
|
||||
spatial_shape = input_shape[1:-1]
|
||||
else:
|
||||
spatial_shape = input_shape[2:]
|
||||
pool_size = np.array(self.pool_size)
|
||||
if self.padding == "valid":
|
||||
output_spatial_shape = (
|
||||
np.floor((spatial_shape - self.pool_size) / strides) + 1
|
||||
)
|
||||
negative_in_shape = np.all(output_spatial_shape < 0)
|
||||
if negative_in_shape:
|
||||
raise ValueError(
|
||||
"Computed output size would be negative. Received "
|
||||
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
||||
)
|
||||
elif self.padding == "same":
|
||||
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
||||
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
||||
if self.data_format == "channels_last":
|
||||
output_shape = (
|
||||
[inputs.shape[0]]
|
||||
+ list(output_spatial_shape)
|
||||
+ [inputs.shape[-1]]
|
||||
)
|
||||
else:
|
||||
output_shape = inputs.shape[:2] + list(output_spatial_shape)
|
||||
output_shape = operation_utils.compute_pooling_output_shape(
|
||||
inputs.shape,
|
||||
self.pool_size,
|
||||
self.strides,
|
||||
self.padding,
|
||||
self.data_format,
|
||||
)
|
||||
return KerasTensor(output_shape, dtype=inputs.dtype)
|
||||
|
||||
|
||||
@ -394,32 +374,13 @@ class AveragePool(Operation):
|
||||
)
|
||||
|
||||
def compute_output_spec(self, inputs):
|
||||
strides = self.pool_size if self.strides is None else self.strides
|
||||
input_shape = np.array(inputs.shape)
|
||||
if self.data_format == "channels_last":
|
||||
spatial_shape = input_shape[1:-1]
|
||||
else:
|
||||
spatial_shape = input_shape[2:]
|
||||
pool_size = np.array(self.pool_size)
|
||||
if self.padding == "valid":
|
||||
output_spatial_shape = (
|
||||
np.floor((spatial_shape - self.pool_size) / strides) + 1
|
||||
)
|
||||
negative_in_shape = np.all(output_spatial_shape < 0)
|
||||
if negative_in_shape:
|
||||
raise ValueError(
|
||||
"Computed output size would be negative. Received "
|
||||
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
||||
)
|
||||
elif self.padding == "same":
|
||||
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
||||
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
||||
if self.data_format == "channels_last":
|
||||
output_shape = (
|
||||
[inputs.shape[0]] + output_spatial_shape + [inputs.shape[-1]]
|
||||
)
|
||||
else:
|
||||
output_shape = inputs.shape[:2] + output_spatial_shape
|
||||
output_shape = operation_utils.compute_pooling_output_shape(
|
||||
inputs.shape,
|
||||
self.pool_size,
|
||||
self.strides,
|
||||
self.padding,
|
||||
self.data_format,
|
||||
)
|
||||
return KerasTensor(output_shape, dtype=inputs.dtype)
|
||||
|
||||
|
||||
|
49
keras_core/operations/operation_utils.py
Normal file
49
keras_core/operations/operation_utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_pooling_output_shape(
|
||||
input_shape,
|
||||
pool_size,
|
||||
strides,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
):
|
||||
"""Compute the output shape of pooling ops."""
|
||||
strides = pool_size if strides is None else strides
|
||||
input_shape_origin = list(input_shape)
|
||||
input_shape = np.array(input_shape)
|
||||
if data_format == "channels_last":
|
||||
spatial_shape = input_shape[1:-1]
|
||||
else:
|
||||
spatial_shape = input_shape[2:]
|
||||
pool_size = np.array(pool_size)
|
||||
if padding == "valid":
|
||||
output_spatial_shape = (
|
||||
np.floor((spatial_shape - pool_size) / strides) + 1
|
||||
)
|
||||
negative_in_shape = np.all(output_spatial_shape < 0)
|
||||
if negative_in_shape:
|
||||
raise ValueError(
|
||||
"Computed output size would be negative. Received: "
|
||||
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
||||
)
|
||||
elif padding == "same":
|
||||
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
||||
else:
|
||||
raise ValueError(
|
||||
"`padding` must be either `'valid'` or `'same'`. Received "
|
||||
f"{padding}."
|
||||
)
|
||||
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
||||
if data_format == "channels_last":
|
||||
output_shape = (
|
||||
(input_shape_origin[0],)
|
||||
+ output_spatial_shape
|
||||
+ (input_shape_origin[-1],)
|
||||
)
|
||||
else:
|
||||
output_shape = (
|
||||
input_shape_origin[0],
|
||||
input_shape_origin[1],
|
||||
) + output_spatial_shape
|
||||
return output_shape
|
Loading…
Reference in New Issue
Block a user