keras/keras_core/layers/rnn/bidirectional.py
2023-05-17 16:06:01 -07:00

317 lines
13 KiB
Python

import copy
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.core.wrapper import Wrapper
from keras_core.layers.layer import Layer
from keras_core.saving import serialization_lib
@keras_core_export("keras_core.layers.Bidirectional")
class Bidirectional(Wrapper):
"""Bidirectional wrapper for RNNs.
Args:
layer: `keras_core.layers.RNN` instance, such as
`keras_core.layers.LSTM` or `keras_core.layers.GRU`.
It could also be a `keras_core.layers.Layer` instance
that meets the following criteria:
1. Be a sequence-processing layer (accepts 3D+ inputs).
2. Have a `go_backwards`, `return_sequences` and `return_state`
attribute (with the same semantics as for the `RNN` class).
3. Have an `input_spec` attribute.
4. Implement serialization via `get_config()` and `from_config()`.
Note that the recommended way to create new RNN layers is to write a
custom RNN cell and use it with `keras_core.layers.RNN`, instead of
subclassing `keras_core.layers.Layer` directly.
When `return_sequences` is `True`, the output of the masked
timestep will be zero regardless of the layer's original
`zero_output_for_mask` value.
merge_mode: Mode by which outputs of the forward and backward RNNs
will be combined. One of `{"sum", "mul", "concat", "ave", None}`.
If `None`, the outputs will not be combined,
they will be returned as a list. Defaults to `"concat"`.
backward_layer: Optional `keras_core.layers.RNN`,
or `keras_core.layers.Layer` instance to be used to handle
backwards input processing.
If `backward_layer` is not provided, the layer instance passed
as the `layer` argument will be used to generate the backward layer
automatically.
Note that the provided `backward_layer` layer should have properties
matching those of the `layer` argument, in particular
it should have the same values for `stateful`, `return_states`,
`return_sequences`, etc. In addition, `backward_layer`
and `layer` should have different `go_backwards` argument values.
A `ValueError` will be raised if these requirements are not met.
Call arguments:
The call arguments for this layer are the same as those of the
wrapped RNN layer. Beware that when passing the `initial_state`
argument during the call of this layer, the first half in the
list of elements in the `initial_state` list will be passed to
the forward RNN call and the last half in the list of elements
will be passed to the backward RNN call.
Note: instantiating a `Bidirectional` layer from an existing RNN layer
instance will not reuse the weights state of the RNN layer instance -- the
`Bidirectional` layer will have freshly initialized weights.
Examples:
```python
model = Sequential([
Input(shape=(5, 10)),
Bidirectional(LSTM(10, return_sequences=True),
Bidirectional(LSTM(10)),
Dense(5, activation="softmax"),
])
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# With custom backward layer
forward_layer = LSTM(10, return_sequences=True)
backward_layer = LSTM(10, activation='relu', return_sequences=True,
go_backwards=True)
model = Sequential([
Input(shape=(5, 10)),
Bidirectional(forward_layer, backward_layer=backward_layer),
Dense(5, activation="softmax"),
])
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
```
"""
def __init__(
self,
layer,
merge_mode="concat",
weights=None,
backward_layer=None,
**kwargs,
):
if not isinstance(layer, Layer):
raise ValueError(
"Please initialize `Bidirectional` layer with a "
f"`keras_core.layers.Layer` instance. Received: {layer}"
)
if backward_layer is not None and not isinstance(backward_layer, Layer):
raise ValueError(
"`backward_layer` need to be a `keras_core.layers.Layer` "
f"instance. Received: {backward_layer}"
)
if merge_mode not in ["sum", "mul", "ave", "concat", None]:
raise ValueError(
f"Invalid merge mode. Received: {merge_mode}. "
"Merge mode should be one of "
'{"sum", "mul", "ave", "concat", None}'
)
super().__init__(layer, **kwargs)
# Recreate the forward layer from the original layer config, so that it
# will not carry over any state from the layer.
config = serialization_lib.serialize_keras_object(layer)
config["config"]["name"] = "forward_" + layer.name.removeprefix(
"forward_"
)
self.forward_layer = serialization_lib.deserialize_keras_object(config)
if backward_layer is None:
config = serialization_lib.serialize_keras_object(layer)
config["config"]["go_backwards"] = True
config["config"]["name"] = "backward_" + layer.name.removeprefix(
"backward_"
)
self.backward_layer = serialization_lib.deserialize_keras_object(
config
)
else:
self.backward_layer = backward_layer
self._verify_layer_config()
def force_zero_output_for_mask(layer):
# Force the zero_output_for_mask to be True if returning sequences.
if getattr(layer, "zero_output_for_mask", None) is not None:
layer.zero_output_for_mask = layer.return_sequences
force_zero_output_for_mask(self.forward_layer)
force_zero_output_for_mask(self.backward_layer)
self.merge_mode = merge_mode
if weights:
nw = len(weights)
self.forward_layer.initial_weights = weights[: nw // 2]
self.backward_layer.initial_weights = weights[nw // 2 :]
self.stateful = layer.stateful
self.return_sequences = layer.return_sequences
self.return_state = layer.return_state
self.supports_masking = True
self.input_spec = layer.input_spec
def _verify_layer_config(self):
"""Ensure the forward and backward layers have valid common property."""
if self.forward_layer.go_backwards == self.backward_layer.go_backwards:
raise ValueError(
"Forward layer and backward layer should have different "
"`go_backwards` value. Received: "
"forward_layer.go_backwards "
f"{self.forward_layer.go_backwards}, "
"backward_layer.go_backwards="
f"{self.backward_layer.go_backwards}"
)
common_attributes = ("stateful", "return_sequences", "return_state")
for a in common_attributes:
forward_value = getattr(self.forward_layer, a)
backward_value = getattr(self.backward_layer, a)
if forward_value != backward_value:
raise ValueError(
"Forward layer and backward layer are expected to have "
f'the same value for attribute "{a}", got '
f'"{forward_value}" for forward layer and '
f'"{backward_value}" for backward layer'
)
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
output_shape = self.forward_layer.compute_output_shape(sequences_shape)
if self.return_state:
output_shape, state_shape = output_shape[0], output_shape[1:]
if self.merge_mode == "concat":
output_shape = list(output_shape)
output_shape[-1] *= 2
output_shape = tuple(output_shape)
elif self.merge_mode is None:
output_shape = [output_shape, copy.copy(output_shape)]
if self.return_state:
if self.merge_mode is None:
return output_shape + state_shape + copy.copy(state_shape)
return [output_shape] + state_shape + copy.copy(state_shape)
return output_shape
def call(
self,
sequences,
initial_state=None,
mask=None,
training=None,
):
kwargs = {}
if self.forward_layer._call_has_training_arg():
kwargs["training"] = training
if self.forward_layer._call_has_mask_arg():
kwargs["mask"] = mask
if initial_state is not None:
# initial_states are not keras tensors, eg eager tensor from np
# array. They are only passed in from kwarg initial_state, and
# should be passed to forward/backward layer via kwarg
# initial_state as well.
forward_inputs, backward_inputs = sequences, sequences
half = len(initial_state) // 2
forward_state = initial_state[:half]
backward_state = initial_state[half:]
else:
forward_inputs, backward_inputs = sequences, sequences
forward_state, backward_state = None, None
y = self.forward_layer(
forward_inputs, initial_state=forward_state, **kwargs
)
y_rev = self.backward_layer(
backward_inputs, initial_state=backward_state, **kwargs
)
if self.return_state:
states = y[1:] + y_rev[1:]
y = y[0]
y_rev = y_rev[0]
if self.return_sequences:
y_rev = ops.flip(y_rev, axis=1)
if self.merge_mode == "concat":
output = ops.concatenate([y, y_rev], axis=-1)
elif self.merge_mode == "sum":
output = y + y_rev
elif self.merge_mode == "ave":
output = (y + y_rev) / 2
elif self.merge_mode == "mul":
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
else:
raise ValueError(
"Unrecognized value for `merge_mode`. "
f"Received: {self.merge_mode}"
'Expected one of {"concat", "sum", "ave", "mul"}.'
)
if self.return_state:
if self.merge_mode is None:
return output + states
return [output] + states
return output
def reset_states(self):
# Compatibility alias.
self.reset_state()
def reset_state(self):
if not self.stateful:
raise AttributeError("Layer must be stateful.")
self.forward_layer.reset_state()
self.backward_layer.reset_state()
def build(self, sequences_shape, initial_state_shape=None):
self.forward_layer.build(sequences_shape)
self.backward_layer.build(sequences_shape)
self.built = True
def compute_mask(self, _, mask):
if isinstance(mask, list):
mask = mask[0]
if self.return_sequences:
if not self.merge_mode:
output_mask = [mask, mask]
else:
output_mask = mask
else:
output_mask = [None, None] if not self.merge_mode else None
if self.return_state:
states = self.forward_layer.states
state_mask = [None for _ in states]
if isinstance(output_mask, list):
return output_mask + state_mask * 2
return [output_mask] + state_mask * 2
return output_mask
def get_config(self):
config = {"merge_mode": self.merge_mode}
config["layer"] = serialization_lib.serialize_keras_object(
self.forward_layer
)
config["backward_layer"] = serialization_lib.serialize_keras_object(
self.backward_layer
)
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
# Instead of updating the input, create a copy and use that.
config = copy.deepcopy(config)
config["layer"] = serialization_lib.deserialize_keras_object(
config["layer"], custom_objects=custom_objects
)
# Handle (optional) backward layer instantiation.
backward_layer_config = config.pop("backward_layer", None)
if backward_layer_config is not None:
backward_layer = serialization_lib.deserialize_keras_object(
backward_layer_config, custom_objects=custom_objects
)
config["backward_layer"] = backward_layer
# Instantiate the wrapper, adjust it and return it.
layer = cls(**config)
return layer