Add UpSampling3D layer and its test. (#179)

* Commit a new file up_sampling3d.py

* Complete UpSampling3D and its test.

* Addresses comments.

* Last comment
This commit is contained in:
Rick Chao 2023-05-18 11:17:59 -07:00 committed by Francois Chollet
parent 04436b0da6
commit 9b0184f644
3 changed files with 262 additions and 0 deletions

@ -103,6 +103,7 @@ from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.up_sampling3d import UpSampling3D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional
from keras_core.layers.rnn.conv_lstm1d import ConvLSTM1D

@ -0,0 +1,136 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.utils import argument_validation
@keras_core_export("keras_core.layers.UpSampling3D")
class UpSampling3D(Layer):
"""Upsampling layer for 3D inputs.
Repeats the 1st, 2nd and 3rd dimensions
of the data by `size[0]`, `size[1]` and `size[2]` respectively.
Examples:
>>> input_shape = (2, 1, 2, 1, 3)
>>> x = np.ones(input_shape)
>>> y = keras_core.layers.UpSampling3D(size=(2, 2, 2))(x)
>>> y.shape
(2, 2, 4, 2, 3)
Args:
size: Int, or tuple of 3 integers.
The upsampling factors for dim1, dim2 and dim3.
data_format: A string,
one of `"channels_last"` (default) or `"channels_first"`.
The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
When unspecified, uses
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json` (if exists) else `"channels_last"`.
Defaults to `"channels_last"`.
Input shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, dim1, dim2, dim3, channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, dim1, dim2, dim3)`
Output shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3,
channels)`
- If `data_format` is `"channels_first"`:
`(batch_size, channels, upsampled_dim1, upsampled_dim2,
upsampled_dim3)`
"""
def __init__(self, size=(2, 2, 2), data_format=None, **kwargs):
super().__init__(**kwargs)
self.data_format = backend.config.standardize_data_format(data_format)
self.size = argument_validation.standardize_tuple(size, 3, "size")
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
dim1 = (
self.size[0] * input_shape[2]
if input_shape[2] is not None
else None
)
dim2 = (
self.size[1] * input_shape[3]
if input_shape[3] is not None
else None
)
dim3 = (
self.size[2] * input_shape[4]
if input_shape[4] is not None
else None
)
return (input_shape[0], input_shape[1], dim1, dim2, dim3)
else:
dim1 = (
self.size[0] * input_shape[1]
if input_shape[1] is not None
else None
)
dim2 = (
self.size[1] * input_shape[2]
if input_shape[2] is not None
else None
)
dim3 = (
self.size[2] * input_shape[3]
if input_shape[3] is not None
else None
)
return (input_shape[0], dim1, dim2, dim3, input_shape[4])
def call(self, inputs):
return self._resize_volumes(
inputs, self.size[0], self.size[1], self.size[2], self.data_format
)
def get_config(self):
config = {"size": self.size, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}
def _resize_volumes(
self, x, depth_factor, height_factor, width_factor, data_format
):
"""Resizes the volume contained in a 5D tensor.
Args:
x: Tensor or variable to resize.
depth_factor: Positive integer.
height_factor: Positive integer.
width_factor: Positive integer.
data_format: One of `"channels_first"`, `"channels_last"`.
Returns:
A tensor.
Raises:
ValueError: if `data_format` is neither
`channels_last` or `channels_first`.
"""
if data_format == "channels_first":
output = ops.repeat(x, depth_factor, axis=2)
output = ops.repeat(output, height_factor, axis=3)
output = ops.repeat(output, width_factor, axis=4)
return output
elif data_format == "channels_last":
output = ops.repeat(x, depth_factor, axis=1)
output = ops.repeat(output, height_factor, axis=2)
output = ops.repeat(output, width_factor, axis=3)
return output
else:
raise ValueError("Invalid data_format: " + str(data_format))

@ -0,0 +1,125 @@
import numpy as np
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class UpSampling3dTest(testing.TestCase, parameterized.TestCase):
@parameterized.product(
data_format=["channels_first", "channels_last"],
length_dim1=[2, 3],
length_dim2=[2],
length_dim3=[3],
)
def test_upsampling_3d(
self, data_format, length_dim1, length_dim2, length_dim3
):
num_samples = 2
stack_size = 2
input_len_dim1 = 10
input_len_dim2 = 11
input_len_dim3 = 12
if data_format == "channels_first":
inputs = np.random.rand(
num_samples,
stack_size,
input_len_dim1,
input_len_dim2,
input_len_dim3,
)
else:
inputs = np.random.rand(
num_samples,
input_len_dim1,
input_len_dim2,
input_len_dim3,
stack_size,
)
# basic test
if data_format == "channels_first":
expected_output_shape = (2, 2, 20, 22, 24)
else:
expected_output_shape = (2, 20, 22, 24, 2)
self.run_layer_test(
layers.UpSampling3D,
init_kwargs={"size": (2, 2, 2), "data_format": data_format},
input_shape=inputs.shape,
expected_output_shape=expected_output_shape,
expected_output_dtype="float32",
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
layer = layers.UpSampling3D(
size=(length_dim1, length_dim2, length_dim3),
data_format=data_format,
)
layer.build(inputs.shape)
np_output = layer(inputs=backend.Variable(inputs))
if data_format == "channels_first":
assert np_output.shape[2] == length_dim1 * input_len_dim1
assert np_output.shape[3] == length_dim2 * input_len_dim2
assert np_output.shape[4] == length_dim3 * input_len_dim3
else: # tf
assert np_output.shape[1] == length_dim1 * input_len_dim1
assert np_output.shape[2] == length_dim2 * input_len_dim2
assert np_output.shape[3] == length_dim3 * input_len_dim3
# compare with numpy
if data_format == "channels_first":
expected_out = np.repeat(inputs, length_dim1, axis=2)
expected_out = np.repeat(expected_out, length_dim2, axis=3)
expected_out = np.repeat(expected_out, length_dim3, axis=4)
else: # tf
expected_out = np.repeat(inputs, length_dim1, axis=1)
expected_out = np.repeat(expected_out, length_dim2, axis=2)
expected_out = np.repeat(expected_out, length_dim3, axis=3)
np.testing.assert_allclose(np_output, expected_out)
def test_upsampling_3d_correctness(self):
input_shape = (2, 1, 2, 1, 3)
x = np.arange(np.prod(input_shape)).reshape(input_shape)
np.testing.assert_array_equal(
layers.UpSampling3D(size=(2, 2, 2))(x),
np.array(
[
[
[
[[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],
[[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
],
[
[[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],
[[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
],
],
[
[
[[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],
[[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
[[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
],
[
[[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],
[[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
[[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
],
],
]
),
)