Add Conv layers (#89)

* initials

* more

* something

* add docstrings

* fix some docstrings

* fix comments
This commit is contained in:
Chen Qian 2023-05-05 22:13:13 -07:00 committed by Francois Chollet
parent 9bd7a380de
commit 43e33ab9ab
11 changed files with 1250 additions and 77 deletions

@ -213,7 +213,7 @@ def conv(
kernel,
strides=1,
padding="valid",
data_format="channel_last",
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = inputs.ndim - 2
@ -234,6 +234,18 @@ def conv(
data_format,
include_batch_and_channels=False,
)
if data_format == "channels_last":
channels = inputs.shape[-1]
else:
channels = inputs.shape[1]
kernel_in_channels = kernel.shape[-2]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
return jax.lax.conv_general_dilated(
inputs,
kernel,
@ -241,6 +253,7 @@ def conv(
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
@ -249,7 +262,7 @@ def depthwise_conv(
kernel,
strides=1,
padding="valid",
data_format="channel_last",
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = inputs.ndim - 2

@ -63,6 +63,10 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor):
def numpy(self): # noqa: F811
return self.value.numpy()
@property
def shape(self):
return tf.TensorShape(super().shape)
# Overload native accessor.
def __tf_tensor__(self, dtype=None, name=None):
return tf.convert_to_tensor(self.value, dtype=dtype, name=name)

@ -203,22 +203,32 @@ def conv(
data_format="channel_last",
dilation_rate=1,
):
"""General N-D convolution function.
def _conv():
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
return tf.nn.convolution(
inputs,
kernel,
strides,
padding.upper(),
data_format=tf_data_format,
dilations=dilation_rate,
)
Arg:
"""
# Reason for making this function is in Tensorflow, `groups > 1` does not
# work on CPU for `tf.nn.convolution`, but wrapping it by XLA works.
@tf.function(jit_compile=True)
def _conv_xla():
return _conv()
data_format = _convert_data_format(data_format, len(inputs.shape))
padding = padding.upper()
return tf.nn.convolution(
inputs,
kernel,
strides,
padding,
data_format=data_format,
dilations=dilation_rate,
)
if data_format == "channels_last":
channels = inputs.shape[-1]
else:
channels = inputs.shape[1]
if channels != kernel.shape[-2]:
# If kernel's in_channel does not match input's channels, it indicates
# convolution is broken down into groups.
return _conv_xla()
return _conv()
def depthwise_conv(

@ -1,4 +1,7 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.convolutional.conv1d import Conv1D
from keras_core.layers.convolutional.conv2d import Conv2D
from keras_core.layers.convolutional.conv3d import Conv3D
from keras_core.layers.core.dense import Dense
from keras_core.layers.core.einsum_dense import EinsumDense
from keras_core.layers.core.embedding import Embedding

@ -0,0 +1,282 @@
"""Keras base class for convolution layers."""
from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
class BaseConv(Layer):
"""Abstract N-D convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved
(actually cross-correlated) with the layer input to produce a tensor of
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
Note: layer attributes cannot be modified after the layer has been called
once (except the `trainable` attribute).
Args:
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the convolution. If only one int is specified, the same stride
size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
"""
def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
groups=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs,
)
self.rank = rank
self.filters = filters
self.groups = groups or 1
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * self.rank
self.kernel_size = kernel_size
if isinstance(strides, int):
strides = (strides,) * self.rank
self.strides = strides
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * self.rank
self.dilation_rate = dilation_rate
self.padding = padding
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format
if self.filters is not None and self.filters <= 0:
raise ValueError(
"Invalid value for argument `filters`. Expected a strictly "
f"positive value. Received filters={self.filters}."
)
if self.filters is not None and self.filters % self.groups != 0:
raise ValueError(
"The number of filters must be evenly divisible by the "
f"number of groups. Received: groups={self.groups}, "
f"filters={self.filters}."
)
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0(s). Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0(s). Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
raise ValueError(
"`strides > 1` not supported in conjunction with "
f"`dilation_rate > 1`. Received: strides={self.strides} and "
f"dilation_rate={self.dilation_rate}"
)
def build(self, input_shape):
if self.data_format == "channels_last":
channel_axis = -1
input_channel = input_shape[-1]
else:
channel_axis = 1
input_channel = input_shape[1]
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
if input_channel % self.groups != 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"the number of groups. Received groups={self.groups}, but the "
f"input has {input_channel} channels (full input shape is "
f"{input_shape})."
)
kernel_shape = self.kernel_size + (
input_channel // self.groups,
self.filters,
)
# compute_output_shape contains some validation logic for the input
# shape, and make sure the output shape has all positive dimensions.
self.compute_output_shape(input_shape)
self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype,
)
else:
self.bias = None
self.built = True
def call(self, inputs):
outputs = ops.conv(
inputs,
self.kernel,
strides=list(self.strides),
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
return compute_conv_output_shape(
input_shape,
self.filters,
self.kernel_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"groups": self.groups,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config

@ -0,0 +1,129 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv import BaseConv
@keras_core_export(
["keras_core.layers.Conv1D", "keras_core.layers.Convolution1D"]
)
class Conv1D(BaseConv):
"""1D convolution layer (e.g. temporal convolution).
This layer creates a convolution kernel that is convolved with the layer
input over a single spatial (or temporal) dimension to produce a tensor of
outputs. If `use_bias` is True, a bias vector is created and added to the
outputs. Finally, if `activation` is not `None`, it is applied to the
outputs as well.
Args:
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of 1 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated convolution.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
Returns:
A 3D tensor representing `activation(conv1d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
Examples:
>>> # The inputs are 128-length vectors with 10 timesteps, and the
>>> # batch size is 4.
>>> input_shape = (4, 10, 128)
>>> x = np.random.normal(4, 10, 128)
>>> y = keras_core.layers.Conv1D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
groups=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=1,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,128 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv import BaseConv
@keras_core_export(
["keras_core.layers.Conv2D", "keras_core.layers.Convolution2D"]
)
class Conv2D(BaseConv):
"""2D convolution layer.
This layer creates a convolution kernel that is convolved with the layer
input over a single spatial (or temporal) dimension to produce a tensor of
outputs. If `use_bias` is True, a bias vector is created and added to the
outputs. Finally, if `activation` is not `None`, it is applied to the
outputs as well.
Args:
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of 2 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 2 integer, specifying the stride length
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch_size, channels, height, width)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
dilation_rate: int or tuple/list of 2 integers, specifying the dilation
rate to use for dilated convolution.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, height, width, channels)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`
Returns:
A 4D tensor representing `activation(conv2d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
Examples:
>>> x = np.random.normal(4, 10, 10, 128)
>>> y = keras_core.layers.Conv2D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 8, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
groups=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,134 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv import BaseConv
@keras_core_export(
["keras_core.layers.Conv3D", "keras_core.layers.Convolution3D"]
)
class Conv3D(BaseConv):
"""3D convolution layer.
This layer creates a convolution kernel that is convolved with the layer
input over a single spatial (or temporal) dimension to produce a tensor of
outputs. If `use_bias` is True, a bias vector is created and added to the
outputs. Finally, if `activation` is not `None`, it is applied to the
outputs as well.
Args:
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of 3 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 3 integer, specifying the stride length
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`.
It defaults to the `image_data_format` value found in your Keras
config file at `~/.keras/keras.json`. If you never set it, then it
will be `"channels_last"`.
dilation_rate: int or tuple/list of 3 integers, specifying the dilation
rate to use for dilated convolution.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
new_spatial_dim3)`
Returns:
A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
Examples:
>>> x = np.random.rand(4, 10, 10, 10, 128)
>>> y = keras_core.layers.Conv3D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 8, 8, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
groups=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=3,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,456 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class ConvBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
"input_shape": (3, 5, 4),
"output_shape": (3, 4, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2,),
"groups": 2,
"input_shape": (3, 4, 4),
"output_shape": (3, 4, 6),
},
{
"filters": 6,
"kernel_size": 2,
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 2,
"input_shape": (3, 5, 4),
"output_shape": (3, 2, 6),
},
)
def test_conv1d_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv1D,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
"groups": groups,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
"input_shape": (3, 5, 5, 4),
"output_shape": (3, 4, 4, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2),
"groups": 2,
"input_shape": (3, 4, 4, 4),
"output_shape": (3, 4, 4, 6),
},
{
"filters": 6,
"kernel_size": (2, 2),
"strides": (2, 1),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
"groups": 2,
"input_shape": (3, 5, 5, 4),
"output_shape": (3, 2, 4, 6),
},
)
def test_conv2d_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv2D,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
"groups": groups,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
"input_shape": (3, 5, 5, 5, 4),
"output_shape": (3, 4, 4, 4, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2, 2),
"groups": 2,
"input_shape": (3, 4, 4, 4, 4),
"output_shape": (3, 4, 4, 4, 6),
},
{
"filters": 6,
"kernel_size": (2, 2, 3),
"strides": (2, 1, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
"groups": 2,
"input_shape": (3, 5, 5, 5, 4),
"output_shape": (3, 2, 4, 2, 6),
},
)
def test_conv3d_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv3D,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
"groups": groups,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
def test_bad_init_args(self):
# `filters` is not positive.
with self.assertRaises(ValueError):
layers.Conv1D(filters=0, kernel_size=1)
# `kernel_size` has 0.
with self.assertRaises(ValueError):
layers.Conv2D(filters=2, kernel_size=(1, 0))
# `strides` has 0.
with self.assertRaises(ValueError):
layers.Conv2D(filters=2, kernel_size=(2, 2), strides=(1, 0))
# `dilation_rate > 1` while `strides > 1`.
with self.assertRaises(ValueError):
layers.Conv2D(
filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1)
)
# `filters` cannot be divided by `groups`.
with self.assertRaises(ValueError):
layers.Conv2D(filters=5, kernel_size=(2, 2), groups=2)
class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2,),
"groups": 2,
},
{
"filters": 6,
"kernel_size": (2,),
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 2,
},
)
def test_conv1d(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
):
layer = layers.Conv1D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
tf_keras_layer = tf.keras.layers.Conv1D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
inputs = np.random.normal(size=[2, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2),
"groups": 2,
},
{
"filters": 6,
"kernel_size": (2, 2),
"strides": (2, 1),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
"groups": 2,
},
)
def test_conv2d(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
):
layer = layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
tf_keras_layer = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
inputs = np.random.normal(size=[2, 8, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2, 2),
"groups": 2,
},
{
"filters": 6,
"kernel_size": (2, 2, 3),
"strides": (2, 1, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
"groups": 2,
},
)
def test_conv3d(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
groups,
):
layer = layers.Conv3D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
tf_keras_layer = tf.keras.layers.Conv3D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
inputs = np.random.normal(size=[2, 8, 8, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@ -467,72 +467,22 @@ class Conv(Operation):
return backend.nn.conv(
inputs,
kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def compute_output_spec(self, inputs, kernel):
output_shape = operation_utils.compute_conv_output_shape(
inputs.shape,
kernel.shape[-1],
kernel.shape[:-2],
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
)
def compute_output_spec(self, inputs, kernel):
input_shape = inputs.shape
if self.data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
if len(kernel.shape) != len(input_shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel.shape} and "
f"input of shape {input_shape}."
)
if isinstance(self.dilation_rate, int):
dilation_rate = (self.dilation_rate,) * len(spatial_shape)
else:
dilation_rate = self.dilation_rate
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={self.dilation_rate}` and "
f"input of shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel.shape[:-2])
dilation_rate = np.array(dilation_rate)
if self.padding == "valid":
output_spatial_shape = (
np.floor(
(
spatial_shape
- dilation_rate * (kernel_spatial_shape - 1)
- 1
)
/ self.strides
)
+ 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={inputs.shape}`, "
f"`kernel spatial size={kernel.size}`, "
f"`dilation_rate={self.dilation_rate}`."
)
elif self.padding == "same":
output_spatial_shape = (
np.floor((spatial_shape - 1) / self.strides) + 1
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
if self.data_format == "channels_last":
output_shape = (
[inputs.shape[0]] + output_spatial_shape + [kernel.shape[-1]]
)
else:
output_shape = [
inputs.shape[0],
kernel.shape[-1],
] + output_spatial_shape
return KerasTensor(output_shape, dtype=inputs.dtype)

@ -47,3 +47,67 @@ def compute_pooling_output_shape(
input_shape_origin[1],
) + output_spatial_shape
return output_shape
def compute_conv_output_shape(
input_shape,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""Compute the output shape of conv ops."""
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
kernel_shape = kernel_size + (input_shape[-1], filters)
else:
spatial_shape = input_shape[2:]
kernel_shape = kernel_size + (input_shape[1], filters)
if len(kernel_shape) != len(input_shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel_shape} and "
f"input of shape {input_shape}."
)
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * len(spatial_shape)
if isinstance(strides, int):
strides = (strides,) * len(spatial_shape)
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={dilation_rate}` and "
f"input of shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel_shape[:-2])
dilation_rate = np.array(dilation_rate)
if padding == "valid":
output_spatial_shape = (
np.floor(
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
/ strides
)
+ 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={input_shape}`, "
f"`kernel shape={kernel_shape}`, "
f"`dilation_rate={dilation_rate}`."
)
elif padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
if data_format == "channels_last":
output_shape = (
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
)
else:
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
return output_shape