280 lines
12 KiB
Python
280 lines
12 KiB
Python
"""Keras base class for convolution layers."""
|
|
|
|
from keras_core import activations
|
|
from keras_core import constraints
|
|
from keras_core import initializers
|
|
from keras_core import operations as ops
|
|
from keras_core import regularizers
|
|
from keras_core.backend import standardize_data_format
|
|
from keras_core.layers.input_spec import InputSpec
|
|
from keras_core.layers.layer import Layer
|
|
from keras_core.operations.operation_utils import compute_conv_output_shape
|
|
from keras_core.utils.argument_validation import standardize_padding
|
|
from keras_core.utils.argument_validation import standardize_tuple
|
|
|
|
|
|
class BaseConv(Layer):
|
|
"""Abstract N-D convolution layer (private, used as implementation base).
|
|
|
|
This layer creates a convolution kernel that is convolved (actually
|
|
cross-correlated) with the layer input to produce a tensor of outputs. If
|
|
`use_bias` is True (and a `bias_initializer` is provided), a bias vector is
|
|
created and added to the outputs. Finally, if `activation` is not `None`, it
|
|
is applied to the outputs as well.
|
|
|
|
Note: layer attributes cannot be modified after the layer has been called
|
|
once (except the `trainable` attribute).
|
|
|
|
Args:
|
|
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
|
|
filters: int, the dimension of the output space (the number of filters
|
|
in the convolution).
|
|
kernel_size: int or tuple/list of `rank` integers, specifying the size
|
|
of the convolution window.
|
|
strides: int or tuple/list of `rank` integers, specifying the stride
|
|
length of the convolution. If only one int is specified, the same
|
|
stride size will be used for all dimensions. `strides > 1` is
|
|
incompatible with `dilation_rate > 1`.
|
|
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
|
the left/right or up/down of the input such that output has the same
|
|
height/width dimension as the input.
|
|
data_format: string, either `"channels_last"` or `"channels_first"`.
|
|
The ordering of the dimensions in the inputs. `"channels_last"`
|
|
corresponds to inputs with shape `(batch, steps, features)`
|
|
while `"channels_first"` corresponds to inputs with shape
|
|
`(batch, features, steps)`. It defaults to the `image_data_format`
|
|
value found in your Keras config file at `~/.keras/keras.json`.
|
|
If you never set it, then it will be `"channels_last"`.
|
|
dilation_rate: int or tuple/list of `rank` integers, specifying the
|
|
dilation rate to use for dilated convolution. If only one int is
|
|
specified, the same dilation rate will be used for all dimensions.
|
|
groups: A positive int specifying the number of groups in which the
|
|
input is split along the channel axis. Each group is convolved
|
|
separately with `filters // groups` filters. The output is the
|
|
concatenation of all the `groups` results along the channel axis.
|
|
Input channels and `filters` must both be divisible by `groups`.
|
|
activation: Activation function. If `None`, no activation is applied.
|
|
use_bias: bool, if `True`, bias will be added to the output.
|
|
kernel_initializer: Initializer for the convolution kernel. If `None`,
|
|
the default initializer (`"glorot_uniform"`) will be used.
|
|
bias_initializer: Initializer for the bias vector. If `None`, the
|
|
default initializer (`"zeros"`) will be used.
|
|
kernel_regularizer: Optional regularizer for the convolution kernel.
|
|
bias_regularizer: Optional regularizer for the bias vector.
|
|
activity_regularizer: Optional regularizer function for the output.
|
|
kernel_constraint: Optional projection function to be applied to the
|
|
kernel after being updated by an `Optimizer` (e.g. used to implement
|
|
norm constraints or value constraints for layer weights). The
|
|
function must take as input the unprojected variable and must return
|
|
the projected variable (which must have the same shape). Constraints
|
|
are not safe to use when doing asynchronous distributed training.
|
|
bias_constraint: Optional projection function to be applied to the
|
|
bias after being updated by an `Optimizer`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rank,
|
|
filters,
|
|
kernel_size,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format=None,
|
|
dilation_rate=1,
|
|
groups=1,
|
|
activation=None,
|
|
use_bias=True,
|
|
kernel_initializer="glorot_uniform",
|
|
bias_initializer="zeros",
|
|
kernel_regularizer=None,
|
|
bias_regularizer=None,
|
|
activity_regularizer=None,
|
|
kernel_constraint=None,
|
|
bias_constraint=None,
|
|
trainable=True,
|
|
name=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
trainable=trainable,
|
|
name=name,
|
|
activity_regularizer=activity_regularizer,
|
|
**kwargs,
|
|
)
|
|
self.rank = rank
|
|
self.filters = filters
|
|
self.groups = groups or 1
|
|
self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size")
|
|
self.strides = standardize_tuple(strides, rank, "strides")
|
|
self.dilation_rate = standardize_tuple(
|
|
dilation_rate, rank, "dilation_rate"
|
|
)
|
|
self.padding = standardize_padding(padding, allow_causal=rank == 1)
|
|
self.data_format = standardize_data_format(data_format)
|
|
self.activation = activations.get(activation)
|
|
self.use_bias = use_bias
|
|
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
self.bias_initializer = initializers.get(bias_initializer)
|
|
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
|
self.bias_regularizer = regularizers.get(bias_regularizer)
|
|
self.kernel_constraint = constraints.get(kernel_constraint)
|
|
self.bias_constraint = constraints.get(bias_constraint)
|
|
self.input_spec = InputSpec(min_ndim=self.rank + 2)
|
|
self.data_format = self.data_format
|
|
|
|
if self.filters is not None and self.filters <= 0:
|
|
raise ValueError(
|
|
"Invalid value for argument `filters`. Expected a strictly "
|
|
f"positive value. Received filters={self.filters}."
|
|
)
|
|
|
|
if self.filters is not None and self.filters % self.groups != 0:
|
|
raise ValueError(
|
|
"The number of filters must be evenly divisible by the "
|
|
f"number of groups. Received: groups={self.groups}, "
|
|
f"filters={self.filters}."
|
|
)
|
|
|
|
if not all(self.kernel_size):
|
|
raise ValueError(
|
|
"The argument `kernel_size` cannot contain 0. Received "
|
|
f"kernel_size={self.kernel_size}."
|
|
)
|
|
|
|
if not all(self.strides):
|
|
raise ValueError(
|
|
"The argument `strides` cannot contains 0. Received "
|
|
f"strides={self.strides}"
|
|
)
|
|
|
|
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
|
|
raise ValueError(
|
|
"`strides > 1` not supported in conjunction with "
|
|
f"`dilation_rate > 1`. Received: strides={self.strides} and "
|
|
f"dilation_rate={self.dilation_rate}"
|
|
)
|
|
|
|
def build(self, input_shape):
|
|
if self.data_format == "channels_last":
|
|
channel_axis = -1
|
|
input_channel = input_shape[-1]
|
|
else:
|
|
channel_axis = 1
|
|
input_channel = input_shape[1]
|
|
self.input_spec = InputSpec(
|
|
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
|
|
)
|
|
if input_channel % self.groups != 0:
|
|
raise ValueError(
|
|
"The number of input channels must be evenly divisible by "
|
|
f"the number of groups. Received groups={self.groups}, but the "
|
|
f"input has {input_channel} channels (full input shape is "
|
|
f"{input_shape})."
|
|
)
|
|
kernel_shape = self.kernel_size + (
|
|
input_channel // self.groups,
|
|
self.filters,
|
|
)
|
|
|
|
# compute_output_shape contains some validation logic for the input
|
|
# shape, and make sure the output shape has all positive dimensions.
|
|
self.compute_output_shape(input_shape)
|
|
|
|
self.kernel = self.add_weight(
|
|
name="kernel",
|
|
shape=kernel_shape,
|
|
initializer=self.kernel_initializer,
|
|
regularizer=self.kernel_regularizer,
|
|
constraint=self.kernel_constraint,
|
|
trainable=True,
|
|
dtype=self.dtype,
|
|
)
|
|
if self.use_bias:
|
|
self.bias = self.add_weight(
|
|
name="bias",
|
|
shape=(self.filters,),
|
|
initializer=self.bias_initializer,
|
|
regularizer=self.bias_regularizer,
|
|
constraint=self.bias_constraint,
|
|
trainable=True,
|
|
dtype=self.dtype,
|
|
)
|
|
else:
|
|
self.bias = None
|
|
self.built = True
|
|
|
|
def convolution_op(self, inputs, kernel):
|
|
return ops.conv(
|
|
inputs,
|
|
kernel,
|
|
strides=list(self.strides),
|
|
padding=self.padding,
|
|
dilation_rate=self.dilation_rate,
|
|
data_format=self.data_format,
|
|
)
|
|
|
|
def call(self, inputs):
|
|
outputs = self.convolution_op(
|
|
inputs,
|
|
self.kernel,
|
|
)
|
|
if self.use_bias:
|
|
if self.data_format == "channels_last":
|
|
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
|
|
else:
|
|
bias_shape = (1, self.filters) + (1,) * self.rank
|
|
bias = ops.reshape(self.bias, bias_shape)
|
|
outputs += bias
|
|
|
|
if self.activation is not None:
|
|
return self.activation(outputs)
|
|
return outputs
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return compute_conv_output_shape(
|
|
input_shape,
|
|
self.filters,
|
|
self.kernel_size,
|
|
strides=self.strides,
|
|
padding=self.padding,
|
|
data_format=self.data_format,
|
|
dilation_rate=self.dilation_rate,
|
|
)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
config.update(
|
|
{
|
|
"filters": self.filters,
|
|
"kernel_size": self.kernel_size,
|
|
"strides": self.strides,
|
|
"padding": self.padding,
|
|
"data_format": self.data_format,
|
|
"dilation_rate": self.dilation_rate,
|
|
"groups": self.groups,
|
|
"activation": activations.serialize(self.activation),
|
|
"use_bias": self.use_bias,
|
|
"kernel_initializer": initializers.serialize(
|
|
self.kernel_initializer
|
|
),
|
|
"bias_initializer": initializers.serialize(
|
|
self.bias_initializer
|
|
),
|
|
"kernel_regularizer": regularizers.serialize(
|
|
self.kernel_regularizer
|
|
),
|
|
"bias_regularizer": regularizers.serialize(
|
|
self.bias_regularizer
|
|
),
|
|
"activity_regularizer": regularizers.serialize(
|
|
self.activity_regularizer
|
|
),
|
|
"kernel_constraint": constraints.serialize(
|
|
self.kernel_constraint
|
|
),
|
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
}
|
|
)
|
|
return config
|