keras/keras_core/layers/rnn/rnn.py
Francois Chollet 784139e69c Minor fixes
2023-05-28 14:35:22 -07:00

443 lines
17 KiB
Python

from tensorflow import nest
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.rnn.stacked_rnn_cells import StackedRNNCells
from keras_core.saving import serialization_lib
from keras_core.utils import tracking
class DropoutRNNCellMixin:
pass
@keras_core_export("keras_core.layers.RNN")
class RNN(Layer):
"""Base class for recurrent layers.
Args:
cell: A RNN cell instance or a list of RNN cell instances.
A RNN cell is a class that has:
- A `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- A `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent
state. This can also be a list of integers
(one size per state).
- A `output_size` attribute, a single integer.
- A `get_initial_state(batch_size=None)`
method that creates a tensor meant to be fed to `call()` as the
initial state, if the user didn't specify any initial state
via other means. The returned initial state should have
shape `(batch_size, cell.state_size)`.
The cell might choose to create a tensor full of zeros,
or other values based on the cell's implementation.
`inputs` is the input tensor to the RNN layer, with shape
`(batch_size, timesteps, features)`.
If this method is not implemented
by the cell, the RNN layer will create a zero filled tensor
with shape `(batch_size, cell.state_size)`.
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on top of each other in the RNN, resulting in an
efficient stacked RNN.
return_sequences: Boolean (default `False`). Whether to return the last
output in the output sequence, or the full sequence.
return_state: Boolean (default `False`).
Whether to return the last state in addition to the output.
go_backwards: Boolean (default `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default `False`). If True, the last state
for each sample at index `i` in a batch will be used as initial
state for the sample of index `i` in the following batch.
unroll: Boolean (default `False`).
If True, the network will be unrolled, else a symbolic loop will be
used. Unrolling can speed-up a RNN, although it tends to be more
memory-intensive. Unrolling is only suitable for short sequences.
zero_output_for_mask: Boolean (default `False`).
Whether the output should use zeros for the masked timesteps.
Note that this field is only used when `return_sequences`
is `True` and `mask` is provided.
It can useful if you want to reuse the raw output sequence of
the RNN without interference from the masked timesteps, e.g.,
merging bidirectional RNNs.
Call arguments:
inputs: Input tensor.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
mask: Binary tensor of shape `[batch_size, timesteps]`
indicating whether a given timestep should be masked.
An individual `True` entry indicates that the corresponding
timestep should be utilized, while a `False` entry indicates
that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed
to the cell when calling it.
This is for use with cells that use dropout.
Input shape:
3-D tensor with shape `(batch_size, timesteps, features)`.
Output shape:
- If `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
each with shape `(batch_size, state_size)`, where `state_size` could
be a high dimension tensor shape.
- If `return_sequences`: 3D tensor with shape
`(batch_size, timesteps, output_size)`.
Masking:
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
use a `keras_core.layers.Embedding` layer with the `mask_zero` parameter
set to `True`.
Note on using statefulness in RNNs:
You can set RNN layers to be 'stateful', which means that the states
computed for the samples in one batch will be reused as initial states
for the samples in the next batch. This assumes a one-to-one mapping
between samples in different successive batches.
To enable statefulness:
- Specify `stateful=True` in the layer constructor.
- Specify a fixed batch size for your model, by passing
If sequential model:
`batch_input_shape=(...)` to the first layer in your model.
Else for functional model with 1 or more Input layers:
`batch_shape=(...)` to all the first layers in your model.
This is the expected shape of your inputs
*including the batch size*.
It should be a tuple of integers, e.g. `(32, 10, 100)`.
- Specify `shuffle=False` when calling `fit()`.
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
Note on specifying the initial state of RNNs:
You can specify the initial state of RNN layers symbolically by
calling them with the keyword argument `initial_state`. The value of
`initial_state` should be a tensor or list of tensors representing
the initial state of the RNN layer.
You can specify the initial state of RNN layers numerically by
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
Examples:
```python
from keras_core.layers import RNN
from keras_core import operations as ops
# First, let's define a RNN Cell, as a layer subclass.
class MinimalRNNCell(keras_core.layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.state_size = units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = ops.matmul(inputs, self.kernel)
output = h + ops.matmul(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras_core.Input((None, 5))
layer = RNN(cell)
y = layer(x)
# Here's how to use the cell to build a stacked RNN:
cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras_core.Input((None, 5))
layer = RNN(cells)
y = layer(x)
```
"""
def __init__(
self,
cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
zero_output_for_mask=False,
**kwargs,
):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
if "call" not in dir(cell):
raise ValueError(
"Argument `cell` should have a `call` method. "
f"Received: cell={cell}"
)
if "state_size" not in dir(cell):
raise ValueError(
"The RNN cell should have a `state_size` attribute "
"(single integer or list of integers, "
"one integer per RNN state). "
f"Received: cell={cell}"
)
super().__init__(**kwargs)
# If True, the output for masked timestep will be zeros, whereas in the
# False case, output from previous timestep is returned for masked
# timestep.
self.zero_output_for_mask = zero_output_for_mask
self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
self.supports_masking = True
self.input_spec = None
self.states = None
state_size = getattr(self.cell, "state_size", None)
if state_size is None:
raise ValueError(
"state_size must be specified as property on the RNN cell."
)
if not isinstance(state_size, (list, int)):
raise ValueError(
"state_size must be an integer, or a list of integers "
"(one for each state tensor)."
)
if isinstance(state_size, int):
self.state_size = [state_size]
self.single_state = True
else:
self.state_size = state_size
self.single_state = False
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
state_shape = [(sequences_shape[0], d) for d in self.state_size]
output_size = getattr(self.cell, "output_size", None)
if output_size is None:
output_size = self.state_size[0]
if not isinstance(output_size, int):
raise ValueError("output_size must be an integer.")
if self.return_sequences:
output_shape = (sequences_shape[0], sequences_shape[1], output_size)
else:
output_shape = (sequences_shape[0], output_size)
if self.return_state:
return output_shape, *state_shape
return output_shape
def compute_mask(self, _, mask):
# Time step masks must be the same for each input.
# This is because the mask for an RNN is of size [batch, time_steps, 1],
# and specifies which time steps should be skipped, and a time step
# must be skipped for all inputs.
mask = nest.flatten(mask)[0]
output_mask = mask if self.return_sequences else None
if self.return_state:
state_mask = [None for _ in self.state_size]
return [output_mask] + state_mask
else:
return output_mask
def build(self, sequences_shape, initial_state_shape=None):
# Build cell (if layer).
step_input_shape = (sequences_shape[0],) + tuple(sequences_shape[2:])
if isinstance(self.cell, Layer) and not self.cell.built:
self.cell.build(step_input_shape)
self.cell.built = True
if self.stateful:
if self.states is not None:
self.reset_state()
else:
if sequences_shape[0] is None:
raise ValueError(
"When using `stateful=True` in a RNN, the "
"batch size must be static. Found dynamic "
f"batch size: sequence.shape={sequences_shape}"
)
self._create_state_variables(sequences_shape[0])
self.built = True
@tracking.no_automatic_dependency_tracking
def _create_state_variables(self, batch_size):
self.states = nest.map_structure(
lambda value: backend.Variable(
value, trainable=False, dtype=self.variable_dtype
),
self.get_initial_state(batch_size),
)
def get_initial_state(self, batch_size):
get_initial_state_fn = getattr(self.cell, "get_initial_state", None)
if get_initial_state_fn:
init_state = get_initial_state_fn(batch_size=batch_size)
else:
return [
ops.zeros((batch_size, d), dtype=self.compute_dtype)
for d in self.state_size
]
# RNN expect the states in a list, even if single state.
if not nest.is_nested(init_state):
init_state = [init_state]
# Force the state to be a list in case it is a namedtuple eg
# LSTMStateTuple.
return list(init_state)
def reset_states(self):
# Compatibility alias.
self.reset_state()
def reset_state(self):
if self.states is not None:
for v in self.states:
v.assign(ops.zeros_like(v))
def inner_loop(self, sequences, initial_state, mask, training=False):
cell_kwargs = {}
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
cell_kwargs["training"] = training
def step(inputs, states):
output, new_states = self.cell(inputs, states, **cell_kwargs)
if not nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
if not nest.is_nested(initial_state):
initial_state = [initial_state]
return backend.rnn(
step,
sequences,
initial_state,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=sequences.shape[1],
zero_output_for_mask=self.zero_output_for_mask,
return_all_outputs=self.return_sequences,
)
def call(
self,
sequences,
initial_state=None,
mask=None,
training=False,
):
timesteps = sequences.shape[1]
if self.unroll and timesteps is None:
raise ValueError(
"Cannot unroll a RNN if the "
"time dimension is undefined. \n"
"- If using a Sequential model, "
"specify the time dimension by passing "
"an `Input()` as your first layer.\n"
"- If using the functional API, specify "
"the time dimension by passing a `shape` "
"or `batch_shape` argument to your `Input()`."
)
if initial_state is None:
if self.stateful:
initial_state = self.states
else:
initial_state = self.get_initial_state(
batch_size=ops.shape(sequences)[0]
)
# RNN expect the states in a list, even if single state.
if not nest.is_nested(initial_state):
initial_state = [initial_state]
initial_state = list(initial_state)
# Cast states to compute dtype.
# Note that states may be deeply nested
# (e.g. in the stacked cells case).
initial_state = nest.map_structure(
lambda x: ops.cast(x, dtype=self.compute_dtype), initial_state
)
last_output, outputs, states = self.inner_loop(
sequences=sequences,
initial_state=initial_state,
mask=mask,
training=training,
)
self._maybe_reset_dropout_masks(self.cell)
if self.stateful:
for self_state, state in zip(
nest.flatten(self.states), nest.flatten(states)
):
self_state.assign(state)
if self.return_sequences:
output = outputs
else:
output = last_output
if self.return_state:
if len(states) == 1:
state = states[0]
return output, state
return output, *states
return output
def _maybe_reset_dropout_masks(self, cell):
if isinstance(cell, DropoutRNNCellMixin):
cell.reset_dropout_mask()
cell.reset_recurrent_dropout_mask()
if isinstance(cell, StackedRNNCells):
for c in cell.cells:
self._maybe_reset_dropout_masks(c)
def get_config(self):
config = {
"return_sequences": self.return_sequences,
"return_state": self.return_state,
"go_backwards": self.go_backwards,
"stateful": self.stateful,
"unroll": self.unroll,
"zero_output_for_mask": self.zero_output_for_mask,
}
config["cell"] = serialization_lib.serialize_keras_object(self.cell)
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
cell = serialization_lib.deserialize_keras_object(
config.pop("cell"), custom_objects=custom_objects
)
layer = cls(cell, **config)
return layer