164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
|
|
|
|
def compute_pooling_output_shape(
|
|
input_shape,
|
|
pool_size,
|
|
strides,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
"""Compute the output shape of pooling ops."""
|
|
strides = pool_size if strides is None else strides
|
|
input_shape_origin = list(input_shape)
|
|
input_shape = np.array(input_shape)
|
|
if data_format == "channels_last":
|
|
spatial_shape = input_shape[1:-1]
|
|
else:
|
|
spatial_shape = input_shape[2:]
|
|
pool_size = np.array(pool_size)
|
|
if padding == "valid":
|
|
output_spatial_shape = (
|
|
np.floor((spatial_shape - pool_size) / strides) + 1
|
|
)
|
|
negative_in_shape = np.all(output_spatial_shape < 0)
|
|
if negative_in_shape:
|
|
raise ValueError(
|
|
"Computed output size would be negative. Received: "
|
|
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
|
)
|
|
elif padding == "same":
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
|
else:
|
|
raise ValueError(
|
|
"`padding` must be either `'valid'` or `'same'`. Received "
|
|
f"{padding}."
|
|
)
|
|
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
|
if data_format == "channels_last":
|
|
output_shape = (
|
|
(input_shape_origin[0],)
|
|
+ output_spatial_shape
|
|
+ (input_shape_origin[-1],)
|
|
)
|
|
else:
|
|
output_shape = (
|
|
input_shape_origin[0],
|
|
input_shape_origin[1],
|
|
) + output_spatial_shape
|
|
return output_shape
|
|
|
|
|
|
def compute_conv_output_shape(
|
|
input_shape,
|
|
filters,
|
|
kernel_size,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
"""Compute the output shape of conv ops."""
|
|
if data_format == "channels_last":
|
|
spatial_shape = input_shape[1:-1]
|
|
kernel_shape = kernel_size + (input_shape[-1], filters)
|
|
else:
|
|
spatial_shape = input_shape[2:]
|
|
kernel_shape = kernel_size + (input_shape[1], filters)
|
|
if len(kernel_shape) != len(input_shape):
|
|
raise ValueError(
|
|
"Kernel shape must have the same length as input, but received "
|
|
f"kernel of shape {kernel_shape} and "
|
|
f"input of shape {input_shape}."
|
|
)
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * len(spatial_shape)
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * len(spatial_shape)
|
|
if len(dilation_rate) != len(spatial_shape):
|
|
raise ValueError(
|
|
"Dilation must be None, scalar or tuple/list of length of "
|
|
"inputs' spatial shape, but received "
|
|
f"`dilation_rate={dilation_rate}` and "
|
|
f"input of shape {input_shape}."
|
|
)
|
|
spatial_shape = np.array(spatial_shape)
|
|
kernel_spatial_shape = np.array(kernel_shape[:-2])
|
|
dilation_rate = np.array(dilation_rate)
|
|
if padding == "valid":
|
|
output_spatial_shape = (
|
|
np.floor(
|
|
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
|
|
/ strides
|
|
)
|
|
+ 1
|
|
)
|
|
negative_in_shape = np.all(output_spatial_shape < 0)
|
|
if negative_in_shape:
|
|
raise ValueError(
|
|
"Computed output size would be negative. Received "
|
|
f"`inputs shape={input_shape}`, "
|
|
f"`kernel shape={kernel_shape}`, "
|
|
f"`dilation_rate={dilation_rate}`."
|
|
)
|
|
elif padding == "same" or padding == "causal":
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
|
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
|
|
if data_format == "channels_last":
|
|
output_shape = (
|
|
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
|
|
)
|
|
else:
|
|
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
|
|
return output_shape
|
|
|
|
|
|
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
|
|
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
|
|
|
|
This utility does not special case the 0th dimension (batch size).
|
|
"""
|
|
unknown_dim_count = new_shape.count(-1)
|
|
if unknown_dim_count > 1:
|
|
raise ValueError(
|
|
"There must be at most one unknown dimension (-1) in "
|
|
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
|
|
)
|
|
|
|
# If there is a None in input_shape, we can't infer what the -1 is
|
|
if None in input_shape:
|
|
return tuple(dim if dim != -1 else None for dim in new_shape)
|
|
|
|
input_size = math.prod(input_shape)
|
|
# If the new_shape fully defined, return it
|
|
if unknown_dim_count == 0:
|
|
if input_size != math.prod(new_shape):
|
|
raise ValueError(
|
|
"The total size of the tensor must be unchanged. Received: "
|
|
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
|
|
)
|
|
return new_shape
|
|
|
|
# We have one -1 in new_shape, compute the actual value
|
|
known_output_size = 1
|
|
unknown_dim_index = None
|
|
for index, dim in enumerate(new_shape):
|
|
if dim == -1:
|
|
unknown_dim_index = index
|
|
else:
|
|
known_output_size *= dim
|
|
|
|
if known_output_size == 0 or input_size % known_output_size != 0:
|
|
raise ValueError(
|
|
"The total size of the tensor must be unchanged, however, the "
|
|
"input size cannot by divided by the specified dimensions in "
|
|
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
|
|
f"{new_shape_arg_name}={new_shape}"
|
|
)
|
|
|
|
output_shape = list(new_shape)
|
|
output_shape[unknown_dim_index] = input_size // known_output_size
|
|
return tuple(output_shape)
|