Add keras.layers.Reshape. (#58)
This commit is contained in:
parent
f71b4d85b9
commit
b9b4749e33
@ -13,3 +13,4 @@ from keras_core.layers.regularization.gaussian_noise import GaussianNoise
|
|||||||
from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
|
from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
|
||||||
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
|
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
|
||||||
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
|
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
|
||||||
|
from keras_core.layers.reshaping.reshape import Reshape
|
||||||
|
113
keras_core/layers/reshaping/reshape.py
Normal file
113
keras_core/layers/reshaping/reshape.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from keras_core import operations as ops
|
||||||
|
from keras_core.layers.layer import Layer
|
||||||
|
|
||||||
|
|
||||||
|
class Reshape(Layer):
|
||||||
|
"""Layer that reshapes inputs into the given shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_shape: Target shape. Tuple of integers, does not include the
|
||||||
|
samples dimension (batch size).
|
||||||
|
|
||||||
|
Input shape:
|
||||||
|
Arbitrary, although all dimensions in the input shape must be
|
||||||
|
known/fixed. Use the keyword argument `input_shape` (tuple of integers,
|
||||||
|
does not include the samples/batch size axis) when using this layer as
|
||||||
|
the first layer in a model.
|
||||||
|
|
||||||
|
Output shape:
|
||||||
|
`(batch_size, *target_shape)`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> # as first layer in a Sequential model
|
||||||
|
>>> model = keras_core.Sequential()
|
||||||
|
>>> model.add(keras_core.layers.Reshape((3, 4), input_shape=(12,)))
|
||||||
|
>>> # model.output_shape == (None, 3, 4), `None` is the batch size.
|
||||||
|
>>> model.output_shape
|
||||||
|
(None, 3, 4)
|
||||||
|
|
||||||
|
>>> # as intermediate layer in a Sequential model
|
||||||
|
>>> model.add(keras_core.layers.Reshape((6, 2)))
|
||||||
|
>>> model.output_shape
|
||||||
|
(None, 6, 2)
|
||||||
|
|
||||||
|
>>> # also supports shape inference using `-1` as dimension
|
||||||
|
>>> model.add(keras_core.layers.Reshape((-1, 2, 2)))
|
||||||
|
>>> model.output_shape
|
||||||
|
(None, 3, 2, 2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target_shape, name=None, dtype=None):
|
||||||
|
super().__init__(name=name, dtype=dtype)
|
||||||
|
self.target_shape = tuple(target_shape)
|
||||||
|
|
||||||
|
def _fix_unknown_dimension(self, input_shape, output_shape):
|
||||||
|
"""Find and replace a missing dimension in an output shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape: Shape of tensor being reshaped as a tuple of ints.
|
||||||
|
output_shape: Desired shape of the tensor as a tuple of ints. It
|
||||||
|
contains at most a single `-1` which indicates a dimension that
|
||||||
|
should be derived from the input shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new output shape as a tuple of ints with a -1 replaced with its
|
||||||
|
computed value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the total tensor size of the output_shape is
|
||||||
|
different than the input_shape, or more than one unknown
|
||||||
|
dimension is specified.
|
||||||
|
"""
|
||||||
|
msg = (
|
||||||
|
"total size of new tensor must be unchanged, "
|
||||||
|
f"input_shape={input_shape},output_shape={output_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
known_output_size, unknown_dim_index = 1, None
|
||||||
|
for index, dim in enumerate(output_shape):
|
||||||
|
if dim == -1:
|
||||||
|
if unknown_dim_index is None:
|
||||||
|
unknown_dim_index = index
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"There must be at most one unknown dimension in "
|
||||||
|
f"output_shape. Received: output_shape={output_shape}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
known_output_size *= dim
|
||||||
|
|
||||||
|
input_size = math.prod(input_shape)
|
||||||
|
if unknown_dim_index is not None:
|
||||||
|
if known_output_size == 0 or input_size % known_output_size != 0:
|
||||||
|
raise ValueError(msg)
|
||||||
|
result = list(output_shape)
|
||||||
|
result[unknown_dim_index] = input_size // known_output_size
|
||||||
|
return tuple(result)
|
||||||
|
elif input_size != known_output_size:
|
||||||
|
raise ValueError(msg)
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
output_shape = (input_shape[0],)
|
||||||
|
if None in input_shape[1:]:
|
||||||
|
# input shape (partially) unknown? replace -1's with None's
|
||||||
|
output_shape += tuple(
|
||||||
|
s if s != -1 else None for s in self.target_shape
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_shape += self._fix_unknown_dimension(
|
||||||
|
input_shape[1:], self.target_shape
|
||||||
|
)
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return ops.reshape(inputs, (inputs.shape[0],) + self.target_shape)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {"target_shape": self.target_shape}
|
||||||
|
base_config = super().get_config()
|
||||||
|
return {**base_config, **config}
|
75
keras_core/layers/reshaping/reshape_test.py
Normal file
75
keras_core/layers/reshaping/reshape_test.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core import testing
|
||||||
|
|
||||||
|
|
||||||
|
class ReshapeTest(testing.TestCase):
|
||||||
|
def test_reshape(self):
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (8, 1)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 8, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (8,)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (2, 4)},
|
||||||
|
input_shape=(3, 8),
|
||||||
|
expected_output_shape=(3, 2, 4),
|
||||||
|
)
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (-1, 1)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 8, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (1, -1)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 1, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (-1,)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.Reshape,
|
||||||
|
init_kwargs={"target_shape": (2, -1)},
|
||||||
|
input_shape=(3, 2, 4),
|
||||||
|
expected_output_shape=(3, 2, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not backend.DYNAMIC_SHAPES_OK,
|
||||||
|
reason="Backend does not support dynamic shapes",
|
||||||
|
)
|
||||||
|
def test_reshape_with_dynamic_batch_size(self):
|
||||||
|
input_layer = layers.Input(shape=(2, 4))
|
||||||
|
reshaped = layers.Reshape((8,))(input_layer)
|
||||||
|
self.assertEqual(reshaped.shape, (None, 8))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not backend.DYNAMIC_SHAPES_OK,
|
||||||
|
reason="Backend does not support dynamic shapes",
|
||||||
|
)
|
||||||
|
def test_reshape_sets_static_shape(self):
|
||||||
|
input_layer = layers.Input(batch_shape=(2, None))
|
||||||
|
reshaped = layers.Reshape((3, 5))(input_layer)
|
||||||
|
# Also make sure the batch dim is not lost after reshape.
|
||||||
|
self.assertEqual(reshaped.shape, (2, 3, 5))
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user