diff --git a/keras_core/layers/convolutional/conv1d.py b/keras_core/layers/convolutional/conv1d.py index a8761d8fd..f3eccbf39 100644 --- a/keras_core/layers/convolutional/conv1d.py +++ b/keras_core/layers/convolutional/conv1d.py @@ -1,3 +1,4 @@ +import keras_core.operations as ops from keras_core.api_export import keras_core_export from keras_core.layers.convolutional.base_conv import BaseConv @@ -22,10 +23,15 @@ class Conv1D(BaseConv): strides: int or tuple/list of 1 integer, specifying the stride length of the convolution. `stride value != 1` is incompatible with `dilation_rate != 1`. - padding: string, either `"valid"` or `"same"` (case-insensitive). + padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive). `"valid"` means no padding. `"same"` results in padding evenly to the left/right or up/down of the input such that output has the same - height/width dimension as the input. + height/width dimension as the input. `"causal"` results in causal + (dilated) convolutions, e.g. `output[t]` does not depend on + `input[t+1:]`. Useful when modeling temporal data where the model + should not violate the temporal order. + See [WaveNet: A Generative Model for Raw Audio, section2.1]( + https://arxiv.org/abs/1609.03499). data_format: string, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape `(batch, steps, features)` @@ -126,3 +132,39 @@ class Conv1D(BaseConv): bias_constraint=bias_constraint, **kwargs ) + + def _compute_causal_padding(self): + left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1) + if self.data_format == "channels_last": + causal_padding = [[0, 0], [left_pad, 0], [0, 0]] + else: + causal_padding = [[0, 0], [0, 0], [left_pad, 0]] + return causal_padding + + def call(self, inputs): + padding = self.padding + if self.padding == "causal": + # Apply causal padding to inputs. + inputs = ops.pad(inputs, self._compute_causal_padding()) + padding = "valid" + + outputs = ops.conv( + inputs, + self.kernel, + strides=list(self.strides), + padding=padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs diff --git a/keras_core/layers/convolutional/conv_test.py b/keras_core/layers/convolutional/conv_test.py index f05d3ccc1..13781131f 100644 --- a/keras_core/layers/convolutional/conv_test.py +++ b/keras_core/layers/convolutional/conv_test.py @@ -30,6 +30,17 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase): "input_shape": (3, 4, 4), "output_shape": (3, 4, 6), }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "causal", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + "input_shape": (3, 4, 4), + "output_shape": (3, 4, 6), + }, { "filters": 6, "kernel_size": 2, @@ -249,6 +260,15 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase): "dilation_rate": (2,), "groups": 2, }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "causal", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + }, { "filters": 6, "kernel_size": (2,), diff --git a/keras_core/operations/operation_utils.py b/keras_core/operations/operation_utils.py index 3ee11ed46..37c274a28 100644 --- a/keras_core/operations/operation_utils.py +++ b/keras_core/operations/operation_utils.py @@ -103,7 +103,7 @@ def compute_conv_output_shape( f"`kernel shape={kernel_shape}`, " f"`dilation_rate={dilation_rate}`." ) - elif padding == "same": + elif padding == "same" or padding == "causal": output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 output_spatial_shape = tuple([int(i) for i in output_spatial_shape]) if data_format == "channels_last":