keras/keras_core/layers/convolutional/base_conv.py
2023-06-10 15:23:02 -07:00

280 lines
12 KiB
Python

"""Keras base class for convolution layers."""
from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
from keras_core.utils.argument_validation import standardize_padding
from keras_core.utils.argument_validation import standardize_tuple
class BaseConv(Layer):
"""Abstract N-D convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved (actually
cross-correlated) with the layer input to produce a tensor of outputs. If
`use_bias` is True (and a `bias_initializer` is provided), a bias vector is
created and added to the outputs. Finally, if `activation` is not `None`, it
is applied to the outputs as well.
Note: layer attributes cannot be modified after the layer has been called
once (except the `trainable` attribute).
Args:
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of `rank` integers, specifying the size
of the convolution window.
strides: int or tuple/list of `rank` integers, specifying the stride
length of the convolution. If only one int is specified, the same
stride size will be used for all dimensions. `strides > 1` is
incompatible with `dilation_rate > 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of `rank` integers, specifying the
dilation rate to use for dilated convolution. If only one int is
specified, the same dilation rate will be used for all dimensions.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
"""
def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
groups=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs,
)
self.rank = rank
self.filters = filters
self.groups = groups or 1
self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size")
self.strides = standardize_tuple(strides, rank, "strides")
self.dilation_rate = standardize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.padding = standardize_padding(padding, allow_causal=rank == 1)
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format
if self.filters is not None and self.filters <= 0:
raise ValueError(
"Invalid value for argument `filters`. Expected a strictly "
f"positive value. Received filters={self.filters}."
)
if self.filters is not None and self.filters % self.groups != 0:
raise ValueError(
"The number of filters must be evenly divisible by the "
f"number of groups. Received: groups={self.groups}, "
f"filters={self.filters}."
)
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received "
f"kernel_size={self.kernel_size}."
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received "
f"strides={self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
raise ValueError(
"`strides > 1` not supported in conjunction with "
f"`dilation_rate > 1`. Received: strides={self.strides} and "
f"dilation_rate={self.dilation_rate}"
)
def build(self, input_shape):
if self.data_format == "channels_last":
channel_axis = -1
input_channel = input_shape[-1]
else:
channel_axis = 1
input_channel = input_shape[1]
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
if input_channel % self.groups != 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"the number of groups. Received groups={self.groups}, but the "
f"input has {input_channel} channels (full input shape is "
f"{input_shape})."
)
kernel_shape = self.kernel_size + (
input_channel // self.groups,
self.filters,
)
# compute_output_shape contains some validation logic for the input
# shape, and make sure the output shape has all positive dimensions.
self.compute_output_shape(input_shape)
self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype,
)
else:
self.bias = None
self.built = True
def convolution_op(self, inputs, kernel):
return ops.conv(
inputs,
kernel,
strides=list(self.strides),
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)
def call(self, inputs):
outputs = self.convolution_op(
inputs,
self.kernel,
)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
return compute_conv_output_shape(
input_shape,
self.filters,
self.kernel_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"groups": self.groups,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config