keras/keras_core/layers/reshaping/reshape_test.py
2023-05-01 10:21:40 -07:00

76 lines
2.2 KiB
Python

import pytest
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ReshapeTest(testing.TestCase):
def test_reshape(self):
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (8, 1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8, 1),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (8,)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (2, 4)},
input_shape=(3, 8),
expected_output_shape=(3, 2, 4),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (-1, 1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8, 1),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (1, -1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 1, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (-1,)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 8),
)
self.run_layer_test(
layers.Reshape,
init_kwargs={"target_shape": (2, -1)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 2, 4),
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_reshape_with_dynamic_batch_size(self):
input_layer = layers.Input(shape=(2, 4))
reshaped = layers.Reshape((8,))(input_layer)
self.assertEqual(reshaped.shape, (None, 8))
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_reshape_sets_static_shape(self):
input_layer = layers.Input(batch_shape=(2, None))
reshaped = layers.Reshape((3, 5))(input_layer)
# Also make sure the batch dim is not lost after reshape.
self.assertEqual(reshaped.shape, (2, 3, 5))