Merge branches 'main' and 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-05-10 21:16:27 -07:00
parent 8b45efc5ee
commit b27593e5b0
16 changed files with 1308 additions and 24 deletions

@ -33,16 +33,17 @@ def _compute_conv_transpose_output_length(
def compute_conv_transpose_output_shape(
inputs,
kernel,
input_shape,
kernel_size,
filters,
strides,
padding,
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = len(inputs.shape) - 2
kernel_spatial_shape = kernel.shape[:-2]
num_spatial_dims = len(input_shape) - 2
kernel_spatial_shape = kernel_size
if isinstance(output_padding, int):
output_padding = (output_padding,) * len(kernel_spatial_shape)
@ -52,9 +53,9 @@ def compute_conv_transpose_output_shape(
dilation_rate = (dilation_rate,) * num_spatial_dims
if data_format == "channels_last":
inputs_spatial_shape = inputs.shape[1:-1]
input_spatial_shape = input_shape[1:-1]
else:
inputs_spatial_shape = inputs.shape[2:]
input_spatial_shape = input_shape[2:]
output_shape = []
for i in range(num_spatial_dims):
@ -63,7 +64,7 @@ def compute_conv_transpose_output_shape(
)
output_shape.append(
_compute_conv_transpose_output_length(
inputs_spatial_shape[i],
input_spatial_shape[i],
kernel_spatial_shape[i],
padding=padding,
output_padding=current_output_padding,
@ -73,7 +74,7 @@ def compute_conv_transpose_output_shape(
)
if data_format == "channels_last":
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
output_shape = [input_shape[0]] + output_shape + [filters]
else:
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
output_shape = [input_shape[0], filters] + output_shape
return output_shape

@ -1,10 +1,12 @@
"""Tests for tf.distribute related functionality under tf implementation."""
import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.eager import context
from keras_core import backend
from keras_core.backend.tensorflow import trainer as tf_trainer
from keras_core import layers
from keras_core import models
from keras_core import testing
@ -79,3 +81,41 @@ class DistributeTest(testing.TestCase):
self.assertEqual(result.values[1].shape, [8, 2])
self.assertNotAllClose(result.values[0], result.values[1])
self.assertAllClose(result.values[0], tf.zeros([8, 2]))
def test_epoch_iterator(self):
x = np.random.random((100, 16))
y = np.random.random((100, 4))
sample_weight = np.random.random((100,))
batch_size = 16
shuffle = True
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
epoch_iterator = tf_trainer.TFEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
shuffle=shuffle,
distribute_strategy=strategy
)
steps_seen = []
for step, data_iterator in epoch_iterator.enumerate_epoch():
steps_seen.append(step)
batch = next(data_iterator)
self.assertEqual(len(batch), 3)
x, y, sample_weight = batch
self.assertTrue(
isinstance(x,
tf.types.experimental.distributed.PerReplica))
# Make sure the local batch size is 8
if step < 6:
self.assertEqual(x.values[0].shape, [8, 16])
self.assertEqual(y.values[0].shape, [8, 4])
self.assertEqual(sample_weight.values[0].shape, [8])
else:
# Last partial batch
self.assertEqual(x.values[0].shape, [2, 16])
self.assertEqual(y.values[0].shape, [2, 4])
self.assertEqual(sample_weight.values[0].shape, [2])
self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])

@ -360,9 +360,12 @@ def conv_transpose(
dilation_rate=1,
):
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
kernel_size = kernel.shape[:-2]
filters = kernel.shape[-2]
output_shape = compute_conv_transpose_output_shape(
inputs,
kernel,
inputs.shape,
kernel_size,
filters,
strides,
padding,
output_padding,

@ -243,6 +243,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
distribute_strategy=self.distribute_strategy,
)
# Container that configures and calls callbacks.
@ -285,6 +286,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
distribute_strategy=self.distribute_strategy,
)
val_logs = self.evaluate(
x=val_x,
@ -346,6 +348,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
)
# Container that configures and calls callbacks.
@ -385,6 +388,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
)
# Container that configures and calls callbacks.
@ -428,20 +432,23 @@ class TensorFlowTrainer(base_trainer.Trainer):
class TFEpochIterator(EpochIterator):
def __init__(self, *args, **kwargs):
def __init__(self, distribute_strategy=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distribute_strategy = distribute_strategy
self._steps_seen = 0
def enumerate_epoch(self):
if self.steps_per_epoch:
if not self._current_iterator:
self._current_iterator = iter(
self.data_adapter.get_tf_dataset()
)
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
for step in range(self.steps_per_epoch):
yield step, self._current_iterator
else:
iterator = iter(self.data_adapter.get_tf_dataset())
iterator = iter(
self._distribute_strategy.experimental_distribute_dataset(
self.data_adapter.get_tf_dataset()))
if self.num_batches:
for step in range(self.num_batches):
yield step, iterator

@ -34,7 +34,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
def test_reduces_lr_with_model_fit(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10
patience=1, factor=0.1, monitor="val_loss", min_delta=100
)
self.model.fit(
@ -49,7 +49,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
def test_throws_when_optimizer_has_schedule(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10
patience=1, factor=0.1, monitor="val_loss", min_delta=100
)
self.model.compile(
@ -75,7 +75,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
def test_verbose_logging(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10, verbose=1
patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1
)
io_utils.disable_interactive_logging()
@ -111,7 +111,11 @@ class ReduceLROnPlateauTest(testing.TestCase):
def test_cooldown(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=10, cooldown=2
patience=1,
factor=0.1,
monitor="val_loss",
min_delta=100,
cooldown=2,
)
self.model.fit(

@ -1,8 +1,11 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.attention.attention import Attention
from keras_core.layers.convolutional.conv1d import Conv1D
from keras_core.layers.convolutional.conv1d_transpose import Conv1DTranspose
from keras_core.layers.convolutional.conv2d import Conv2D
from keras_core.layers.convolutional.conv2d_transpose import Conv2DTranspose
from keras_core.layers.convolutional.conv3d import Conv3D
from keras_core.layers.convolutional.conv3d_transpose import Conv3DTranspose
from keras_core.layers.core.dense import Dense
from keras_core.layers.core.einsum_dense import EinsumDense
from keras_core.layers.core.embedding import Embedding
@ -78,6 +81,7 @@ from keras_core.layers.regularization.gaussian_noise import GaussianNoise
from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
from keras_core.layers.reshaping.cropping1d import Cropping1D
from keras_core.layers.reshaping.flatten import Flatten
from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector

@ -0,0 +1,255 @@
"""Keras base class for transpose convolution layers."""
from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import standardize_data_format
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
class BaseConvTranspose(Layer):
"""Abstract N-D transpose convolution layer.
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
Args:
rank: int, the rank of the transposed convolution, e.g. 2 for 2D
transposed convolution.
filters: int, the dimension of the output space (the number of filters
in the transposed convolution).
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the transposed convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the transposed convolution. If only one int is specified, the
same stride size will be used for all dimensions.
`stride value != 1` is incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
"""
def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs,
)
self.rank = rank
self.filters = filters
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * self.rank
self.kernel_size = kernel_size
if isinstance(strides, int):
strides = (strides,) * self.rank
self.strides = strides
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * self.rank
self.dilation_rate = dilation_rate
self.padding = padding
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format
if self.filters is not None and self.filters <= 0:
raise ValueError(
"Invalid value for argument `filters`. Expected a strictly "
f"positive value. Received filters={self.filters}."
)
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
raise ValueError(
"`strides > 1` not supported in conjunction with "
f"`dilation_rate > 1`. Received: strides={self.strides} and "
f"dilation_rate={self.dilation_rate}"
)
def build(self, input_shape):
if self.data_format == "channels_last":
channel_axis = -1
input_channel = input_shape[-1]
else:
channel_axis = 1
input_channel = input_shape[1]
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
kernel_shape = self.kernel_size + (
self.filters,
input_channel,
)
self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype,
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype,
)
else:
self.bias = None
self.built = True
def call(self, inputs):
outputs = ops.conv_transpose(
inputs,
self.kernel,
strides=list(self.strides),
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format,
)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
return compute_conv_transpose_output_shape(
input_shape,
self.kernel_size,
self.filters,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config

@ -80,8 +80,7 @@ class Conv1D(BaseConv):
>>> # The inputs are 128-length vectors with 10 timesteps, and the
>>> # batch size is 4.
>>> input_shape = (4, 10, 128)
>>> x = np.random.normal(4, 10, 128)
>>> x = np.random.rand(4, 10, 128)
>>> y = keras_core.layers.Conv1D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 32)

@ -0,0 +1,136 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv_transpose import (
BaseConvTranspose,
)
@keras_core_export(
[
"keras_core.layers.Conv1DTranspose",
"keras_core.layers.Convolution1DTranspose",
]
)
class Conv1DTranspose(BaseConvTranspose):
"""1D transposed convolution layer.
The need for transposed convolutions generally arise from the desire to use
a transformation going in the opposite direction of a normal convolution,
i.e., from something that has the shape of the output of some convolution
to something that has the shape of its input while maintaining a
connectivity pattern that is compatible with said convolution.
Args:
filters: int, the dimension of the output space (the number of filters
in the transpose convolution).
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated transposed convolution.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, channels, steps)`
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
Returns:
A 3D tensor representing
`activation(conv1d_transpose(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
References:
- [A guide to convolution arithmetic for deep learning](
https://arxiv.org/abs/1603.07285v1)
- [Deconvolutional Networks](
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)
Examples:
>>> x = np.random.rand(4, 10, 128)
>>> y = keras_core.layers.Conv1DTranspose(32, 3, 2, activation='relu')(x)
>>> print(y.shape)
(4, 21, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=1,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -80,7 +80,7 @@ class Conv2D(BaseConv):
Examples:
>>> x = np.random.normal(4, 10, 10, 128)
>>> x = np.random.rand(4, 10, 10, 128)
>>> y = keras_core.layers.Conv2D(32, 3, activation='relu')(x)
>>> print(y.shape)
(4, 8, 8, 32)

@ -0,0 +1,133 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv_transpose import (
BaseConvTranspose,
)
@keras_core_export(
[
"keras_core.layers.Conv2DTranspose",
"keras_core.layers.Convolution2DTranspose",
]
)
class Conv2DTranspose(BaseConvTranspose):
"""2D transposed convolution layer.
The need for transposed convolutions generally arise from the desire to use
a transformation going in the opposite direction of a normal convolution,
i.e., from something that has the shape of the output of some convolution
to something that has the shape of its input while maintaining a
connectivity pattern that is compatible with said convolution.
Args:
filters: int, the dimension of the output space (the number of filters
in the transposed convolution).
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch_size, channels, height, width)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated transposed convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, height, width, channels)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, channels, height, width)`
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`
Returns:
A 4D tensor representing
`activation(conv2d_transpose(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
References:
- [A guide to convolution arithmetic for deep learning](
https://arxiv.org/abs/1603.07285v1)
- [Deconvolutional Networks](
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)
Examples:
>>> x = np.random.rand(4, 10, 8, 128)
>>> y = keras_core.layers.Conv2DTranspose(32, 2, 2, activation='relu')(x)
>>> print(y.shape)
(4, 20, 16, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,138 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.convolutional.base_conv_transpose import (
BaseConvTranspose,
)
@keras_core_export(
[
"keras_core.layers.Conv3DTranspose",
"keras_core.layers.Convolution3DTranspose",
]
)
class Conv3DTranspose(BaseConvTranspose):
"""3D transposed convolution layer.
The need for transposed convolutions generally arise from the desire to use
a transformation going in the opposite direction of a normal convolution,
i.e., from something that has the shape of the output of some convolution
to something that has the shape of its input while maintaining a
connectivity pattern that is compatible with said convolution.
Args:
filters: int, the dimension of the output space (the number of filters
in the transposed convolution).
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`.
It defaults to the `image_data_format` value found in your Keras
config file at `~/.keras/keras.json`. If you never set it, then it
will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated transposed convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
the projected variable (which must have the same shape). Constraints
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
Input shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
new_spatial_dim3)`
Returns:
A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`.
Raises:
ValueError: when both `strides > 1` and `dilation_rate > 1`.
References:
- [A guide to convolution arithmetic for deep learning](
https://arxiv.org/abs/1603.07285v1)
- [Deconvolutional Networks](
https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)
Examples:
>>> x = np.random.rand(4, 10, 8, 12, 128)
>>> y = keras_core.layers.Conv3DTranspose(32, 2, 2, activation='relu')(x)
>>> print(y.shape)
(4, 20, 16, 24, 32)
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
rank=3,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -0,0 +1,420 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (2, 8, 4),
"output_shape": (2, 16, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1,),
"input_shape": (2, 8, 4),
"output_shape": (2, 24, 6),
},
{
"filters": 6,
"kernel_size": (2,),
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (2, 8, 4),
"output_shape": (2, 16, 6),
},
)
def test_conv1d_transpose_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv1DTranspose,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (2, 8, 8, 4),
"output_shape": (2, 16, 16, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1, 1),
"input_shape": (2, 8, 8, 4),
"output_shape": (2, 24, 24, 6),
},
{
"filters": 6,
"kernel_size": (2, 3),
"strides": (2, 1),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
"input_shape": (2, 8, 8, 4),
"output_shape": (2, 16, 10, 6),
},
)
def test_conv2d_transpose_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv2DTranspose,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
"input_shape": (2, 8, 8, 8, 4),
"output_shape": (2, 16, 16, 16, 5),
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
"input_shape": (2, 8, 8, 8, 4),
"output_shape": (2, 24, 24, 24, 6),
},
{
"filters": 6,
"kernel_size": (2, 2, 3),
"strides": (2, 1, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
"input_shape": (2, 8, 8, 8, 4),
"output_shape": (2, 16, 9, 17, 6),
},
)
def test_conv3d_transpose_basic(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
input_shape,
output_shape,
):
self.run_layer_test(
layers.Conv3DTranspose,
init_kwargs={
"filters": filters,
"kernel_size": kernel_size,
"strides": strides,
"padding": padding,
"data_format": data_format,
"dilation_rate": dilation_rate,
},
input_shape=input_shape,
expected_output_shape=output_shape,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
def test_bad_init_args(self):
# `filters` is not positive.
with self.assertRaises(ValueError):
layers.Conv1DTranspose(filters=0, kernel_size=1)
# `kernel_size` has 0.
with self.assertRaises(ValueError):
layers.Conv2DTranspose(filters=2, kernel_size=(1, 0))
# `strides` has 0.
with self.assertRaises(ValueError):
layers.Conv2DTranspose(
filters=2, kernel_size=(2, 2), strides=(1, 0)
)
# `dilation_rate > 1` while `strides > 1`.
with self.assertRaises(ValueError):
layers.Conv2DTranspose(
filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1)
)
class ConvTransposeCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1,),
},
{
"filters": 6,
"kernel_size": (2,),
"strides": (2,),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
)
def test_conv1d_transpose(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
):
layer = layers.Conv1DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
tf_keras_layer = tf.keras.layers.Conv1DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
inputs = np.random.normal(size=[2, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1, 1),
},
{
"filters": 6,
"kernel_size": (2, 3),
"strides": (2, 1),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1),
},
)
def test_conv2d_transpose(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
):
layer = layers.Conv2DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
tf_keras_layer = tf.keras.layers.Conv2DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
inputs = np.random.normal(size=[2, 8, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
@parameterized.parameters(
{
"filters": 5,
"kernel_size": 2,
"strides": 2,
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": 1,
},
{
"filters": 6,
"kernel_size": 2,
"strides": 3,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
},
{
"filters": 6,
"kernel_size": (2, 2, 3),
"strides": (2, 1, 2),
"padding": "valid",
"data_format": "channels_last",
"dilation_rate": (1, 1, 1),
},
)
def test_conv3d_transpose(
self,
filters,
kernel_size,
strides,
padding,
data_format,
dilation_rate,
):
layer = layers.Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
tf_keras_layer = tf.keras.layers.Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
inputs = np.random.normal(size=[2, 8, 8, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)
kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)
layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@ -0,0 +1,79 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Cropping1D")
class Cropping1D(Layer):
"""Cropping layer for 1D input (e.g. temporal sequence).
It crops along the time dimension (axis 1).
Examples:
>>> input_shape = (2, 3, 2)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> print(x)
[[[ 0 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]]
>>> y = keras_core.layers.Cropping1D(cropping=1)(x)
>>> print(y)
[[[2 3]]
[[8 9]]]
Args:
cropping: Integer or tuple of integers of length 2.
How many units should be trimmed off at the beginning and end of
the cropping dimension (axis 1).
If a single int is provided, the same value will be used for both.
Input shape:
3D tensor with shape `(batch_size, axis_to_crop, features)`
Output shape:
3D tensor with shape `(batch_size, cropped_axis, features)`
"""
def __init__(self, cropping=(1, 1), name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
if isinstance(cropping, int):
cropping = (cropping, cropping)
self.cropping = cropping
self.input_spec = InputSpec(ndim=3)
def compute_output_shape(self, input_shape):
if input_shape[1] is not None:
length = input_shape[1] - self.cropping[0] - self.cropping[1]
if length <= 0:
raise ValueError(
"`cropping` parameter of `Cropping1D` layer must be "
"greater than the input length. Received: input_shape="
f"{input_shape}, cropping={self.cropping}"
)
else:
length = None
return (input_shape[0], length, input_shape[2])
def call(self, inputs):
if (
inputs.shape[1] is not None
and sum(self.cropping) >= inputs.shape[1]
):
raise ValueError(
"`cropping` parameter of `Cropping1D` layer must be "
"greater than the input length. Received: inputs.shape="
f"{inputs.shape}, cropping={self.cropping}"
)
if self.cropping[1] == 0:
return inputs[:, self.cropping[0] :, :]
else:
return inputs[:, self.cropping[0] : -self.cropping[1], :]
def get_config(self):
config = {"cropping": self.cropping}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,62 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
class CroppingTest(testing.TestCase):
def test_cropping_1d(self):
inputs = np.random.rand(3, 5, 7)
# Cropping with different values on the left and the right.
self.run_layer_test(
layers.Cropping1D,
init_kwargs={"cropping": (1, 2)},
input_data=inputs,
expected_output=ops.convert_to_tensor(inputs[:, 1:3, :]),
)
# Same cropping on the left and the right.
self.run_layer_test(
layers.Cropping1D,
init_kwargs={"cropping": (1, 1)},
input_data=inputs,
expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]),
)
# Same cropping on the left and the right provided as an int.
self.run_layer_test(
layers.Cropping1D,
init_kwargs={"cropping": 1},
input_data=inputs,
expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]),
)
# Cropping on the right only.
self.run_layer_test(
layers.Cropping1D,
init_kwargs={"cropping": (0, 1)},
input_data=inputs,
expected_output=ops.convert_to_tensor(inputs[:, 0:4, :]),
)
# Cropping on the left only.
self.run_layer_test(
layers.Cropping1D,
init_kwargs={"cropping": (1, 0)},
input_data=inputs,
expected_output=ops.convert_to_tensor(inputs[:, 1:5, :]),
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_cropping_1d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 5, 7))
permuted = layers.Cropping1D((1, 2))(input_layer)
self.assertEqual(permuted.shape, (None, 2, 7))
def test_cropping_1d_errors_if_cropping_more_than_available(self):
with self.assertRaises(ValueError):
input_layer = layers.Input(shape=(3, 4, 7))
layers.Cropping1D(cropping=(2, 3))(input_layer)

@ -831,9 +831,12 @@ class ConvTranspose(Operation):
)
def compute_output_spec(self, inputs, kernel):
kernel_size = kernel.shape[:-2]
filters = kernel.shape[-2]
output_shape = compute_conv_transpose_output_shape(
inputs,
kernel,
inputs.shape,
kernel_size,
filters,
self.strides,
self.padding,
self.output_padding,