Add Conv LSTM layers

This commit is contained in:
Francois Chollet 2023-05-17 16:06:01 -07:00
parent 9cecfef6a2
commit 4679cdd3ab
36 changed files with 1425 additions and 417 deletions

@ -1,7 +1,11 @@
import jax.numpy as jnp
from keras_core.backend.jax.core import convert_to_tensor
def add(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.add(x1, x2)
@ -20,14 +24,20 @@ def einsum(subscripts, *operands, **kwargs):
def subtract(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.subtract(x1, x2)
def matmul(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.matmul(x1, x2)
def multiply(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.multiply(x1, x2)
@ -498,6 +508,8 @@ def where(condition, x1, x2):
def divide(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.divide(x1, x2)

@ -481,28 +481,22 @@ class JAXTrainer(base_trainer.Trainer):
)
def _predict_step(trainable_variables, non_trainable_variables, data):
outputs, _ = self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
return outputs
return [
self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
]
def _predict_multi_step(
trainable_variables, non_trainable_variables, data
):
outputs = _predict_step(
trainable_variables, non_trainable_variables, data[:1]
)
for single_step_data in data[1:]:
step_outputs = _predict_step(
outputs = []
for single_step_data in data:
outputs += _predict_step(
trainable_variables,
non_trainable_variables,
[single_step_data],
)
outputs = tf.nest.map_structure(
lambda t1, t2: jax.numpy.concatenate([t1, t2]),
outputs,
step_outputs,
)
return outputs
if self.steps_per_execution > 1:
@ -525,30 +519,29 @@ class JAXTrainer(base_trainer.Trainer):
callbacks.on_predict_begin()
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs = predict_step(
multi_step_return_values = predict_step(
trainable_variables, non_trainable_variables, x
)
outputs = append_to_outputs(batch_outputs, outputs)
for batch_outputs, _ in multi_step_return_values:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -115,13 +115,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def multi_step_on_iterator(iterator):
def mutli_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
train_function = multi_step_on_iterator
train_function = mutli_step_on_iterator
else:
train_function = one_step_on_iterator
@ -157,13 +157,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def multi_step_on_iterator(iterator):
def mutli_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
test_function = multi_step_on_iterator
test_function = mutli_step_on_iterator
else:
test_function = one_step_on_iterator
@ -186,8 +186,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
def one_step_on_data_distributed(data):
data = data[0]
def one_step_on_iterator(iterator):
"""Runs a single predict step given a Dataset iterator."""
data = next(iterator)
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,)
)
@ -196,21 +197,20 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.distribute_strategy,
reduction=self.distribute_reduction_method,
)
return outputs
return [outputs]
def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
def mutli_step_on_iterator(iterator):
outputs = []
for _ in tf.range(self.steps_per_execution):
outputs += one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
predict_function = multi_step_on_data
# TODO(haifengj): Use multi_step_on_iterator.
# predict_function = mutli_step_on_iterator
predict_function = one_step_on_iterator
else:
predict_function = one_step_on_data_distributed
predict_function = one_step_on_iterator
if not self.run_eagerly:
predict_function = tf.function(
@ -431,48 +431,28 @@ class TensorFlowTrainer(base_trainer.Trainer):
model=self,
)
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
def get_data(iterator):
"""Returns data for the next execution."""
data = []
for _ in range(self.steps_per_execution):
try:
single_step_data = next(iterator)
except (StopIteration, tf.errors.OutOfRangeError) as e:
if len(data) > 0:
# Suppress the error when still have remaining data.
return data
else:
# Re-raise the error for
# TFEpochIterator.catch_stop_iteration() to catch when
# no data left.
raise e
data.append(single_step_data)
return data
self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
multi_batch_outputs = self.predict_function(iterator)
for batch_outputs in multi_batch_outputs:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -103,8 +103,9 @@ from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional
from keras_core.layers.rnn.conv_lstm1d import ConvLSTM1D
from keras_core.layers.rnn.conv_lstm2d import ConvLSTM2D
from keras_core.layers.rnn.gru import GRU
from keras_core.layers.rnn.lstm import LSTM
from keras_core.layers.rnn.rnn import RNN

@ -9,6 +9,7 @@ from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
from keras_core.utils.argument_validation import standardize_padding
from keras_core.utils.argument_validation import standardize_tuple
@ -109,7 +110,7 @@ class BaseConv(Layer):
self.dilation_rate = standardize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.padding = padding
self.padding = standardize_padding(padding, allow_causal=rank == 1)
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias

@ -11,6 +11,7 @@ from keras_core.backend.common.backend_utils import (
)
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.utils.argument_validation import standardize_padding
from keras_core.utils.argument_validation import standardize_tuple
@ -103,7 +104,7 @@ class BaseConvTranspose(Layer):
self.dilation_rate = standardize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.padding = padding
self.padding = standardize_padding(padding)
if output_padding is None:
self.output_padding = None
else:

@ -9,6 +9,7 @@ from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
from keras_core.utils.argument_validation import standardize_padding
from keras_core.utils.argument_validation import standardize_tuple
@ -114,7 +115,7 @@ class BaseDepthwiseConv(Layer):
self.dilation_rate = standardize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.padding = padding
self.padding = standardize_padding(padding)
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias

@ -10,6 +10,7 @@ from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
from keras_core.utils.argument_validation import standardize_padding
from keras_core.utils.argument_validation import standardize_tuple
@ -115,7 +116,7 @@ class BaseSeparableConv(Layer):
self.dilation_rate = standardize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.padding = padding
self.padding = standardize_padding(padding)
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias

@ -335,6 +335,15 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
"dilation_rate": 1,
"groups": 1,
},
{
"filters": 4,
"kernel_size": 3,
"strides": 2,
"padding": "same",
"data_format": "channels_last",
"dilation_rate": 1,
"groups": 1,
},
{
"filters": 6,
"kernel_size": 2,
@ -346,7 +355,7 @@ class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
},
{
"filters": 6,
"kernel_size": (2, 2),
"kernel_size": (4, 3),
"strides": (2, 1),
"padding": "valid",
"data_format": "channels_last",

@ -571,7 +571,9 @@ class Layer(Operation):
else:
# Use compute_output_shape() to return the right output spec
call_spec = CallSpec(self.call, args, kwargs)
shapes_dict = get_shapes_dict(self.compute_output_shape, call_spec)
shapes_dict = get_shapes_dict(
self.compute_output_shape, call_spec, self.__class__
)
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
@ -723,7 +725,7 @@ class Layer(Operation):
def _maybe_build(self, call_spec):
if not self.built:
shapes_dict = get_shapes_dict(self.build, call_spec)
shapes_dict = get_shapes_dict(self.build, call_spec, self.__class__)
self._build_shapes_dict = shapes_dict
failure = False
if len(shapes_dict) == 1:
@ -967,17 +969,17 @@ def get_arguments_dict(fn, args, kwargs):
return arg_dict
def get_shapes_dict(target_fn, call_spec):
def get_shapes_dict(target_fn, call_spec, cls):
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
>>> get_shapes_dict(self.build, call_spec)
>>> get_shapes_dict(self.build, call_spec, cls)
{"input_a_shape": (2, 3)}
```
"""
expected_names = check_shapes_signature(target_fn, call_spec)
expected_names = check_shapes_signature(target_fn, call_spec, cls)
shapes_dict = {}
for k, v in call_spec.tensor_arguments_dict.items():
if k == "mask" or k.startswith("mask_"):
@ -997,7 +999,7 @@ def get_shapes_dict(target_fn, call_spec):
return shapes_dict
def check_shapes_signature(target_fn, call_spec):
def check_shapes_signature(target_fn, call_spec, cls):
"""Asserts that the argument names in `target_fn` match arguments in `call`.
We use this to check that `build()` and `compute_output_shape()` arguments
@ -1037,13 +1039,15 @@ def check_shapes_signature(target_fn, call_spec):
)
if not name.endswith("_shape"):
raise ValueError(
f"{error_preamble} Received `{method_name}()` argument "
f"{error_preamble} For layer '{cls.__name__}', "
f"Received `{method_name}()` argument "
f"`{name}`, which does not end in `_shape`."
)
expected_call_arg = name.removesuffix("_shape")
if expected_call_arg not in call_spec.arguments_dict:
raise ValueError(
f"{error_preamble} Received `{method_name}()` argument "
f"{error_preamble} For layer '{cls.__name__}', "
f"received `{method_name}()` argument "
f"`{name}`, but `call()` does not have argument "
f"`{expected_call_arg}`."
)

@ -1,122 +0,0 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.ZeroPadding3D")
class ZeroPadding3D(Layer):
"""Zero-padding layer for 3D data (spatial or spatio-temporal).
Examples:
>>> input_shape = (1, 1, 2, 2, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> y = keras_core.layers.ZeroPadding3D(padding=2)(x)
>>> y.shape
(1, 5, 6, 6, 3)
Args:
padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding is applied to depth, height,
and width.
- If tuple of 3 ints: interpreted as three different symmetric
padding values for depth, height, and width:
`(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad)`
Output shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_padded_axis, second_padded_axis,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_padded_axis, second_padded_axis,
third_axis_to_pad)`
"""
def __init__(
self,
padding=((1, 1), (1, 1), (1, 1)),
data_format=None,
name=None,
dtype=None,
):
super().__init__(name=name, dtype=dtype)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(padding, int):
self.padding = (
(padding, padding),
(padding, padding),
(padding, padding),
)
elif hasattr(padding, "__len__"):
if len(padding) != 3:
raise ValueError(
f"`padding` should have 3 elements. Received: {padding}."
)
dim1_padding = padding[0]
if isinstance(dim1_padding, int):
dim1_padding = (dim1_padding, dim1_padding)
dim2_padding = padding[1]
if isinstance(dim2_padding, int):
dim2_padding = (dim2_padding, dim2_padding)
dim3_padding = padding[2]
if isinstance(dim3_padding, int):
dim3_padding = (dim3_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
else:
raise ValueError(
"`padding` should be either an int, a tuple of 3 ints "
"(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), "
"or a tuple of 3 tuples of 2 ints "
"((left_dim1_pad, right_dim1_pad),"
" (left_dim2_pad, right_dim2_pad),"
" (left_dim3_pad, right_dim2_pad)). "
f"Received: {padding}."
)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
spatial_dims_offset = 2 if self.data_format == "channels_first" else 1
for index in range(0, 3):
if output_shape[index + spatial_dims_offset] is not None:
output_shape[index + spatial_dims_offset] += (
self.padding[index][0] + self.padding[index][1]
)
return tuple(output_shape)
def call(self, inputs):
if self.data_format == "channels_first":
all_dims_padding = ((0, 0), (0, 0), *self.padding)
else:
all_dims_padding = ((0, 0), *self.padding, (0, 0))
return ops.pad(inputs, all_dims_padding)
def get_config(self):
config = {"padding": self.padding, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -1,84 +0,0 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("channels_first", "channels_first"), ("channels_last", "channels_last")
)
def test_zero_padding_3d(self, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=((1, 2), (3, 4), (0, 2)), data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 1:-2, 3:-4, 0:-2], inputs)
else:
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 1:-2, 3:-4, 0:-2, :], inputs)
@parameterized.product(
(
{"padding": ((2, 2), (2, 2), (2, 2))}, # 3 tuples
{"padding": (2, 2, 2)}, # 1 tuple
{"padding": 2}, # 1 int
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_zero_padding_3d_with_same_padding(self, padding, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=padding, data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 2:-2, 2:-2, 2:-2], inputs)
else:
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_zero_padding_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5))
permuted = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 5, 10, 15, 5))
def test_zero_padding_3d_errors_if_padding_argument_invalid(self):
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1,))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2, 3, 4))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding="1")

@ -170,8 +170,8 @@ class Bidirectional(Wrapper):
f'"{backward_value}" for backward layer'
)
def compute_output_shape(self, sequence_shape, initial_state_shape=None):
output_shape = self.forward_layer.compute_output_shape(sequence_shape)
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
output_shape = self.forward_layer.compute_output_shape(sequences_shape)
if self.return_state:
output_shape, state_shape = output_shape[0], output_shape[1:]
@ -191,7 +191,7 @@ class Bidirectional(Wrapper):
def call(
self,
sequence,
sequences,
initial_state=None,
mask=None,
training=None,
@ -207,12 +207,12 @@ class Bidirectional(Wrapper):
# array. They are only passed in from kwarg initial_state, and
# should be passed to forward/backward layer via kwarg
# initial_state as well.
forward_inputs, backward_inputs = sequence, sequence
forward_inputs, backward_inputs = sequences, sequences
half = len(initial_state) // 2
forward_state = initial_state[:half]
backward_state = initial_state[half:]
else:
forward_inputs, backward_inputs = sequence, sequence
forward_inputs, backward_inputs = sequences, sequences
forward_state, backward_state = None, None
y = self.forward_layer(
@ -261,12 +261,12 @@ class Bidirectional(Wrapper):
self.forward_layer.reset_state()
self.backward_layer.reset_state()
def build(self, sequence_shape, initial_state_shape=None):
self.forward_layer.build(sequence_shape)
self.backward_layer.build(sequence_shape)
def build(self, sequences_shape, initial_state_shape=None):
self.forward_layer.build(sequences_shape)
self.backward_layer.build(sequences_shape)
self.built = True
def compute_mask(self, sequence, mask):
def compute_mask(self, _, mask):
if isinstance(mask, list):
mask = mask[0]
if self.return_sequences:

@ -0,0 +1,689 @@
from tensorflow import nest
from keras_core import activations
from keras_core import backend
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.layers.rnn.dropout_rnn_cell import DropoutRNNCell
from keras_core.layers.rnn.rnn import RNN
from keras_core.operations import operation_utils
from keras_core.utils import argument_validation
class ConvLSTMCell(Layer, DropoutRNNCell):
"""Cell class for the ConvLSTM layer.
Args:
rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
dimensions of the convolution window.
strides: An integer or tuple/list of n integers, specifying the strides
of the convolution. Specifying any stride value != 1
is incompatible with specifying any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly
to the left/right or up/down of the input such that output
has the same height/width dimension as the input.
data_format: A string, one of `channels_last` (default) or
`channels_first`. When unspecified, uses
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json` (if exists) else 'channels_last'.
Defaults to 'channels_last'.
dilation_rate: An integer or tuple/list of n integers, specifying the
dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
activation: Activation function. If `None`, no activation is applied.
recurrent_activation: Activation function to use for the recurrent step.
use_bias: Boolean, (default `True`), whether the layer
should use a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`"glorot_uniform"`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `"orthogonal"`.
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
unit_forget_bias: Boolean (default `True`). If `True`,
add 1 to the bias of the forget gate at initialization.
Setting it to `True` will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al.](
https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector.
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector.
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
Call arguments:
inputs: A (2+ `rank`)D tensor.
states: List of state tensors corresponding to the previous timestep.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
"""
def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.seed = seed
self.seed_generator = backend.random.SeedGenerator(seed=seed)
self.rank = rank
if self.rank > 3:
raise ValueError(
f"Rank {rank} convolutions are not currently "
f"implemented. Received: rank={rank}"
)
self.filters = filters
self.kernel_size = argument_validation.standardize_tuple(
kernel_size, self.rank, "kernel_size"
)
self.strides = argument_validation.standardize_tuple(
strides, self.rank, "strides", allow_zero=True
)
self.padding = argument_validation.standardize_padding(padding)
self.data_format = backend.standardize_data_format(data_format)
self.dilation_rate = argument_validation.standardize_tuple(
dilation_rate, self.rank, "dilation_rate"
)
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.unit_forget_bias = unit_forget_bias
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1.0, max(0.0, dropout))
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
self.input_spec = InputSpec(ndim=rank + 2)
self.state_size = -1 # Custom, defined in methods
def build(self, input_shape):
if self.data_format == "channels_first":
channel_axis = 1
self.spatial_dims = input_shape[2:]
else:
channel_axis = -1
self.spatial_dims = input_shape[1:-1]
if None in self.spatial_dims:
raise ValueError(
"ConvLSTM layers only support static "
"input shapes for the spatial dimension. "
f"Received invalid input shape: input_shape={input_shape}"
)
if input_shape[channel_axis] is None:
raise ValueError(
"The channel dimension of the inputs (last axis) should be "
"defined. Found None. Full input shape received: "
f"input_shape={input_shape}"
)
self.input_spec = InputSpec(
ndim=self.rank + 3, shape=(None,) + input_shape[1:]
)
input_dim = input_shape[channel_axis]
self.input_dim = input_dim
self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
recurrent_kernel_shape = self.kernel_size + (
self.filters,
self.filters * 4,
)
self.kernel = self.add_weight(
shape=self.kernel_shape,
initializer=self.kernel_initializer,
name="kernel",
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
self.recurrent_kernel = self.add_weight(
shape=recurrent_kernel_shape,
initializer=self.recurrent_initializer,
name="recurrent_kernel",
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
)
if self.use_bias:
if self.unit_forget_bias:
def bias_initializer(_, *args, **kwargs):
return ops.concatenate(
[
self.bias_initializer(
(self.filters,), *args, **kwargs
),
initializers.get("ones")(
(self.filters,), *args, **kwargs
),
self.bias_initializer(
(self.filters * 2,), *args, **kwargs
),
]
)
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(
shape=(self.filters * 4,),
name="bias",
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=False):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
dp_mask = self.get_dropout_mask(inputs)
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
if training and 0.0 < self.dropout < 1.0:
inputs *= dp_mask
if training and 0.0 < self.recurrent_dropout < 1.0:
h_tm1 *= rec_dp_mask
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
(kernel_i, kernel_f, kernel_c, kernel_o) = ops.split(
self.kernel, 4, axis=self.rank + 1
)
(
recurrent_kernel_i,
recurrent_kernel_f,
recurrent_kernel_c,
recurrent_kernel_o,
) = ops.split(self.recurrent_kernel, 4, axis=self.rank + 1)
if self.use_bias:
bias_i, bias_f, bias_c, bias_o = ops.split(self.bias, 4)
else:
bias_i, bias_f, bias_c, bias_o = None, None, None, None
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f)
c = f * c_tm1 + i * self.activation(x_c + h_c)
o = self.recurrent_activation(x_o + h_o)
h = o * self.activation(c)
return h, [h, c]
def compute_output_shape(self, inputs_shape, states_shape=None):
conv_output_shape = operation_utils.compute_conv_output_shape(
inputs_shape,
self.filters,
self.kernel_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
return conv_output_shape, [conv_output_shape, conv_output_shape]
def get_initial_state(self, batch_size=None):
if self.data_format == "channels_last":
input_shape = (batch_size,) + self.spatial_dims + (self.input_dim,)
else:
input_shape = (batch_size, self.input_dim) + self.spatial_dims
state_shape = self.compute_output_shape(input_shape)[0]
return [
ops.zeros(state_shape, dtype=self.compute_dtype),
ops.zeros(state_shape, dtype=self.compute_dtype),
]
def input_conv(self, x, w, b=None, padding="valid"):
conv_out = ops.conv(
x,
w,
strides=self.strides,
padding=padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
if b is not None:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(b, bias_shape)
conv_out += bias
return conv_out
def recurrent_conv(self, x, w):
strides = argument_validation.standardize_tuple(
1, self.rank, "strides", allow_zero=True
)
conv_out = ops.conv(
x, w, strides=strides, padding="same", data_format=self.data_format
)
return conv_out
def get_config(self):
config = {
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"unit_forget_bias": self.unit_forget_bias,
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
class ConvLSTM(RNN):
"""Abstract N-D Convolutional LSTM layer (used as implementation base).
Similar to an LSTM layer, but the input transformations
and recurrent transformations are both convolutional.
Args:
rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
dimensions of the convolution window.
strides: An integer or tuple/list of n integers,
specifying the strides of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, time, ..., channels)`
while `channels_first` corresponds to
inputs with shape `(batch, time, channels, ...)`.
When unspecified, uses
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json` (if exists) else 'channels_last'.
Defaults to 'channels_last'.
dilation_rate: An integer or tuple/list of n integers, specifying
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
activation: Activation function to use.
By default hyperbolic tangent activation function is applied
(`tanh(x)`).
recurrent_activation: Activation function to use
for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Use in combination with `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al., 2015](
http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to.
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. (default False)
return_state: Boolean Whether to return the last state
in addition to the output. (default False)
go_backwards: Boolean (default False).
If True, process the input sequence backwards.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
"""
def __init__(
self,
rank,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
seed=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
**kwargs,
):
cell = ConvLSTMCell(
rank=rank,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
unit_forget_bias=unit_forget_bias,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
seed=seed,
name="conv_lstm_cell",
dtype=kwargs.get("dtype"),
)
super().__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
**kwargs,
)
self.input_spec = InputSpec(ndim=rank + 3)
def call(self, sequences, initial_state=None, mask=None, training=False):
return super().call(
sequences, initial_state=initial_state, mask=mask, training=training
)
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
batch_size = sequences_shape[0]
steps = sequences_shape[1]
step_shape = (batch_size,) + sequences_shape[2:]
state_shape = self.cell.compute_output_shape(step_shape)[0][1:]
if self.return_sequences:
output_shape = (
batch_size,
steps,
) + state_shape
else:
output_shape = (batch_size,) + state_shape
if self.return_state:
batched_state_shape = (batch_size,) + state_shape
return output_shape, batched_state_shape, batched_state_shape
return output_shape
def compute_mask(self, _, mask):
mask = nest.flatten(mask)[0]
output_mask = mask if self.return_sequences else None
if self.return_state:
state_mask = [None, None]
return [output_mask] + state_mask
else:
return output_mask
@property
def filters(self):
return self.cell.filters
@property
def kernel_size(self):
return self.cell.kernel_size
@property
def strides(self):
return self.cell.strides
@property
def padding(self):
return self.cell.padding
@property
def data_format(self):
return self.cell.data_format
@property
def dilation_rate(self):
return self.cell.dilation_rate
@property
def activation(self):
return self.cell.activation
@property
def recurrent_activation(self):
return self.cell.recurrent_activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def unit_forget_bias(self):
return self.cell.unit_forget_bias
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
def get_config(self):
config = {
"filters": self.filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"unit_forget_bias": self.unit_forget_bias,
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"seed": self.cell.seed,
}
base_config = super().get_config()
del base_config["cell"]
return {**base_config, **config}
@classmethod
def from_config(cls, config):
return cls(**config)

@ -0,0 +1,181 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.rnn.conv_lstm import ConvLSTM
@keras_core_export("keras_core.layers.ConvLSTM1D")
class ConvLSTM1D(ConvLSTM):
"""1D Convolutional LSTM.
Similar to an LSTM layer, but the input transformations
and recurrent transformations are both convolutional.
Args:
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of 1 integer, specifying the size of
the convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
padding: string, `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the
same height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated convolution.
activation: Activation function to use. By default hyperbolic tangent
activation function is applied (`tanh(x)`).
recurrent_activation: Activation function to use for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel` weights
matrix, used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean. If `True`, add 1 to the bias of
the forget gate at initialization.
Use in combination with `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al., 2015](
http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
to the output. Default: `False`.
go_backwards: Boolean (default: `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default: `False`).
If `True`, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 4D tensor.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode.
This is only relevant if `dropout` or `recurrent_dropout` are set.
Input shape:
- If `data_format="channels_first"`:
4D tensor with shape: `(samples, time, channels, rows)`
- If `data_format="channels_last"`:
4D tensor with shape: `(samples, time, rows, channels)`
Output shape:
- If `return_state`: a list of tensors. The first tensor is the output.
The remaining tensors are the last states,
each 3D tensor with shape: `(samples, filters, new_rows)` if
`data_format='channels_first'`
or shape: `(samples, new_rows, filters)` if
`data_format='channels_last'`.
`rows` values might have changed due to padding.
- If `return_sequences`: 4D tensor with shape: `(samples, timesteps,
filters, new_rows)` if data_format='channels_first'
or shape: `(samples, timesteps, new_rows, filters)` if
`data_format='channels_last'`.
- Else, 3D tensor with shape: `(samples, filters, new_rows)` if
`data_format='channels_first'`
or shape: `(samples, new_rows, filters)` if
`data_format='channels_last'`.
References:
- [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
(the current implementation does not include the feedback loop on the
cells output).
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
**kwargs
):
super().__init__(
rank=1,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
unit_forget_bias=unit_forget_bias,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
**kwargs
)

@ -0,0 +1,67 @@
from keras_core import layers
from keras_core import testing
class ConvLSTM1DTest(testing.TestCase):
def test_basics(self):
self.run_layer_test(
layers.ConvLSTM1D,
init_kwargs={"filters": 5, "kernel_size": 3, "padding": "same"},
input_shape=(3, 2, 4, 3),
expected_output_shape=(3, 4, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
self.run_layer_test(
layers.ConvLSTM1D,
init_kwargs={
"filters": 5,
"kernel_size": 3,
"padding": "valid",
"recurrent_dropout": 0.5,
},
input_shape=(3, 2, 8, 3),
call_kwargs={"training": True},
expected_output_shape=(3, 6, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
self.run_layer_test(
layers.ConvLSTM1D,
init_kwargs={
"filters": 5,
"kernel_size": 3,
"padding": "valid",
"return_sequences": True,
},
input_shape=(3, 2, 8, 3),
expected_output_shape=(3, 2, 6, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
# TODO: correctness testing
# def test_correctness(self):
# sequence = np.arange(120).reshape((2, 3, 4, 5)).astype("float32")
# layer = layers.ConvLSTM1D(
# filters=2,
# kernel_size=3,
# kernel_initializer=initializers.Constant(0.001),
# recurrent_initializer=initializers.Constant(0.0),
# bias_initializer=initializers.Constant(0.3),
# use_bias=False,
# )
# output = layer(sequence)
# self.assertAllClose(
# np.array(
# [
# [[0.49877906, 0.49877906], [0.5447451, 0.5447451]],
# [[0.94260275, 0.94260275], [0.95974874, 0.95974874]],
# ]
# ),
# output,
# )

@ -0,0 +1,181 @@
from keras_core.api_export import keras_core_export
from keras_core.layers.rnn.conv_lstm import ConvLSTM
@keras_core_export("keras_core.layers.ConvLSTM2D")
class ConvLSTM2D(ConvLSTM):
"""2D Convolutional LSTM.
Similar to an LSTM layer, but the input transformations
and recurrent transformations are both convolutional.
Args:
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of 2 integers, specifying the size of the
convolution window.
strides: int or tuple/list of 2 integers, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
padding: string, `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 2 integers, specifying the dilation
rate to use for dilated convolution.
activation: Activation function to use. By default hyperbolic tangent
activation function is applied (`tanh(x)`).
recurrent_activation: Activation function to use for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel` weights
matrix, used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean. If `True`, add 1 to the bias of the forget
gate at initialization.
Use in combination with `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al., 2015](
http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
to the output. Default: `False`.
go_backwards: Boolean (default: `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default: `False`).
If `True`, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 5D tensor.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode.
This is only relevant if `dropout` or `recurrent_dropout` are set.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Input shape:
- If `data_format='channels_first'`:
5D tensor with shape: `(samples, time, channels, rows, cols)`
- If `data_format='channels_last'`:
5D tensor with shape: `(samples, time, rows, cols, channels)`
Output shape:
- If `return_state`: a list of tensors. The first tensor is the output.
The remaining tensors are the last states,
each 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if
`data_format='channels_first'`
or shape: `(samples, new_rows, new_cols, filters)` if
`data_format='channels_last'`. `rows` and `cols` values might have
changed due to padding.
- If `return_sequences`: 5D tensor with shape: `(samples, timesteps,
filters, new_rows, new_cols)` if data_format='channels_first'
or shape: `(samples, timesteps, new_rows, new_cols, filters)` if
`data_format='channels_last'`.
- Else, 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if
`data_format='channels_first'`
or shape: `(samples, new_rows, new_cols, filters)` if
`data_format='channels_last'`.
References:
- [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
(the current implementation does not include the feedback loop on the
cells output).
"""
def __init__(
self,
filters,
kernel_size,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
**kwargs
):
super().__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
unit_forget_bias=unit_forget_bias,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
**kwargs
)

@ -0,0 +1,72 @@
from keras_core import layers
from keras_core import testing
class ConvLSTM2DTest(testing.TestCase):
def test_basics(self):
self.run_layer_test(
layers.ConvLSTM2D,
init_kwargs={"filters": 5, "kernel_size": 3, "padding": "same"},
input_shape=(3, 2, 4, 4, 3),
expected_output_shape=(3, 4, 4, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
self.run_layer_test(
layers.ConvLSTM2D,
init_kwargs={
"filters": 5,
"kernel_size": 3,
"padding": "valid",
"recurrent_dropout": 0.5,
},
input_shape=(3, 2, 8, 8, 3),
call_kwargs={"training": True},
expected_output_shape=(3, 6, 6, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
self.run_layer_test(
layers.ConvLSTM2D,
init_kwargs={
"filters": 5,
"kernel_size": 3,
"padding": "valid",
"return_sequences": True,
},
input_shape=(3, 2, 8, 8, 3),
expected_output_shape=(3, 2, 6, 6, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
# TODO: correctness testing
# def test_correctness(self):
# sequence = np.arange(480).reshape((2, 3, 4, 4, 5)).astype("float32")
# layer = layers.ConvLSTM2D(
# filters=2,
# kernel_size=3,
# kernel_initializer=initializers.Constant(0.0001),
# recurrent_initializer=initializers.Constant(0.01),
# bias_initializer=initializers.Constant(0.01),
# )
# output = layer(sequence)
# self.assertAllClose(
# np.array(
# [
# [
# [[0.4320268, 0.4320268], [0.4475501, 0.4475501]],
# [[0.49229687, 0.49229687], [0.50656533, 0.50656533]],
# ],
# [
# [[0.8781725, 0.8781725], [0.88340145, 0.88340145]],
# [[0.8988858, 0.8988858], [0.9039862, 0.9039862]],
# ],
# ]
# ),
# output,
# )

@ -32,7 +32,7 @@ class DropoutRNNCell:
def get_recurrent_dropout_mask(self, step_input):
if not hasattr(self, "_recurrent_dropout_mask"):
self._recurrent_dropout_mask = None
if self._recurrent_dropout_mask is None and self.dropout > 0:
if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0:
ones = ops.ones_like(step_input)
self._recurrent_dropout_mask = backend.random.dropout(
ones, rate=self.dropout, seed=self.seed_generator

@ -56,6 +56,7 @@ class GRUCell(Layer, DropoutRNNCell):
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). False = "before",
True = "after" (default and cuDNN compatible).
seed: Random seed for dropout.
Call arguments:
inputs: A 2D tensor, with shape `(batch, features)`.
@ -176,7 +177,7 @@ class GRUCell(Layer, DropoutRNNCell):
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
def call(self, inputs, states, training=False):
h_tm1 = (
states[0] if nest.is_nested(states) else states
) # previous state
@ -319,6 +320,7 @@ class GRUCell(Layer, DropoutRNNCell):
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
@ -412,6 +414,7 @@ class GRU(RNN):
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
@ -419,10 +422,10 @@ class GRU(RNN):
go_backwards: Boolean (default `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If `True`, the last state
stateful: Boolean (default: `False`). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
unroll: Boolean (default: `False`).
If `True`, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
@ -466,13 +469,13 @@ class GRU(RNN):
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
seed=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
reset_after=True,
seed=None,
**kwargs,
):
cell = GRUCell(
@ -509,7 +512,7 @@ class GRU(RNN):
)
self.input_spec = InputSpec(ndim=3)
def inner_loop(self, sequence, initial_state, mask, training=False):
def inner_loop(self, sequences, initial_state, mask, training=False):
if nest.is_nested(initial_state):
initial_state = initial_state[0]
if nest.is_nested(mask):
@ -522,7 +525,7 @@ class GRU(RNN):
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.gru(
sequence,
sequences,
initial_state,
mask,
kernel=self.cell.kernel,
@ -538,12 +541,12 @@ class GRU(RNN):
except NotImplementedError:
pass
return super().inner_loop(
sequence, initial_state, mask=mask, training=training
sequences, initial_state, mask=mask, training=training
)
def call(self, sequence, initial_state=None, mask=None, training=None):
def call(self, sequences, initial_state=None, mask=None, training=False):
return super().call(
sequence, mask=mask, training=training, initial_state=initial_state
sequences, mask=mask, training=training, initial_state=initial_state
)
@property
@ -643,6 +646,7 @@ class GRU(RNN):
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
"seed": self.cell.seed,
}
base_config = super().get_config()
del base_config["cell"]

@ -58,6 +58,7 @@ class LSTMCell(Layer, DropoutRNNCell):
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
Call arguments:
inputs: A 2D tensor, with shape `(batch, features)`.
@ -224,7 +225,7 @@ class LSTMCell(Layer, DropoutRNNCell):
o = self.recurrent_activation(z3)
return c, o
def call(self, inputs, states, training=None):
def call(self, inputs, states, training=False):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
@ -305,6 +306,7 @@ class LSTMCell(Layer, DropoutRNNCell):
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
@ -395,14 +397,15 @@ class LSTM(RNN):
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
to the output. Default: `False`.
go_backwards: Boolean (default `False`).
go_backwards: Boolean (default: `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If `True`, the last state
stateful: Boolean (default: `False`). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
@ -447,12 +450,12 @@ class LSTM(RNN):
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
seed=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
seed=None,
**kwargs,
):
cell = LSTMCell(
@ -490,7 +493,7 @@ class LSTM(RNN):
)
self.input_spec = InputSpec(ndim=3)
def inner_loop(self, sequence, initial_state, mask, training=False):
def inner_loop(self, sequences, initial_state, mask, training=False):
if nest.is_nested(mask):
mask = mask[0]
@ -501,7 +504,7 @@ class LSTM(RNN):
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.lstm(
sequence,
sequences,
initial_state[0],
initial_state[1],
mask,
@ -517,12 +520,12 @@ class LSTM(RNN):
except NotImplementedError:
pass
return super().inner_loop(
sequence, initial_state, mask=mask, training=training
sequences, initial_state, mask=mask, training=training
)
def call(self, sequence, initial_state=None, mask=None, training=None):
def call(self, sequences, initial_state=None, mask=None, training=False):
return super().call(
sequence, mask=mask, training=training, initial_state=initial_state
sequences, mask=mask, training=training, initial_state=initial_state
)
@property
@ -622,6 +625,7 @@ class LSTM(RNN):
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"seed": self.cell.seed,
}
base_config = super().get_config()
del base_config["cell"]

@ -239,17 +239,17 @@ class RNN(Layer):
self.state_size = state_size
self.single_state = False
def compute_output_shape(self, sequence_shape, initial_state_shape=None):
state_shape = [(sequence_shape[0], d) for d in self.state_size]
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
state_shape = [(sequences_shape[0], d) for d in self.state_size]
output_size = getattr(self.cell, "output_size", None)
if output_size is None:
output_size = self.state_size[0]
if not isinstance(output_size, int):
raise ValueError("output_size must be an integer.")
if self.return_sequences:
output_shape = (sequence_shape[0], sequence_shape[1], output_size)
output_shape = (sequences_shape[0], sequences_shape[1], output_size)
else:
output_shape = (sequence_shape[0], output_size)
output_shape = (sequences_shape[0], output_size)
if self.return_state:
return output_shape, *state_shape
return output_shape
@ -267,9 +267,9 @@ class RNN(Layer):
else:
return output_mask
def build(self, sequence_shape, initial_state_shape=None):
def build(self, sequences_shape, initial_state_shape=None):
# Build cell (if layer).
step_input_shape = (sequence_shape[0],) + sequence_shape[2:]
step_input_shape = (sequences_shape[0],) + sequences_shape[2:]
if isinstance(self.cell, Layer) and not self.cell.built:
self.cell.build(step_input_shape)
self.cell.built = True
@ -277,13 +277,13 @@ class RNN(Layer):
if self.states is not None:
self.reset_state()
else:
if sequence_shape[0] is None:
if sequences_shape[0] is None:
raise ValueError(
"When using `stateful=True` in a RNN, the "
"batch size must be static. Found dynamic "
f"batch size: sequence.shape={sequence_shape}"
f"batch size: sequence.shape={sequences_shape}"
)
self._create_state_variables(sequence_shape[0])
self._create_state_variables(sequences_shape[0])
self.built = True
@tracking.no_automatic_dependency_tracking
@ -321,7 +321,7 @@ class RNN(Layer):
for v in self.states:
v.assign(ops.zeros_like(v))
def inner_loop(self, sequence, initial_state, mask, training=False):
def inner_loop(self, sequences, initial_state, mask, training=False):
cell_kwargs = {}
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
cell_kwargs["training"] = training
@ -337,24 +337,24 @@ class RNN(Layer):
return backend.rnn(
step,
sequence,
sequences,
initial_state,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=sequence.shape[1],
input_length=sequences.shape[1],
zero_output_for_mask=self.zero_output_for_mask,
return_all_outputs=self.return_sequences,
)
def call(
self,
sequence,
sequences,
initial_state=None,
mask=None,
training=False,
):
timesteps = sequence.shape[1]
timesteps = sequences.shape[1]
if self.unroll and timesteps is None:
raise ValueError(
"Cannot unroll a RNN if the "
@ -372,7 +372,7 @@ class RNN(Layer):
initial_state = self.states
else:
initial_state = self.get_initial_state(
batch_size=ops.shape(sequence)[0]
batch_size=ops.shape(sequences)[0]
)
# RNN expect the states in a list, even if single state.
if not nest.is_nested(initial_state):
@ -387,7 +387,7 @@ class RNN(Layer):
)
last_output, outputs, states = self.inner_loop(
sequence=sequence,
sequences=sequences,
initial_state=initial_state,
mask=mask,
training=training,

@ -49,6 +49,7 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
Call arguments:
sequence: A 2D tensor, with shape `(batch, features)`.
@ -349,9 +350,9 @@ class SimpleRNN(RNN):
)
self.input_spec = [InputSpec(ndim=3)]
def call(self, sequence, initial_state=None, mask=None, training=None):
def call(self, sequences, initial_state=None, mask=None, training=False):
return super().call(
sequence, mask=mask, training=training, initial_state=initial_state
sequences, mask=mask, training=training, initial_state=initial_state
)
@property

@ -21,15 +21,15 @@ class Adadelta(optimizer.Optimizer):
learning rate can be set, as in most other Keras optimizers.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adadelta`
tends to benefit from higher initial learning rate values compared
to other optimizers. To match the exact form in the original paper,
use 1.0.
learning_rate: Initial value for the learning rate: a floating
point value, Defaults to 0.001. Note that `Adadelta` tends
to benefit from higher initial learning rate values compared to
other optimizers.
To match the exact form in the original paper, use 1.0.
rho: A floating point value. The decay rate. Defaults to 0.95.
epsilon: Small floating point value for maintaining numerical stability.
epsilon: Small floating point value used to maintain numerical
stability.
Defaults to 1e-7.
{{base_optimizer_keyword_args}}
Reference:

@ -17,10 +17,8 @@ class Adafactor(optimizer.Optimizer):
last 2 dimensions separately in its accumulator variables.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: Initial value for the learning rate:
a floating point value, Defaults to 0.001.
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator
away from 0.
@ -131,7 +129,9 @@ class Adafactor(optimizer.Optimizer):
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
one = ops.cast(1.0, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
if not callable(self._learning_rate) and self.relative_step:
if self.relative_step: # TODO: add learning_rate_schedule logic
# If `relative_step=True` and learning rate is a constant, we
# apply the relative step algorithm.
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
r = self._r[self._get_variable_index(variable)]

@ -14,16 +14,18 @@ class Adagrad(optimizer.Optimizer):
the smaller the updates.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adagrad` tends
to benefit from higher initial learning rate values compared to
other optimizers. To match the exact form in the original paper,
use 1.0.
initial_accumulator_value: Floating point value. Starting value for the
accumulators (per-parameter momentum values). Must be non-negative.
epsilon: Small floating point value for maintaining numerical stability.
learning_rate: Initial value for the learning rate:
a floating point value,
Defaults to 0.001.
Note that `Adagrad` tends to benefit from higher initial
learning rate values compared to other optimizers.
To match the exact form in the original paper, use 1.0.
initial_accumulator_value: Floating point value.
Starting value for the accumulators (per-parameter
momentum values).
Must be non-negative.
epsilon: Small floating point value used to maintain
numerical stability.
{{base_optimizer_keyword_args}}
Reference:

@ -18,10 +18,9 @@ class Adam(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to `0.001`.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates. Defaults to

@ -34,10 +34,9 @@ class Adamax(optimizer.Optimizer):
```
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: A floating point value, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
beta_1: A float value or a constant float tensor. The exponential decay
rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential decay

@ -21,10 +21,9 @@ class AdamW(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -52,10 +52,9 @@ class Ftrl(optimizer.Optimizer):
is replaced with a gradient with shrinkage.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: A floating point value, or a callable that
takes no arguments and returns the actual value to use. The learning
rate. Defaults to `0.001`.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero
for a fixed learning rate.

@ -12,10 +12,9 @@ class Nadam(optimizer.Optimizer):
Nesterov momentum.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: A floating point value or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -5,14 +5,16 @@ from keras_core.optimizers import base_optimizer
if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow import optimizer as tf_optimizer
Optimizer = tf_optimizer.TFOptimizer
BackendOptimizer = tf_optimizer.TFOptimizer
else:
Optimizer = base_optimizer.Optimizer
BackendOptimizer = base_optimizer.Optimizer
keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])(
Optimizer
)
keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])
class Optimizer(BackendOptimizer):
pass
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args

@ -18,10 +18,8 @@ class RMSprop(optimizer.Optimizer):
gradients, and uses that average to estimate the variance.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate: Initial value for the learning rate: a floating point
value, defaults to 0.001.
rho: float, defaults to 0.9. Discounting factor for the old gradients.
momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the
momentum value, with a decay rate equals to `1 - momentum`.

@ -1,7 +1,9 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer
@keras_core_export("keras_core.optimizers.SGD")
class SGD(optimizer.Optimizer):
"""Gradient descent (with momentum) optimizer.
@ -26,16 +28,16 @@ class SGD(optimizer.Optimizer):
```
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.01.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0,
i.e., vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to 0.001.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0, i.e.,
vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}
"""
def __init__(

@ -228,7 +228,6 @@ class TestTrainer(testing.TestCase):
x = np.ones((100, 4))
y = np.ones((100, 1))
batch_size = 16
model = ExampleModel(units=1)
model.compile(
loss="mse",
@ -241,11 +240,8 @@ class TestTrainer(testing.TestCase):
model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=batch_size, verbose=0)
model_2.fit(x=x, y=y, batch_size=16, verbose=0)
self.assertAllClose(model.get_weights(), model_2.get_weights())
self.assertAllClose(
model.predict(x, batch_size=batch_size),
model_2.predict(x, batch_size=batch_size),
)
self.assertAllClose(model.predict(x), model_2.predict(x))
self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))

@ -51,3 +51,20 @@ def standardize_tuple(value, n, name, allow_zero=False):
raise ValueError(error_msg)
return value_tuple
def standardize_padding(value, allow_causal=False):
if isinstance(value, (list, tuple)):
return value
padding = value.lower()
if allow_causal:
allowed_values = {"valid", "same", "causal"}
else:
allowed_values = {"valid", "same"}
if padding not in allowed_values:
raise ValueError(
"The `padding` argument must be a list/tuple or one of "
f"{allowed_values}. "
f"Received: {padding}"
)
return padding