keras/keras_core/operations/operation_utils.py
2023-05-11 18:33:09 -07:00

164 lines
5.7 KiB
Python

import math
import numpy as np
def compute_pooling_output_shape(
input_shape,
pool_size,
strides,
padding="valid",
data_format="channels_last",
):
"""Compute the output shape of pooling ops."""
strides = pool_size if strides is None else strides
input_shape_origin = list(input_shape)
input_shape = np.array(input_shape)
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
pool_size = np.array(pool_size)
if padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - pool_size) / strides) + 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
else:
raise ValueError(
"`padding` must be either `'valid'` or `'same'`. Received "
f"{padding}."
)
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
if data_format == "channels_last":
output_shape = (
(input_shape_origin[0],)
+ output_spatial_shape
+ (input_shape_origin[-1],)
)
else:
output_shape = (
input_shape_origin[0],
input_shape_origin[1],
) + output_spatial_shape
return output_shape
def compute_conv_output_shape(
input_shape,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""Compute the output shape of conv ops."""
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
kernel_shape = kernel_size + (input_shape[-1], filters)
else:
spatial_shape = input_shape[2:]
kernel_shape = kernel_size + (input_shape[1], filters)
if len(kernel_shape) != len(input_shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel_shape} and "
f"input of shape {input_shape}."
)
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * len(spatial_shape)
if isinstance(strides, int):
strides = (strides,) * len(spatial_shape)
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={dilation_rate}` and "
f"input of shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel_shape[:-2])
dilation_rate = np.array(dilation_rate)
if padding == "valid":
output_spatial_shape = (
np.floor(
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
/ strides
)
+ 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={input_shape}`, "
f"`kernel shape={kernel_shape}`, "
f"`dilation_rate={dilation_rate}`."
)
elif padding == "same" or padding == "causal":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = tuple([int(i) for i in output_spatial_shape])
if data_format == "channels_last":
output_shape = (
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
)
else:
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
return output_shape
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
This utility does not special case the 0th dimension (batch size).
"""
unknown_dim_count = new_shape.count(-1)
if unknown_dim_count > 1:
raise ValueError(
"There must be at most one unknown dimension (-1) in "
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
)
# If there is a None in input_shape, we can't infer what the -1 is
if None in input_shape:
return tuple(dim if dim != -1 else None for dim in new_shape)
input_size = math.prod(input_shape)
# If the new_shape fully defined, return it
if unknown_dim_count == 0:
if input_size != math.prod(new_shape):
raise ValueError(
"The total size of the tensor must be unchanged. Received: "
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
)
return new_shape
# We have one -1 in new_shape, compute the actual value
known_output_size = 1
unknown_dim_index = None
for index, dim in enumerate(new_shape):
if dim == -1:
unknown_dim_index = index
else:
known_output_size *= dim
if known_output_size == 0 or input_size % known_output_size != 0:
raise ValueError(
"The total size of the tensor must be unchanged, however, the "
"input size cannot by divided by the specified dimensions in "
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
f"{new_shape_arg_name}={new_shape}"
)
output_shape = list(new_shape)
output_shape[unknown_dim_index] = input_size // known_output_size
return tuple(output_shape)