240 lines
8.1 KiB
Python
240 lines
8.1 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
from tensorflow import nest
|
|
|
|
|
|
def compute_pooling_output_shape(
|
|
input_shape,
|
|
pool_size,
|
|
strides,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
"""Compute the output shape of pooling ops."""
|
|
strides = pool_size if strides is None else strides
|
|
input_shape_origin = list(input_shape)
|
|
input_shape = np.array(input_shape)
|
|
if data_format == "channels_last":
|
|
spatial_shape = input_shape[1:-1]
|
|
else:
|
|
spatial_shape = input_shape[2:]
|
|
none_dims = []
|
|
for i in range(len(spatial_shape)):
|
|
if spatial_shape[i] is None:
|
|
# Set `None` shape to a manual value so that we can run numpy
|
|
# computation on `spatial_shape`.
|
|
spatial_shape[i] = -1
|
|
none_dims.append(i)
|
|
pool_size = np.array(pool_size)
|
|
if padding == "valid":
|
|
output_spatial_shape = (
|
|
np.floor((spatial_shape - pool_size) / strides) + 1
|
|
)
|
|
for i in range(len(output_spatial_shape)):
|
|
if i not in none_dims and output_spatial_shape[i] < 0:
|
|
raise ValueError(
|
|
"Computed output size would be negative. Received: "
|
|
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
|
|
)
|
|
elif padding == "same":
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
|
else:
|
|
raise ValueError(
|
|
"`padding` must be either `'valid'` or `'same'`. Received "
|
|
f"{padding}."
|
|
)
|
|
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
|
for i in none_dims:
|
|
output_spatial_shape[i] = None
|
|
output_spatial_shape = tuple(output_spatial_shape)
|
|
if data_format == "channels_last":
|
|
output_shape = (
|
|
(input_shape_origin[0],)
|
|
+ output_spatial_shape
|
|
+ (input_shape_origin[-1],)
|
|
)
|
|
else:
|
|
output_shape = (
|
|
input_shape_origin[0],
|
|
input_shape_origin[1],
|
|
) + output_spatial_shape
|
|
return output_shape
|
|
|
|
|
|
def compute_conv_output_shape(
|
|
input_shape,
|
|
filters,
|
|
kernel_size,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
"""Compute the output shape of conv ops."""
|
|
if data_format == "channels_last":
|
|
spatial_shape = input_shape[1:-1]
|
|
kernel_shape = kernel_size + (input_shape[-1], filters)
|
|
else:
|
|
spatial_shape = input_shape[2:]
|
|
kernel_shape = kernel_size + (input_shape[1], filters)
|
|
if len(kernel_shape) != len(input_shape):
|
|
raise ValueError(
|
|
"Kernel shape must have the same length as input, but received "
|
|
f"kernel of shape {kernel_shape} and "
|
|
f"input of shape {input_shape}."
|
|
)
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * len(spatial_shape)
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * len(spatial_shape)
|
|
if len(dilation_rate) != len(spatial_shape):
|
|
raise ValueError(
|
|
"Dilation must be None, scalar or tuple/list of length of "
|
|
"inputs' spatial shape, but received "
|
|
f"`dilation_rate={dilation_rate}` and "
|
|
f"input of shape {input_shape}."
|
|
)
|
|
none_dims = []
|
|
spatial_shape = np.array(spatial_shape)
|
|
for i in range(len(spatial_shape)):
|
|
if spatial_shape[i] is None:
|
|
# Set `None` shape to a manual value so that we can run numpy
|
|
# computation on `spatial_shape`.
|
|
spatial_shape[i] = -1
|
|
none_dims.append(i)
|
|
|
|
kernel_spatial_shape = np.array(kernel_shape[:-2])
|
|
dilation_rate = np.array(dilation_rate)
|
|
if padding == "valid":
|
|
output_spatial_shape = (
|
|
np.floor(
|
|
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
|
|
/ strides
|
|
)
|
|
+ 1
|
|
)
|
|
for i in range(len(output_spatial_shape)):
|
|
if i not in none_dims and output_spatial_shape[i] < 0:
|
|
raise ValueError(
|
|
"Computed output size would be negative. Received "
|
|
f"`inputs shape={input_shape}`, "
|
|
f"`kernel shape={kernel_shape}`, "
|
|
f"`dilation_rate={dilation_rate}`."
|
|
)
|
|
elif padding == "same" or padding == "causal":
|
|
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
|
|
output_spatial_shape = [int(i) for i in output_spatial_shape]
|
|
for i in none_dims:
|
|
output_spatial_shape[i] = None
|
|
output_spatial_shape = tuple(output_spatial_shape)
|
|
if data_format == "channels_last":
|
|
output_shape = (
|
|
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
|
|
)
|
|
else:
|
|
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
|
|
return output_shape
|
|
|
|
|
|
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
|
|
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
|
|
|
|
This utility does not special case the 0th dimension (batch size).
|
|
"""
|
|
unknown_dim_count = new_shape.count(-1)
|
|
if unknown_dim_count > 1:
|
|
raise ValueError(
|
|
"There must be at most one unknown dimension (-1) in "
|
|
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
|
|
)
|
|
|
|
# If there is a None in input_shape, we can't infer what the -1 is
|
|
if None in input_shape:
|
|
return tuple(dim if dim != -1 else None for dim in new_shape)
|
|
|
|
input_size = math.prod(input_shape)
|
|
# If the new_shape fully defined, return it
|
|
if unknown_dim_count == 0:
|
|
if input_size != math.prod(new_shape):
|
|
raise ValueError(
|
|
"The total size of the tensor must be unchanged. Received: "
|
|
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
|
|
)
|
|
return new_shape
|
|
|
|
# We have one -1 in new_shape, compute the actual value
|
|
known_output_size = 1
|
|
unknown_dim_index = None
|
|
for index, dim in enumerate(new_shape):
|
|
if dim == -1:
|
|
unknown_dim_index = index
|
|
else:
|
|
known_output_size *= dim
|
|
|
|
if known_output_size == 0 or input_size % known_output_size != 0:
|
|
raise ValueError(
|
|
"The total size of the tensor must be unchanged, however, the "
|
|
"input size cannot by divided by the specified dimensions in "
|
|
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
|
|
f"{new_shape_arg_name}={new_shape}"
|
|
)
|
|
|
|
output_shape = list(new_shape)
|
|
output_shape[unknown_dim_index] = input_size // known_output_size
|
|
return tuple(output_shape)
|
|
|
|
|
|
def reduce_shape(shape, axis=None, keepdims=False):
|
|
shape = list(shape)
|
|
if axis is None:
|
|
if keepdims:
|
|
output_shape = [1 for _ in range(shape)]
|
|
else:
|
|
output_shape = []
|
|
return output_shape
|
|
|
|
if keepdims:
|
|
for ax in axis:
|
|
shape[ax] = 1
|
|
return shape
|
|
else:
|
|
for ax in axis:
|
|
shape[ax] = -1
|
|
output_shape = list(filter((-1).__ne__, shape))
|
|
return output_shape
|
|
|
|
|
|
def get_source_inputs(tensor):
|
|
"""Returns the list of input tensors necessary to compute `tensor`.
|
|
|
|
Output will always be a list of tensors
|
|
(potentially with 1 element).
|
|
|
|
Args:
|
|
tensor: The tensor to start from.
|
|
|
|
Returns:
|
|
List of input tensors.
|
|
"""
|
|
if not hasattr(tensor, "_keras_history"):
|
|
return tensor
|
|
|
|
operation, node_index, _ = tensor._keras_history
|
|
if not operation or not operation._inbound_nodes:
|
|
return [tensor]
|
|
else:
|
|
node = operation._inbound_nodes[node_index]
|
|
if node.is_input:
|
|
# Reached input node, stop recursion.
|
|
return nest.flatten(node.input_tensors)
|
|
else:
|
|
source_tensors = []
|
|
for tensor in node.input_tensors:
|
|
previous_sources = get_source_inputs(tensor)
|
|
# Avoid input redundancy.
|
|
for x in previous_sources:
|
|
if all(x is not t for t in source_tensors):
|
|
source_tensors.append(x)
|
|
return source_tensors
|