support dilation_rate is a tuple or list for conv_transpose in torch backend (#275)
Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
69ed8781fd
commit
60d1b04718
@ -419,10 +419,13 @@ def conv_transpose(
|
||||
kernel_spatial_shape = kernel.shape[2:]
|
||||
padding_arg = []
|
||||
output_padding_arg = []
|
||||
if isinstance(dilation_rate, int):
|
||||
dilation_rate = [dilation_rate] * len(kernel_spatial_shape)
|
||||
for i, value in enumerate(padding_values):
|
||||
total_padding = value[0] + value[1]
|
||||
padding_arg.append(
|
||||
dilation_rate * (kernel_spatial_shape[i] - 1) - total_padding // 2
|
||||
dilation_rate[i] * (kernel_spatial_shape[i] - 1)
|
||||
- total_padding // 2
|
||||
)
|
||||
if total_padding % 2 == 0:
|
||||
output_padding_arg.append(0)
|
||||
|
Loading…
Reference in New Issue
Block a user