Add causal padding support to conv1d (#152)
This commit is contained in:
parent
61ae4f2f59
commit
69a3430260
@ -1,3 +1,4 @@
|
||||
import keras_core.operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.convolutional.base_conv import BaseConv
|
||||
|
||||
@ -22,10 +23,15 @@ class Conv1D(BaseConv):
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the convolution. `stride value != 1` is incompatible with
|
||||
`dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
height/width dimension as the input.
|
||||
height/width dimension as the input. `"causal"` results in causal
|
||||
(dilated) convolutions, e.g. `output[t]` does not depend on
|
||||
`input[t+1:]`. Useful when modeling temporal data where the model
|
||||
should not violate the temporal order.
|
||||
See [WaveNet: A Generative Model for Raw Audio, section2.1](
|
||||
https://arxiv.org/abs/1609.03499).
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, steps, features)`
|
||||
@ -126,3 +132,39 @@ class Conv1D(BaseConv):
|
||||
bias_constraint=bias_constraint,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _compute_causal_padding(self):
|
||||
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
|
||||
if self.data_format == "channels_last":
|
||||
causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
|
||||
else:
|
||||
causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
|
||||
return causal_padding
|
||||
|
||||
def call(self, inputs):
|
||||
padding = self.padding
|
||||
if self.padding == "causal":
|
||||
# Apply causal padding to inputs.
|
||||
inputs = ops.pad(inputs, self._compute_causal_padding())
|
||||
padding = "valid"
|
||||
|
||||
outputs = ops.conv(
|
||||
inputs,
|
||||
self.kernel,
|
||||
strides=list(self.strides),
|
||||
padding=padding,
|
||||
dilation_rate=self.dilation_rate,
|
||||
data_format=self.data_format,
|
||||
)
|
||||
|
||||
if self.use_bias:
|
||||
if self.data_format == "channels_last":
|
||||
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
|
||||
else:
|
||||
bias_shape = (1, self.filters) + (1,) * self.rank
|
||||
bias = ops.reshape(self.bias, bias_shape)
|
||||
outputs += bias
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs)
|
||||
return outputs
|
||||
|
@ -30,6 +30,17 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
|
||||
"input_shape": (3, 4, 4),
|
||||
"output_shape": (3, 4, 6),
|
||||
},
|
||||
{
|
||||
"filters": 6,
|
||||
"kernel_size": 2,
|
||||
"strides": 1,
|
||||
"padding": "causal",
|
||||
"data_format": "channels_last",
|
||||
"dilation_rate": (2,),
|
||||
"groups": 2,
|
||||
"input_shape": (3, 4, 4),
|
||||
"output_shape": (3, 4, 6),
|
||||
},
|
||||
{
|
||||
"filters": 6,
|
||||
"kernel_size": 2,
|
||||
@ -249,6 +260,15 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
"dilation_rate": (2,),
|
||||
"groups": 2,
|
||||
},
|
||||
{
|
||||
"filters": 6,
|
||||
"kernel_size": 2,
|
||||
"strides": 1,
|
||||
"padding": "causal",
|
||||
"data_format": "channels_last",
|
||||
"dilation_rate": (2,),
|
||||
"groups": 2,
|
||||
},
|
||||
{
|
||||
"filters": 6,
|
||||
"kernel_size": (2,),
|
||||
|
@ -103,7 +103,7 @@ def compute_conv_output_shape(
|
||||
f"`kernel shape={kernel_shape}`, "
|
||||
f"`dilation_rate={dilation_rate}`."
|
||||
)
|
||||
elif padding == "same":
|
||||
elif padding == "same" or padding == "causal":
|
||||
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
||||
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
||||
if data_format == "channels_last":
|
||||
|
Loading…
Reference in New Issue
Block a user