keras/keras_core/operations/operation_utils.py
2023-05-22 12:02:11 -07:00

240 lines
8.1 KiB
Python

import math
import numpy as np
from tensorflow import nest
def compute_pooling_output_shape(
input_shape,
pool_size,
strides,
padding="valid",
data_format="channels_last",
):
"""Compute the output shape of pooling ops."""
strides = pool_size if strides is None else strides
input_shape_origin = list(input_shape)
input_shape = np.array(input_shape)
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
none_dims = []
for i in range(len(spatial_shape)):
if spatial_shape[i] is None:
# Set `None` shape to a manual value so that we can run numpy
# computation on `spatial_shape`.
spatial_shape[i] = -1
none_dims.append(i)
pool_size = np.array(pool_size)
if padding == "valid":
output_spatial_shape = (
np.floor((spatial_shape - pool_size) / strides) + 1
)
for i in range(len(output_spatial_shape)):
if i not in none_dims and output_spatial_shape[i] < 0:
raise ValueError(
"Computed output size would be negative. Received: "
f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
)
elif padding == "same":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
else:
raise ValueError(
"`padding` must be either `'valid'` or `'same'`. Received "
f"{padding}."
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
for i in none_dims:
output_spatial_shape[i] = None
output_spatial_shape = tuple(output_spatial_shape)
if data_format == "channels_last":
output_shape = (
(input_shape_origin[0],)
+ output_spatial_shape
+ (input_shape_origin[-1],)
)
else:
output_shape = (
input_shape_origin[0],
input_shape_origin[1],
) + output_spatial_shape
return output_shape
def compute_conv_output_shape(
input_shape,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""Compute the output shape of conv ops."""
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
kernel_shape = kernel_size + (input_shape[-1], filters)
else:
spatial_shape = input_shape[2:]
kernel_shape = kernel_size + (input_shape[1], filters)
if len(kernel_shape) != len(input_shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel_shape} and "
f"input of shape {input_shape}."
)
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * len(spatial_shape)
if isinstance(strides, int):
strides = (strides,) * len(spatial_shape)
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={dilation_rate}` and "
f"input of shape {input_shape}."
)
none_dims = []
spatial_shape = np.array(spatial_shape)
for i in range(len(spatial_shape)):
if spatial_shape[i] is None:
# Set `None` shape to a manual value so that we can run numpy
# computation on `spatial_shape`.
spatial_shape[i] = -1
none_dims.append(i)
kernel_spatial_shape = np.array(kernel_shape[:-2])
dilation_rate = np.array(dilation_rate)
if padding == "valid":
output_spatial_shape = (
np.floor(
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
/ strides
)
+ 1
)
for i in range(len(output_spatial_shape)):
if i not in none_dims and output_spatial_shape[i] < 0:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={input_shape}`, "
f"`kernel shape={kernel_shape}`, "
f"`dilation_rate={dilation_rate}`."
)
elif padding == "same" or padding == "causal":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = [int(i) for i in output_spatial_shape]
for i in none_dims:
output_spatial_shape[i] = None
output_spatial_shape = tuple(output_spatial_shape)
if data_format == "channels_last":
output_shape = (
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
)
else:
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
return output_shape
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
This utility does not special case the 0th dimension (batch size).
"""
unknown_dim_count = new_shape.count(-1)
if unknown_dim_count > 1:
raise ValueError(
"There must be at most one unknown dimension (-1) in "
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
)
# If there is a None in input_shape, we can't infer what the -1 is
if None in input_shape:
return tuple(dim if dim != -1 else None for dim in new_shape)
input_size = math.prod(input_shape)
# If the new_shape fully defined, return it
if unknown_dim_count == 0:
if input_size != math.prod(new_shape):
raise ValueError(
"The total size of the tensor must be unchanged. Received: "
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
)
return new_shape
# We have one -1 in new_shape, compute the actual value
known_output_size = 1
unknown_dim_index = None
for index, dim in enumerate(new_shape):
if dim == -1:
unknown_dim_index = index
else:
known_output_size *= dim
if known_output_size == 0 or input_size % known_output_size != 0:
raise ValueError(
"The total size of the tensor must be unchanged, however, the "
"input size cannot by divided by the specified dimensions in "
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
f"{new_shape_arg_name}={new_shape}"
)
output_shape = list(new_shape)
output_shape[unknown_dim_index] = input_size // known_output_size
return tuple(output_shape)
def reduce_shape(shape, axis=None, keepdims=False):
shape = list(shape)
if axis is None:
if keepdims:
output_shape = [1 for _ in range(shape)]
else:
output_shape = []
return output_shape
if keepdims:
for ax in axis:
shape[ax] = 1
return shape
else:
for ax in axis:
shape[ax] = -1
output_shape = list(filter((-1).__ne__, shape))
return output_shape
def get_source_inputs(tensor):
"""Returns the list of input tensors necessary to compute `tensor`.
Output will always be a list of tensors
(potentially with 1 element).
Args:
tensor: The tensor to start from.
Returns:
List of input tensors.
"""
if not hasattr(tensor, "_keras_history"):
return tensor
operation, node_index, _ = tensor._keras_history
if not operation or not operation._inbound_nodes:
return [tensor]
else:
node = operation._inbound_nodes[node_index]
if node.is_input:
# Reached input node, stop recursion.
return nest.flatten(node.input_tensors)
else:
source_tensors = []
for tensor in node.input_tensors:
previous_sources = get_source_inputs(tensor)
# Avoid input redundancy.
for x in previous_sources:
if all(x is not t for t in source_tensors):
source_tensors.append(x)
return source_tensors