keras/keras_core/layers/rnn/gru.py
Francois Chollet 784139e69c Minor fixes
2023-05-28 14:35:22 -07:00

656 lines
26 KiB
Python

from tensorflow import nest
from keras_core import activations
from keras_core import backend
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.layers.rnn.dropout_rnn_cell import DropoutRNNCell
from keras_core.layers.rnn.rnn import RNN
@keras_core_export("keras_core.layers.GRUCell")
class GRUCell(Layer, DropoutRNNCell):
"""Cell class for the GRU layer.
This class processes one step within the whole time sequence input, whereas
`keras_core.layer.GRU` processes the whole sequence.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use. Default: hyperbolic tangent
(`tanh`). If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use for the recurrent step.
Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
applied (ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer
should use a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`"glorot_uniform"`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation
of the recurrent state. Default: `"orthogonal"`.
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector.
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector.
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). False = "before",
True = "after" (default and cuDNN compatible).
seed: Random seed for dropout.
Call arguments:
inputs: A 2D tensor, with shape `(batch, features)`.
states: A 2D tensor with shape `(batch, units)`, which is the state
from the previous time step.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Example:
>>> inputs = np.random.random((32, 10, 8))
>>> rnn = keras_core.layers.RNN(keras_core.layers.GRUCell(4))
>>> output = rnn(inputs)
>>> output.shape
(32, 4)
>>> rnn = keras_core.layers.RNN(
... keras_core.layers.GRUCell(4),
... return_sequences=True,
... return_state=True)
>>> whole_sequence_output, final_state = rnn(inputs)
>>> whole_sequence_output.shape
(32, 10, 4)
>>> final_state.shape
(32, 4)
"""
def __init__(
self,
units,
activation="tanh",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
reset_after=True,
seed=None,
**kwargs,
):
if units <= 0:
raise ValueError(
"Received an invalid value for argument `units`, "
f"expected a positive integer, got {units}."
)
implementation = kwargs.pop("implementation", 2)
super().__init__(**kwargs)
self.implementation = implementation
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1.0, max(0.0, dropout))
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
self.seed = seed
self.seed_generator = backend.random.SeedGenerator(seed=seed)
self.reset_after = reset_after
self.state_size = self.units
self.output_size = self.units
def build(self, input_shape):
super().build(input_shape)
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 3),
name="kernel",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 3),
name="recurrent_kernel",
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
)
if self.use_bias:
if not self.reset_after:
bias_shape = (3 * self.units,)
else:
# separate biases for input and recurrent kernels
# Note: the shape is intentionally different from CuDNNGRU
# biases `(2 * 3 * self.units,)`, so that we can distinguish the
# classes when loading and converting saved weights.
bias_shape = (2, 3 * self.units)
self.bias = self.add_weight(
shape=bias_shape,
name="bias",
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=False):
h_tm1 = (
states[0] if nest.is_nested(states) else states
) # previous state
dp_mask = self.get_dropout_mask(inputs)
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
if self.use_bias:
if not self.reset_after:
input_bias, recurrent_bias = self.bias, None
else:
input_bias, recurrent_bias = (
ops.squeeze(e, axis=0)
for e in ops.split(self.bias, self.bias.shape[0], axis=0)
)
if training and 0.0 < self.dropout < 1.0:
inputs *= dp_mask
if training and 0.0 < self.recurrent_dropout < 1.0:
h_tm1 *= rec_dp_mask
if self.implementation == 1:
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
x_z = ops.matmul(inputs_z, self.kernel[:, : self.units])
x_r = ops.matmul(
inputs_r, self.kernel[:, self.units : self.units * 2]
)
x_h = ops.matmul(inputs_h, self.kernel[:, self.units * 2 :])
if self.use_bias:
x_z += input_bias[: self.units]
x_r += input_bias[self.units : self.units * 2]
x_h += input_bias[self.units * 2 :]
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
recurrent_z = ops.matmul(
h_tm1_z, self.recurrent_kernel[:, : self.units]
)
recurrent_r = ops.matmul(
h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
)
if self.reset_after and self.use_bias:
recurrent_z += recurrent_bias[: self.units]
recurrent_r += recurrent_bias[self.units : self.units * 2]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
# reset gate applied after/before matrix multiplication
if self.reset_after:
recurrent_h = ops.matmul(
h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
)
if self.use_bias:
recurrent_h += recurrent_bias[self.units * 2 :]
recurrent_h = r * recurrent_h
else:
recurrent_h = ops.matmul(
r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
)
hh = self.activation(x_h + recurrent_h)
else:
# inputs projected by all gate matrices at once
matrix_x = ops.matmul(inputs, self.kernel)
if self.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x += input_bias
x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1)
if self.reset_after:
# hidden state projected by all gate matrices at once
matrix_inner = ops.matmul(h_tm1, self.recurrent_kernel)
if self.use_bias:
matrix_inner += recurrent_bias
else:
# hidden state projected separately for update/reset and new
matrix_inner = ops.matmul(
h_tm1, self.recurrent_kernel[:, : 2 * self.units]
)
recurrent_z = matrix_inner[:, : self.units]
recurrent_r = matrix_inner[:, self.units : self.units * 2]
recurrent_h = matrix_inner[:, self.units * 2 :]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
if self.reset_after:
recurrent_h = r * recurrent_h
else:
recurrent_h = ops.matmul(
r * h_tm1, self.recurrent_kernel[:, 2 * self.units :]
)
hh = self.activation(x_h + recurrent_h)
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
new_state = [h] if nest.is_nested(states) else h
return h, new_state
def get_config(self):
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
def get_initial_state(self, batch_size=None):
return [
ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype)
]
@keras_core_export("keras_core.layers.GRU")
class GRU(RNN):
"""Gated Recurrent Unit - Cho et al. 2014.
Based on available runtime hardware and constraints, this layer
will choose different implementations (cuDNN-based or backend-native)
to maximize the performance. If a GPU is available and all
the arguments to the layer meet the requirement of the cuDNN kernel
(see below for details), the layer will use a fast cuDNN implementation
when using the TensorFlow backend.
The requirements to use the cuDNN implementation are:
1. `activation` == `tanh`
2. `recurrent_activation` == `sigmoid`
3. `dropout` == 0 and `recurrent_dropout` == 0
4. `unroll` is `False`
5. `use_bias` is `True`
6. `reset_after` is `True`
7. Inputs, if use masking, are strictly right-padded.
8. Eager execution is enabled in the outermost context.
There are two variants of the GRU implementation. The default one is based
on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to
hidden state before matrix multiplication. The other one is based on
[original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
The second variant is compatible with CuDNNGRU (GPU-only) and allows
inference on CPU. Thus it has separate biases for `kernel` and
`recurrent_kernel`. To use this variant, set `reset_after=True` and
`recurrent_activation='sigmoid'`.
For example:
>>> inputs = np.random.random((32, 10, 8))
>>> gru = keras_core.layers.GRU(4)
>>> output = gru(inputs)
>>> output.shape
(32, 4)
>>> gru = keras_core.layers.GRU(4, return_sequences=True, return_state=True)
>>> whole_sequence_output, final_state = gru(inputs)
>>> whole_sequence_output.shape
(32, 10, 4)
>>> final_state.shape
(32, 4)
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: sigmoid (`sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer
should use a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`"glorot_uniform"`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `"orthogonal"`.
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector.
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector.
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
seed: Random seed for dropout.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
to the output. Default: `False`.
go_backwards: Boolean (default `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default: `False`). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default: `False`).
If `True`, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). `False` is `"before"`,
`True` is `"after"` (default and cuDNN compatible).
Call arguments:
inputs: A 3D tensor, with shape `(batch, timesteps, feature)`.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
a given timestep should be masked (optional).
An individual `True` entry indicates that the corresponding timestep
should be utilized, while a `False` entry indicates that the
corresponding timestep should be ignored. Defaults to `None`.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the
cell when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used (optional). Defaults to `None`.
initial_state: List of initial state tensors to be passed to the first
call of the cell (optional, `None` causes creation
of zero-filled initial state tensors). Defaults to `None`.
"""
def __init__(
self,
units,
activation="tanh",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
seed=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
reset_after=True,
**kwargs,
):
cell = GRUCell(
units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
reset_after=reset_after,
dtype=kwargs.get("dtype", None),
trainable=kwargs.get("trainable", True),
name="gru_cell",
seed=seed,
)
super().__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
activity_regularizer=activity_regularizer,
**kwargs,
)
self.input_spec = InputSpec(ndim=3)
def inner_loop(self, sequences, initial_state, mask, training=False):
if nest.is_nested(initial_state):
initial_state = initial_state[0]
if nest.is_nested(mask):
mask = mask[0]
if not self.dropout and not self.recurrent_dropout:
try:
# Backends are allowed to specify (optionally) optimized
# implementation of the inner GRU loop. In the case of
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.gru(
sequences,
initial_state,
mask,
kernel=self.cell.kernel,
recurrent_kernel=self.cell.recurrent_kernel,
bias=self.cell.bias,
activation=self.cell.activation,
recurrent_activation=self.cell.recurrent_activation,
return_sequences=self.return_sequences,
go_backwards=self.go_backwards,
unroll=self.unroll,
reset_after=self.cell.reset_after,
)
except NotImplementedError:
pass
return super().inner_loop(
sequences, initial_state, mask=mask, training=training
)
def call(self, sequences, initial_state=None, mask=None, training=False):
return super().call(
sequences, mask=mask, training=training, initial_state=initial_state
)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def recurrent_activation(self):
return self.cell.recurrent_activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
@property
def reset_after(self):
return self.cell.reset_after
def get_config(self):
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
"seed": self.cell.seed,
}
base_config = super().get_config()
del base_config["cell"]
return {**base_config, **config}
@classmethod
def from_config(cls, config):
return cls(**config)