keras/keras_core/layers/pooling/base_pooling.py
Chen Qian 841b8d702d Add max and average pooling layer (#66)
* Add max and poolig layer

* fix tests

* handle TF transpose

* renaming

* rename tests

* Fix comments

* Move out the shape computation logic
2023-05-02 15:17:59 -07:00

77 lines
2.2 KiB
Python

from keras_core import operations as ops
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_pooling_output_shape
class BasePooling(Layer):
"""Base pooling layer."""
def __init__(
self,
pool_size,
strides,
pool_dimensions,
pool_mode="max",
padding="valid",
data_format=None,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.pool_size = pool_size
self.strides = pool_size if strides is None else strides
self.pool_mode = pool_mode
self.padding = padding
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.input_spec = InputSpec(ndim=pool_dimensions + 2)
def call(self, inputs):
if self.pool_mode == "max":
return ops.max_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
elif self.pool_mode == "average":
return ops.average_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
else:
raise ValueError(
"`pool_mode` must be either 'max' or 'average'. Received: "
f"{self.pool_mode}."
)
def compute_output_shape(self, input_shape):
return compute_pooling_output_shape(
input_shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
def get_config(self):
config = super().get_config()
config.update(
{
"pool_size": self.pool_size,
"padding": self.padding,
"strides": self.strides,
"data_format": self.data_format,
}
)
return config