support dilation_rate is a tuple or list for conv_transpose in torch backend (#275)

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-06-05 20:47:41 -07:00 committed by Francois Chollet
parent 69ed8781fd
commit 60d1b04718

@ -419,10 +419,13 @@ def conv_transpose(
kernel_spatial_shape = kernel.shape[2:]
padding_arg = []
output_padding_arg = []
if isinstance(dilation_rate, int):
dilation_rate = [dilation_rate] * len(kernel_spatial_shape)
for i, value in enumerate(padding_values):
total_padding = value[0] + value[1]
padding_arg.append(
dilation_rate * (kernel_spatial_shape[i] - 1) - total_padding // 2
dilation_rate[i] * (kernel_spatial_shape[i] - 1)
- total_padding // 2
)
if total_padding % 2 == 0:
output_padding_arg.append(0)