keras/keras_core/layers/pooling/base_pooling.py

77 lines
2.2 KiB
Python
Raw Normal View History

from keras_core import operations as ops
2023-06-02 01:48:52 +00:00
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_pooling_output_shape
class BasePooling(Layer):
"""Base pooling layer."""
def __init__(
self,
pool_size,
strides,
pool_dimensions,
pool_mode="max",
padding="valid",
data_format=None,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
2023-06-02 01:48:52 +00:00
self.pool_size = pool_size
self.strides = pool_size if strides is None else strides
self.pool_mode = pool_mode
self.padding = padding
2023-06-02 01:48:52 +00:00
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.input_spec = InputSpec(ndim=pool_dimensions + 2)
def call(self, inputs):
if self.pool_mode == "max":
return ops.max_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
elif self.pool_mode == "average":
return ops.average_pool(
inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
)
else:
raise ValueError(
"`pool_mode` must be either 'max' or 'average'. Received: "
f"{self.pool_mode}."
)
def compute_output_shape(self, input_shape):
return compute_pooling_output_shape(
input_shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
def get_config(self):
config = super().get_config()
config.update(
{
"pool_size": self.pool_size,
"padding": self.padding,
"strides": self.strides,
"data_format": self.data_format,
}
)
return config