114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
import math
|
|
|
|
from keras_core import operations as ops
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
class Reshape(Layer):
|
|
"""Layer that reshapes inputs into the given shape.
|
|
|
|
Args:
|
|
target_shape: Target shape. Tuple of integers, does not include the
|
|
samples dimension (batch size).
|
|
|
|
Input shape:
|
|
Arbitrary, although all dimensions in the input shape must be
|
|
known/fixed. Use the keyword argument `input_shape` (tuple of integers,
|
|
does not include the samples/batch size axis) when using this layer as
|
|
the first layer in a model.
|
|
|
|
Output shape:
|
|
`(batch_size, *target_shape)`
|
|
|
|
Example:
|
|
|
|
>>> # as first layer in a Sequential model
|
|
>>> model = keras_core.Sequential()
|
|
>>> model.add(keras_core.layers.Reshape((3, 4), input_shape=(12,)))
|
|
>>> # model.output_shape == (None, 3, 4), `None` is the batch size.
|
|
>>> model.output_shape
|
|
(None, 3, 4)
|
|
|
|
>>> # as intermediate layer in a Sequential model
|
|
>>> model.add(keras_core.layers.Reshape((6, 2)))
|
|
>>> model.output_shape
|
|
(None, 6, 2)
|
|
|
|
>>> # also supports shape inference using `-1` as dimension
|
|
>>> model.add(keras_core.layers.Reshape((-1, 2, 2)))
|
|
>>> model.output_shape
|
|
(None, 3, 2, 2)
|
|
"""
|
|
|
|
def __init__(self, target_shape, name=None, dtype=None):
|
|
super().__init__(name=name, dtype=dtype)
|
|
self.target_shape = tuple(target_shape)
|
|
|
|
def _fix_unknown_dimension(self, input_shape, output_shape):
|
|
"""Find and replace a missing dimension in an output shape.
|
|
|
|
Args:
|
|
input_shape: Shape of tensor being reshaped as a tuple of ints.
|
|
output_shape: Desired shape of the tensor as a tuple of ints. It
|
|
contains at most a single `-1` which indicates a dimension that
|
|
should be derived from the input shape.
|
|
|
|
Returns:
|
|
The new output shape as a tuple of ints with a -1 replaced with its
|
|
computed value.
|
|
|
|
Raises:
|
|
ValueError: If the total tensor size of the output_shape is
|
|
different than the input_shape, or more than one unknown
|
|
dimension is specified.
|
|
"""
|
|
msg = (
|
|
"total size of new tensor must be unchanged, "
|
|
f"input_shape={input_shape},output_shape={output_shape}"
|
|
)
|
|
|
|
known_output_size, unknown_dim_index = 1, None
|
|
for index, dim in enumerate(output_shape):
|
|
if dim == -1:
|
|
if unknown_dim_index is None:
|
|
unknown_dim_index = index
|
|
else:
|
|
raise ValueError(
|
|
"There must be at most one unknown dimension in "
|
|
f"output_shape. Received: output_shape={output_shape}."
|
|
)
|
|
else:
|
|
known_output_size *= dim
|
|
|
|
input_size = math.prod(input_shape)
|
|
if unknown_dim_index is not None:
|
|
if known_output_size == 0 or input_size % known_output_size != 0:
|
|
raise ValueError(msg)
|
|
result = list(output_shape)
|
|
result[unknown_dim_index] = input_size // known_output_size
|
|
return tuple(result)
|
|
elif input_size != known_output_size:
|
|
raise ValueError(msg)
|
|
return output_shape
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
output_shape = (input_shape[0],)
|
|
if None in input_shape[1:]:
|
|
# input shape (partially) unknown? replace -1's with None's
|
|
output_shape += tuple(
|
|
s if s != -1 else None for s in self.target_shape
|
|
)
|
|
else:
|
|
output_shape += self._fix_unknown_dimension(
|
|
input_shape[1:], self.target_shape
|
|
)
|
|
return output_shape
|
|
|
|
def call(self, inputs):
|
|
return ops.reshape(inputs, (inputs.shape[0],) + self.target_shape)
|
|
|
|
def get_config(self):
|
|
config = {"target_shape": self.target_shape}
|
|
base_config = super().get_config()
|
|
return {**base_config, **config}
|