Add causal padding support to conv1d (#152)

This commit is contained in:
Chen Qian 2023-05-11 18:33:09 -07:00 committed by Francois Chollet
parent 61ae4f2f59
commit 69a3430260
3 changed files with 65 additions and 3 deletions

@ -1,3 +1,4 @@
import keras_core.operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv import BaseConv
@ -22,10 +23,15 @@ class Conv1D(BaseConv):
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
height/width dimension as the input. `"causal"` results in causal
(dilated) convolutions, e.g. `output[t]` does not depend on
`input[t+1:]`. Useful when modeling temporal data where the model
should not violate the temporal order.
See [WaveNet: A Generative Model for Raw Audio, section2.1](
https://arxiv.org/abs/1609.03499).
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
@ -126,3 +132,39 @@ class Conv1D(BaseConv):
bias_constraint=bias_constraint,
**kwargs
)
def _compute_causal_padding(self):
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
if self.data_format == "channels_last":
causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
else:
causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
return causal_padding
def call(self, inputs):
padding = self.padding
if self.padding == "causal":
# Apply causal padding to inputs.
inputs = ops.pad(inputs, self._compute_causal_padding())
padding = "valid"
outputs = ops.conv(
inputs,
self.kernel,
strides=list(self.strides),
padding=padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
if self.activation is not None:
return self.activation(outputs)
return outputs

@ -30,6 +30,17 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
"input_shape": (3, 4, 4),
"output_shape": (3, 4, 6),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "causal",
"data_format": "channels_last",
"dilation_rate": (2,),
"groups": 2,
"input_shape": (3, 4, 4),
"output_shape": (3, 4, 6),
},
{
"filters": 6,
"kernel_size": 2,
@ -249,6 +260,15 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
"dilation_rate": (2,),
"groups": 2,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 1,
"padding": "causal",
"data_format": "channels_last",
"dilation_rate": (2,),
"groups": 2,
},
{
"filters": 6,
"kernel_size": (2,),

@ -103,7 +103,7 @@ def compute_conv_output_shape(
f"`kernel shape={kernel_shape}`, "
f"`dilation_rate={dilation_rate}`."
)
elif padding == "same":
elif padding == "same" or padding == "causal":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
if data_format == "channels_last":