Add keras.layers.Reshape. (#58)

This commit is contained in:
hertschuh 2023-05-01 10:21:40 -07:00 committed by Francois Chollet
parent f71b4d85b9
commit b9b4749e33
4 changed files with 291 additions and 1010 deletions

@ -13,3 +13,4 @@ from keras_core.layers.regularization.gaussian_noise import GaussianNoise
from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
from keras_core.layers.reshaping.reshape import Reshape

@ -0,0 +1,113 @@
import math
from keras_core import operations as ops
from keras_core.layers.layer import Layer
class Reshape(Layer):
"""Layer that reshapes inputs into the given shape.
Args:
target_shape: Target shape. Tuple of integers, does not include the
samples dimension (batch size).
Input shape:
Arbitrary, although all dimensions in the input shape must be
known/fixed. Use the keyword argument `input_shape` (tuple of integers,
does not include the samples/batch size axis) when using this layer as
the first layer in a model.
Output shape:
`(batch_size, *target_shape)`
Example:
>>> # as first layer in a Sequential model
>>> model = keras_core.Sequential()
>>> model.add(keras_core.layers.Reshape((3, 4), input_shape=(12,)))
>>> # model.output_shape == (None, 3, 4), `None` is the batch size.
>>> model.output_shape
(None, 3, 4)
>>> # as intermediate layer in a Sequential model
>>> model.add(keras_core.layers.Reshape((6, 2)))
>>> model.output_shape
(None, 6, 2)
>>> # also supports shape inference using `-1` as dimension
>>> model.add(keras_core.layers.Reshape((-1, 2, 2)))
>>> model.output_shape
(None, 3, 2, 2)
"""
def __init__(self, target_shape, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
self.target_shape = tuple(target_shape)
def _fix_unknown_dimension(self, input_shape, output_shape):
"""Find and replace a missing dimension in an output shape.
Args:
input_shape: Shape of tensor being reshaped as a tuple of ints.
output_shape: Desired shape of the tensor as a tuple of ints. It
contains at most a single `-1` which indicates a dimension that
should be derived from the input shape.
Returns:
The new output shape as a tuple of ints with a -1 replaced with its
computed value.
Raises:
ValueError: If the total tensor size of the output_shape is
different than the input_shape, or more than one unknown
dimension is specified.
"""
msg = (
"total size of new tensor must be unchanged, "
f"input_shape={input_shape},output_shape={output_shape}"
)
known_output_size, unknown_dim_index = 1, None
for index, dim in enumerate(output_shape):
if dim == -1:
if unknown_dim_index is None:
unknown_dim_index = index
else:
raise ValueError(
"There must be at most one unknown dimension in "
f"output_shape. Received: output_shape={output_shape}."
)
else:
known_output_size *= dim
input_size = math.prod(input_shape)
if unknown_dim_index is not None:
if known_output_size == 0 or input_size % known_output_size != 0:
raise ValueError(msg)
result = list(output_shape)
result[unknown_dim_index] = input_size // known_output_size
return tuple(result)
elif input_size != known_output_size:
raise ValueError(msg)
return output_shape
def compute_output_shape(self, input_shape):
output_shape = (input_shape[0],)
if None in input_shape[1:]:
# input shape (partially) unknown? replace -1's with None's
output_shape += tuple(
s if s != -1 else None for s in self.target_shape
)
else:
output_shape += self._fix_unknown_dimension(
input_shape[1:], self.target_shape
)
return output_shape
def call(self, inputs):
return ops.reshape(inputs, (inputs.shape[0],) + self.target_shape)
def get_config(self):
config = {"target_shape": self.target_shape}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,75 @@
import pytest
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ReshapeTest(testing.TestCase):
def test_reshape(self):
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (8, 1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8, 1),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (8,)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (2, 4)},
input_shape=(3, 8),
expected_output_shape=(3, 2, 4),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (-1, 1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8, 1),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (1, -1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 1, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (-1,)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (2, -1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 2, 4),
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_reshape_with_dynamic_batch_size(self):
input_layer = layers.Input(shape=(2, 4))
reshaped = layers.Reshape((8,))(input_layer)
self.assertEqual(reshaped.shape, (None, 8))
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_reshape_sets_static_shape(self):
input_layer = layers.Input(batch_shape=(2, None))
reshaped = layers.Reshape((3, 5))(input_layer)
# Also make sure the batch dim is not lost after reshape.
self.assertEqual(reshaped.shape, (2, 3, 5))

File diff suppressed because it is too large Load Diff