Add depthwise conv layer (#140)

This commit is contained in:
Chen Qian 2023-05-10 21:33:59 -07:00 committed by Francois Chollet
parent c9dc3e7a53
commit a3f224d2eb
9 changed files with 888 additions and 86 deletions

@ -6,18 +6,17 @@ import tensorflow as tf
from tensorflow.python.eager import context
from keras_core import backend
from keras_core.backend.tensorflow import trainer as tf_trainer
from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core.backend.tensorflow import trainer as tf_trainer
@pytest.mark.skipif(
backend.backend() != 'tensorflow',
reason='The distribute test can only run with TF backend.',
backend.backend() != "tensorflow",
reason="The distribute test can only run with TF backend.",
)
class DistributeTest(testing.TestCase):
def setUp(self):
super().setUp()
# Need at least 2 devices for distribution related tests.
@ -32,22 +31,23 @@ class DistributeTest(testing.TestCase):
)
def test_variable_creation(self):
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
with strategy.scope():
dense = layers.Dense(2)
dense.build([4, 2])
self.assertIsInstance(dense.kernel, backend.KerasVariable)
self.assertIsInstance(dense.kernel.value,
tf.distribute.DistributedValues)
self.assertIn('MirroredVariable', dense.kernel.value.__class__.__name__)
self.assertIsInstance(
dense.kernel.value, tf.distribute.DistributedValues
)
self.assertIn("MirroredVariable", dense.kernel.value.__class__.__name__)
self.assertIsInstance(dense.kernel, backend.KerasVariable)
self.assertIsInstance(dense.bias.value, tf.distribute.DistributedValues)
self.assertIn('MirroredVariable', dense.bias.value.__class__.__name__)
self.assertIn("MirroredVariable", dense.bias.value.__class__.__name__)
def test_strategy_run(self):
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
with strategy.scope():
inputs = layers.Input(shape=[4])
@ -56,8 +56,9 @@ class DistributeTest(testing.TestCase):
model = models.Functional(inputs, output)
self.assertIsInstance(dense.kernel, backend.KerasVariable)
self.assertIsInstance(dense.kernel.value,
tf.distribute.DistributedValues)
self.assertIsInstance(
dense.kernel.value, tf.distribute.DistributedValues
)
def input_fn(ctx):
if ctx.replica_id_in_sync_group == 1:
@ -65,8 +66,9 @@ class DistributeTest(testing.TestCase):
else:
return tf.zeros([8, 4])
distributed_inputs = strategy.\
experimental_distribute_values_from_function(input_fn)
distributed_inputs = (
strategy.experimental_distribute_values_from_function(input_fn)
)
@tf.function
def run_fn(data):
@ -74,8 +76,9 @@ class DistributeTest(testing.TestCase):
result = strategy.run(run_fn, args=(distributed_inputs,))
self.assertIsInstance(result,
tf.types.experimental.distributed.PerReplica)
self.assertIsInstance(
result, tf.types.experimental.distributed.PerReplica
)
self.assertLen(result.values, 2)
self.assertEqual(result.values[0].shape, [8, 2])
self.assertEqual(result.values[1].shape, [8, 2])
@ -89,7 +92,7 @@ class DistributeTest(testing.TestCase):
batch_size = 16
shuffle = True
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
epoch_iterator = tf_trainer.TFEpochIterator(
x=x,
@ -97,7 +100,7 @@ class DistributeTest(testing.TestCase):
sample_weight=sample_weight,
batch_size=batch_size,
shuffle=shuffle,
distribute_strategy=strategy
distribute_strategy=strategy,
)
steps_seen = []
for step, data_iterator in epoch_iterator.enumerate_epoch():
@ -106,8 +109,8 @@ class DistributeTest(testing.TestCase):
self.assertEqual(len(batch), 3)
x, y, sample_weight = batch
self.assertTrue(
isinstance(x,
tf.types.experimental.distributed.PerReplica))
isinstance(x, tf.types.experimental.distributed.PerReplica)
)
# Make sure the local batch size is 8
if step < 6:
self.assertEqual(x.values[0].shape, [8, 16])

@ -245,7 +245,9 @@ def depthwise_conv(
"`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: "
"{inputs.ndim}."
)
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
# Because we use `tf.nn.depthwise_conv2d` for both 1D and 2D convs, we set
# `tf_data_format` using 2D conv format.
tf_data_format = _convert_data_format(data_format, 4)
padding = padding.upper()
if isinstance(strides, int):
strides = (strides,) * num_spatial_dims

@ -442,13 +442,17 @@ class TFEpochIterator(EpochIterator):
if not self._current_iterator:
self._current_iterator = iter(
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
self.data_adapter.get_tf_dataset()
)
)
for step in range(self.steps_per_epoch):
yield step, self._current_iterator
else:
iterator = iter(
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
self.data_adapter.get_tf_dataset()
)
)
if self.num_batches:
for step in range(self.num_batches):
yield step, iterator

@ -6,6 +6,8 @@ from keras_core.layers.convolutional.conv2d import Conv2D
from keras_core.layers.convolutional.conv2d_transpose import Conv2DTranspose
from keras_core.layers.convolutional.conv3d import Conv3D
from keras_core.layers.convolutional.conv3d_transpose import Conv3DTranspose
from keras_core.layers.convolutional.depthwise_conv1d import DepthwiseConv1D
from keras_core.layers.convolutional.depthwise_conv2d import DepthwiseConv2D
from keras_core.layers.core.dense import Dense
from keras_core.layers.core.einsum_dense import EinsumDense
from keras_core.layers.core.embedding import Embedding

@ -0,0 +1,279 @@
"""Keras base class for depthwise convolution layers."""
from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
class BaseDepthwiseConv(Layer):
"""Abstract N-D depthwise convolution layer.
Depthwise convolution is a type of convolution in which each input channel
is convolved with a different kernel (called a depthwise kernel). You can
understand depthwise convolution as the first step in a depthwise separable
convolution.
It is implemented via the following steps:
- Split the input into individual channels.
- Convolve each channel with an individual depthwise kernel with
`depth_multiplier` output channels.
- Concatenate the convolved outputs along the channels axis.
Unlike a regular convolution, depthwise convolution does not mix information
across different input channels.
The `depth_multiplier` argument determines how many filter are applied to
one input channel. As such, it controls the amount of output channels that
are generated per input channel in the depthwise step.
Args:
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
depth_multiplier: The number of depthwise convolution output channels
for each input channel. The total number of depthwise convolution
output channels will be equal to `input_channel * depth_multiplier`.
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the depthwise convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the depthwise convolution. If only one int is specified, the same
stride size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
"""
def __init__(
self,
rank,
depth_multiplier,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs,
)
self.rank = rank
self.depth_multiplier = depth_multiplier
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * self.rank
self.kernel_size = kernel_size
if isinstance(strides, int):
strides = (strides,) * self.rank
self.strides = strides
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * self.rank
self.dilation_rate = dilation_rate
self.padding = padding
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format
if self.depth_multiplier is not None and self.depth_multiplier <= 0:
raise ValueError(
"Invalid value for argument `depth_multiplier`. Expected a "
"strictly positive value. Received "
f"depth_multiplier={self.depth_multiplier}."
)
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0(s). Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0(s). Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
raise ValueError(
"`strides > 1` not supported in conjunction with "
f"`dilation_rate > 1`. Received: strides={self.strides} and "
f"dilation_rate={self.dilation_rate}"
)
def build(self, input_shape):
if self.data_format == "channels_last":
channel_axis = -1
input_channel = input_shape[-1]
else:
channel_axis = 1
input_channel = input_shape[1]
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
kernel_shape = self.kernel_size + (
input_channel,
self.depth_multiplier,
)
self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.depth_multiplier * input_channel,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype,
)
else:
self.bias = None
self.built = True
def _get_input_channel(self, input_shape):
if self.data_format == "channels_last":
input_channel = input_shape[-1]
else:
input_channel = input_shape[1]
return input_channel
def call(self, inputs):
input_channel = self._get_input_channel(inputs.shape)
outputs = ops.depthwise_conv(
inputs,
self.kernel,
strides=self.strides,
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (
self.depth_multiplier * input_channel,
)
else:
bias_shape = (1, self.depth_multiplier * input_channel) + (
1,
) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
input_channel = self._get_input_channel(input_shape)
return compute_conv_output_shape(
input_shape,
self.depth_multiplier * input_channel,
self.kernel_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def get_config(self):
config = super().get_config()
config.update(
{
"depth_multiplier": self.depth_multiplier,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config

@ -0,0 +1,136 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_depthwise_conv import (
BaseDepthwiseConv,
)
@keras_core_export("keras_core.layers.DepthwiseConv1D")
class DepthwiseConv1D(BaseDepthwiseConv):
"""1D depthwise convolution layer.
Depthwise convolution is a type of convolution in which each input channel
is convolved with a different kernel (called a depthwise kernel). You can
understand depthwise convolution as the first step in a depthwise separable
convolution.
It is implemented via the following steps:
- Split the input into individual channels.
- Convolve each channel with an individual depthwise kernel with
`depth_multiplier` output channels.
- Concatenate the convolved outputs along the channels axis.
Unlike a regular 1D convolution, depthwise convolution does not mix
information across different input channels.
The `depth_multiplier` argument determines how many filters are applied to
one input channel. As such, it controls the amount of output channels that
are generated per input channel in the depthwise step.
Args:
depth_multiplier: The number of depthwise convolution output channels
for each input channel. The total number of depthwise convolution
output channels will be equal to `input_channel * depth_multiplier`.
kernel_size: int or tuple/list of 1 integer, specifying the size of the
depthwise convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape:
`(batch_shape, new_steps, channels * depth_multiplier)`
- If `data_format="channels_first"`:
A 3D tensor with shape:
`(batch_shape, channels * depth_multiplier, new_steps)`
Returns:
A 3D tensor representing
`activation(depthwise_conv1d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
Examples:
>>> x = np.random.rand(4, 10, 12)
>>> y = keras_core.layers.DepthwiseConv1D(3, 3, 2, activation='relu')(x)
>>> print(y.shape)
(4, 4, 36)
"""
def __init__(
self,
depth_multiplier,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=1,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,136 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_depthwise_conv import (
BaseDepthwiseConv,
)
@keras_core_export("keras_core.layers.DepthwiseConv2D")
class DepthwiseConv2D(BaseDepthwiseConv):
"""2D depthwise convolution layer.
Depthwise convolution is a type of convolution in which each input channel
is convolved with a different kernel (called a depthwise kernel). You can
understand depthwise convolution as the first step in a depthwise separable
convolution.
It is implemented via the following steps:
- Split the input into individual channels.
- Convolve each channel with an individual depthwise kernel with
`depth_multiplier` output channels.
- Concatenate the convolved outputs along the channels axis.
Unlike a regular 2D convolution, depthwise convolution does not mix
information across different input channels.
The `depth_multiplier` argument determines how many filters are applied to
one input channel. As such, it controls the amount of output channels that
are generated per input channel in the depthwise step.
Args:
depth_multiplier: The number of depthwise convolution output channels
for each input channel. The total number of depthwise convolution
output channels will be equal to `input_channel * depth_multiplier`.
kernel_size: int or tuple/list of 2 integer, specifying the size of the
depthwise convolution window.
strides: int or tuple/list of 2 integer, specifying the stride length
of the depthwise convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 2 integers, specifying the dilation
rate to use for dilated convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, height, width, channels)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape:
`(batch_size, new_height, new_width, channels * depth_multiplier)`
- If `data_format="channels_first"`:
A 4D tensor with shape:
`(batch_size, channels * depth_multiplier, new_height, new_width)`
Returns:
A 4D tensor representing
`activation(depthwise_conv2d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
Examples:
>>> x = np.random.rand(4, 10, 10, 12)
>>> y = keras_core.layers.DepthwiseConv2D(3, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 8, 36)
"""
def __init__(
self,
depth_multiplier,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=2,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,294 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class DepthwiseConvBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"depth_multiplier": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (3, 5, 4),
"output_shape": (3, 4, 20),
},
{
"depth_multiplier": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2,),
"input_shape": (3, 4, 4),
"output_shape": (3, 4, 24),
},
{
"depth_multiplier": 6,
"kernel_size": 2,
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (3, 5, 4),
"output_shape": (3, 2, 24),
},
)
def test_depthwise_conv1d_basic(
self,
depth_multiplier,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
input_shape,
output_shape,
):
self.run_layer_test(
layers.DepthwiseConv1D,
init_kwargs={
"depth_multiplier": depth_multiplier,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
{
"depth_multiplier": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (3, 5, 5, 4),
"output_shape": (3, 4, 4, 20),
},
{
"depth_multiplier": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2),
"input_shape": (3, 4, 4, 4),
"output_shape": (3, 4, 4, 24),
},
{
"depth_multiplier": 6,
"kernel_size": (2, 2),
"strides": (2, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
"input_shape": (3, 5, 5, 4),
"output_shape": (3, 2, 2, 24),
},
)
def test_depthwise_conv2d_basic(
self,
depth_multiplier,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
input_shape,
output_shape,
):
self.run_layer_test(
layers.DepthwiseConv2D,
init_kwargs={
"depth_multiplier": depth_multiplier,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
def test_bad_init_args(self):
# `depth_multiplier` is not positive.
with self.assertRaises(ValueError):
layers.DepthwiseConv1D(depth_multiplier=0, kernel_size=1)
# `kernel_size` has 0.
with self.assertRaises(ValueError):
layers.DepthwiseConv2D(depth_multiplier=2, kernel_size=(1, 0))
# `strides` has 0.
with self.assertRaises(ValueError):
layers.DepthwiseConv2D(
depth_multiplier=2, kernel_size=(2, 2), strides=(1, 0)
)
# `dilation_rate > 1` while `strides > 1`.
with self.assertRaises(ValueError):
layers.DepthwiseConv2D(
depth_multiplier=2,
kernel_size=(2, 2),
strides=2,
dilation_rate=(2, 1),
)
class DepthwiseConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"depth_multiplier": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
{
"depth_multiplier": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2,),
},
{
"depth_multiplier": 6,
"kernel_size": (2,),
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
)
def test_depthwise_conv1d(
self,
depth_multiplier,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
):
layer = layers.DepthwiseConv1D(
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
tf_keras_layer = tf.keras.layers.DepthwiseConv1D(
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
inputs = np.random.normal(size=[2, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(depth_multiplier * 4,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.depthwise_kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
{
"depth_multiplier": 5,
"kernel_size": 2,
"strides": 1,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
{
"depth_multiplier": 6,
"kernel_size": 2,
"strides": 1,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (2, 2),
},
{
"depth_multiplier": 6,
"kernel_size": (2, 2),
"strides": (2, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
},
)
def test_depthwise_conv2d(
self,
depth_multiplier,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
):
layer = layers.DepthwiseConv2D(
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
tf_keras_layer = tf.keras.layers.DepthwiseConv2D(
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
inputs = np.random.normal(size=[2, 8, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(depth_multiplier * 4,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.depthwise_kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@ -29,8 +29,6 @@ in_top_k
ctc ??
"""
import numpy as np
from keras_core import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
@ -565,67 +563,15 @@ class DepthwiseConv(Operation):
)
def compute_output_spec(self, inputs, kernel):
input_shape = inputs.shape
if self.data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
if len(kernel.shape) != len(inputs.shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel.shape} and "
f"input of shape {input_shape}."
)
if isinstance(self.dilation_rate, int):
dilation_rate = (self.dilation_rate,) * len(spatial_shape)
else:
dilation_rate = self.dilation_rate
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={self.dilation_rate}` and input of "
f"shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel.shape[:-2])
dilation_rate = np.array(self.dilation_rate)
if self.padding == "valid":
output_spatial_shape = (
np.floor(
(
spatial_shape
- dilation_rate * (kernel_spatial_shape - 1)
- 1
)
/ self.strides
)
+ 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={inputs.shape}`, "
f"`kernel spatial size={kernel.size}`, "
f"`dilation_rate={self.dilation_rate}`."
)
elif self.padding == "same":
output_spatial_shape = (
np.floor((spatial_shape - 1) / self.strides) + 1
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
output_channels = kernel.shape[-1] * kernel.shape[-2]
if self.data_format == "channels_last":
output_shape = (
[input_shape[0]] + output_spatial_shape + [output_channels]
)
else:
output_shape = [
input_shape[0],
output_channels,
] + output_spatial_shape
output_shape = operation_utils.compute_conv_output_shape(
inputs.shape,
kernel.shape[-1] * kernel.shape[-2],
kernel.shape[:-2],
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
)
return KerasTensor(output_shape, dtype=inputs.dtype)