Add causal padding support to conv1d (#152)
This commit is contained in:
parent
61ae4f2f59
commit
69a3430260
@ -1,3 +1,4 @@
|
|||||||
|
import keras_core.operations as ops
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
from keras_core.layers.convolutional.base_conv import BaseConv
|
from keras_core.layers.convolutional.base_conv import BaseConv
|
||||||
|
|
||||||
@ -22,10 +23,15 @@ class Conv1D(BaseConv):
|
|||||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||||
of the convolution. `stride value != 1` is incompatible with
|
of the convolution. `stride value != 1` is incompatible with
|
||||||
`dilation_rate != 1`.
|
`dilation_rate != 1`.
|
||||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive).
|
||||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||||
the left/right or up/down of the input such that output has the same
|
the left/right or up/down of the input such that output has the same
|
||||||
height/width dimension as the input.
|
height/width dimension as the input. `"causal"` results in causal
|
||||||
|
(dilated) convolutions, e.g. `output[t]` does not depend on
|
||||||
|
`input[t+1:]`. Useful when modeling temporal data where the model
|
||||||
|
should not violate the temporal order.
|
||||||
|
See [WaveNet: A Generative Model for Raw Audio, section2.1](
|
||||||
|
https://arxiv.org/abs/1609.03499).
|
||||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||||
corresponds to inputs with shape `(batch, steps, features)`
|
corresponds to inputs with shape `(batch, steps, features)`
|
||||||
@ -126,3 +132,39 @@ class Conv1D(BaseConv):
|
|||||||
bias_constraint=bias_constraint,
|
bias_constraint=bias_constraint,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_causal_padding(self):
|
||||||
|
left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
|
||||||
|
if self.data_format == "channels_last":
|
||||||
|
causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
|
||||||
|
else:
|
||||||
|
causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
|
||||||
|
return causal_padding
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
padding = self.padding
|
||||||
|
if self.padding == "causal":
|
||||||
|
# Apply causal padding to inputs.
|
||||||
|
inputs = ops.pad(inputs, self._compute_causal_padding())
|
||||||
|
padding = "valid"
|
||||||
|
|
||||||
|
outputs = ops.conv(
|
||||||
|
inputs,
|
||||||
|
self.kernel,
|
||||||
|
strides=list(self.strides),
|
||||||
|
padding=padding,
|
||||||
|
dilation_rate=self.dilation_rate,
|
||||||
|
data_format=self.data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_bias:
|
||||||
|
if self.data_format == "channels_last":
|
||||||
|
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
|
||||||
|
else:
|
||||||
|
bias_shape = (1, self.filters) + (1,) * self.rank
|
||||||
|
bias = ops.reshape(self.bias, bias_shape)
|
||||||
|
outputs += bias
|
||||||
|
|
||||||
|
if self.activation is not None:
|
||||||
|
return self.activation(outputs)
|
||||||
|
return outputs
|
||||||
|
@ -30,6 +30,17 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
|
|||||||
"input_shape": (3, 4, 4),
|
"input_shape": (3, 4, 4),
|
||||||
"output_shape": (3, 4, 6),
|
"output_shape": (3, 4, 6),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"filters": 6,
|
||||||
|
"kernel_size": 2,
|
||||||
|
"strides": 1,
|
||||||
|
"padding": "causal",
|
||||||
|
"data_format": "channels_last",
|
||||||
|
"dilation_rate": (2,),
|
||||||
|
"groups": 2,
|
||||||
|
"input_shape": (3, 4, 4),
|
||||||
|
"output_shape": (3, 4, 6),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"filters": 6,
|
"filters": 6,
|
||||||
"kernel_size": 2,
|
"kernel_size": 2,
|
||||||
@ -249,6 +260,15 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
|||||||
"dilation_rate": (2,),
|
"dilation_rate": (2,),
|
||||||
"groups": 2,
|
"groups": 2,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"filters": 6,
|
||||||
|
"kernel_size": 2,
|
||||||
|
"strides": 1,
|
||||||
|
"padding": "causal",
|
||||||
|
"data_format": "channels_last",
|
||||||
|
"dilation_rate": (2,),
|
||||||
|
"groups": 2,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"filters": 6,
|
"filters": 6,
|
||||||
"kernel_size": (2,),
|
"kernel_size": (2,),
|
||||||
|
@ -103,7 +103,7 @@ def compute_conv_output_shape(
|
|||||||
f"`kernel shape={kernel_shape}`, "
|
f"`kernel shape={kernel_shape}`, "
|
||||||
f"`dilation_rate={dilation_rate}`."
|
f"`dilation_rate={dilation_rate}`."
|
||||||
)
|
)
|
||||||
elif padding == "same":
|
elif padding == "same" or padding == "causal":
|
||||||
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
||||||
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
||||||
if data_format == "channels_last":
|
if data_format == "channels_last":
|
||||||
|
Loading…
Reference in New Issue
Block a user