Add Attention layer (#122)

* Add keras_core.layers.Attention

* Address style comments

* Fix jax
This commit is contained in:
Matt Watson 2023-05-09 16:25:50 -07:00 committed by Francois Chollet
parent 3873029035
commit c2623e2e98
3 changed files with 375 additions and 0 deletions

@ -1,4 +1,5 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.attention.attention import Attention
from keras_core.layers.convolutional.conv1d import Conv1D
from keras_core.layers.convolutional.conv2d import Conv2D
from keras_core.layers.convolutional.conv3d import Conv3D

@ -0,0 +1,275 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Attention")
class Attention(Layer):
"""Dot-product attention layer, a.k.a. Luong-style attention.
Inputs are a list with 2 or 3 elements:
1. A query tensor of shape `(batch_size, Tq, dim)`.
2. A value tensor of shape `(batch_size, Tv, dim)`.
3. A optional key tensor of shape `(batch_size, Tv, dim)`. If none supplied,
the value tensor will be used as a key.
The calculation follows the steps:
1. Calculate attention scores using query and key with shape
`(batch_size, Tq, Tv)`.
2. Use scores to calculate a softmax distribution with shape
`(batch_size, Tq, Tv)`.
3. Use the softmax distribution to create a linear combination of value
with shape `(batch_size, Tq, dim)`.
Args:
use_scale: If `True`, will create a scalar variable to scale the
attention scores.
dropout: Float between 0 and 1. Fraction of the units to drop for the
attention scores. Defaults to 0.0.
score_mode: Function to use to compute attention scores, one of
`{"dot", "concat"}`. `"dot"` refers to the dot product between the
query and key vectors. `"concat"` refers to the hyperbolic tangent
of the concatenation of the query and key vectors.
Call Args:
inputs: List of the following tensors:
- query: Query `Tensor` of shape `(batch_size, Tq, dim)`.
- value: Value `Tensor` of shape `(batch_size, Tv, dim)`.
- key: Optional key `Tensor` of shape `(batch_size, Tv, dim)`. If
not given, will use `value` for both `key` and `value`, which is
the most common case.
mask: List of the following tensors:
- query_mask: A boolean mask `Tensor` of shape `(batch_size, Tq)`.
If given, the output will be zero at the positions where
`mask==False`.
- value_mask: A boolean mask `Tensor` of shape `(batch_size, Tv)`.
If given, will apply the mask such that values at positions
where `mask==False` do not contribute to the result.
return_attention_scores: bool, it `True`, returns the attention scores
(after masking and softmax) as an additional output argument.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
a mask such that position `i` cannot attend to positions `j > i`.
This prevents the flow of information from the future towards the
past. Defaults to `False`.
Output:
Attention outputs of shape `(batch_size, Tq, dim)`.
(Optional) Attention scores after masking and softmax with shape
`(batch_size, Tq, Tv)`.
"""
def __init__(
self,
use_scale=False,
score_mode="dot",
dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.use_scale = use_scale
self.score_mode = score_mode
self.dropout = dropout
if self.score_mode not in ["dot", "concat"]:
raise ValueError(
"Invalid value for argument score_mode. "
"Expected one of {'dot', 'concat'}. "
f"Received: score_mode={score_mode}"
)
def build(self, input_shape):
self.scale = None
self.concat_score_weight = None
if self.use_scale:
self.scale = self.add_weight(
name="scale",
shape=(),
initializer="ones",
dtype=self.dtype,
trainable=True,
)
if self.score_mode == "concat":
self.concat_score_weight = self.add_weight(
name="concat_score_weight",
shape=(),
initializer="ones",
dtype=self.dtype,
trainable=True,
)
self.built = True
def _calculate_scores(self, query, key):
"""Calculates attention scores as a query-key dot product.
Args:
query: Query tensor of shape `(batch_size, Tq, dim)`.
key: Key tensor of shape `(batch_size, Tv, dim)`.
Returns:
Tensor of shape `(batch_size, Tq, Tv)`.
"""
if self.score_mode == "dot":
scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))
if self.scale is not None:
scores *= self.scale
elif self.score_mode == "concat":
# Reshape tensors to enable broadcasting.
# Reshape into [batch_size, Tq, 1, dim].
q_reshaped = ops.expand_dims(query, axis=-2)
# Reshape into [batch_size, 1, Tv, dim].
k_reshaped = ops.expand_dims(key, axis=-3)
if self.scale is not None:
scores = self.concat_score_weight * ops.sum(
ops.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1
)
else:
scores = self.concat_score_weight * ops.sum(
ops.tanh(q_reshaped + k_reshaped), axis=-1
)
return scores
def _apply_scores(self, scores, value, scores_mask=None, training=False):
"""Applies attention scores to the given value tensor.
To use this method in your attention layer, follow the steps:
* Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of
shape `(batch_size, Tv)` to calculate the attention `scores`.
* Pass `scores` and `value` tensors to this method. The method applies
`scores_mask`, calculates
`attention_distribution = softmax(scores)`, then returns
`matmul(attention_distribution, value).
* Apply `query_mask` and return the result.
Args:
scores: Scores float tensor of shape `(batch_size, Tq, Tv)`.
value: Value tensor of shape `(batch_size, Tv, dim)`.
scores_mask: A boolean mask `Tensor` of shape `(batch_size, 1, Tv)`
or `(batch_size, Tq, Tv)`. If given, scores at positions where
`scores_mask==False` do not contribute to the result. It must
contain at least one `True` value in each line along the last
dimension.
training: Python boolean indicating whether the layer should behave
in training mode (adding dropout) or in inference mode
(no dropout).
Returns:
Tensor of shape `(batch_size, Tq, dim)`.
Attention scores after masking and softmax with shape
`(batch_size, Tq, Tv)`.
"""
if scores_mask is not None:
padding_mask = ops.logical_not(scores_mask)
# Bias so padding positions do not contribute to attention
# distribution. Note 65504. is the max float16 value.
max_value = 65504.0 if scores.dtype == "float16" else 1.0e9
scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype)
weights = ops.softmax(scores, axis=-1)
if training and self.dropout > 0:
weights = backend.random.dropout(
weights,
self.dropout,
noise_shape=self.noise_shape,
seed=self.seed_generator,
)
return ops.matmul(weights, value), weights
def _calculate_score_mask(self, scores, v_mask, use_causal_mask):
if v_mask is not None:
# Mask of shape [batch_size, 1, Tv].
v_mask = ops.expand_dims(v_mask, axis=-2)
if not use_causal_mask:
return v_mask
# Creates a lower triangular mask, so position i cannot attend to
# positions j>i. This prevents the flow of information from the
# future into the past.
score_shape = ops.shape(scores)
# causal_mask_shape = [1, Tq, Tv].
mask_shape = (1, score_shape[-2], score_shape[-1])
ones_mask = ops.ones(shape=mask_shape, dtype="int32")
row_index = ops.cumsum(ones_mask, axis=-2)
col_index = ops.cumsum(ones_mask, axis=-1)
causal_mask = ops.greater_equal(row_index, col_index)
if v_mask is not None:
return causal_mask
return ops.logical_and(v_mask, causal_mask)
def call(
self,
inputs,
mask=None,
training=False,
return_attention_scores=False,
use_causal_mask=False,
):
self._validate_call_args(inputs=inputs, mask=mask)
q = inputs[0]
v = inputs[1]
k = inputs[2] if len(inputs) > 2 else v
q_mask = mask[0] if mask else None
v_mask = mask[1] if mask else None
scores = self._calculate_scores(query=q, key=k)
scores_mask = self._calculate_score_mask(
scores, v_mask, use_causal_mask
)
result, attention_scores = self._apply_scores(
scores=scores, value=v, scores_mask=scores_mask, training=training
)
if q_mask is not None:
# Mask of shape [batch_size, Tq, 1].
q_mask = ops.expand_dims(q_mask, axis=-1)
result *= ops.cast(q_mask, dtype=result.dtype)
if return_attention_scores:
return result, attention_scores
return result
def compute_mask(self, inputs, mask=None):
self._validate_call_args(inputs=inputs, mask=mask)
if mask is None or mask[0] is None:
return None
return ops.convert_to_tensor(mask[0])
def compute_output_shape(self, input_shape):
return input_shape[0]
def _validate_call_args(self, inputs, mask):
"""Validates arguments of the call method."""
class_name = self.__class__.__name__
if not isinstance(inputs, list):
raise ValueError(
f"{class_name} layer must be called on a list of inputs, "
"namely [query, value] or [query, value, key]. "
f"Received: inputs={inputs}."
)
if len(inputs) < 2 or len(inputs) > 3:
raise ValueError(
f"{class_name} layer accepts inputs list of length 2 or 3, "
"namely [query, value] or [query, value, key]. "
f"Received length: {len(inputs)}."
)
if mask is not None:
if not isinstance(mask, list):
raise ValueError(
f"{class_name} layer mask must be a list, "
f"namely [query_mask, value_mask]. Received: mask={mask}."
)
if len(mask) < 2 or len(mask) > 3:
raise ValueError(
f"{class_name} layer accepts mask list of length 2 or 3. "
f"Received: inputs={inputs}, mask={mask}."
)
def get_config(self):
base_config = super().get_config()
config = {
"use_scale": self.use_scale,
"score_mode": self.score_mode,
"dropout": self.dropout,
}
return {**base_config, **config}

@ -0,0 +1,99 @@
import numpy as np
from keras_core import layers
from keras_core import testing
class DenseTest(testing.TestCase):
def test_attention_basics(self):
# No scale, no concat.
self.run_layer_test(
layers.Attention,
init_kwargs={
"score_mode": "dot",
"dropout": 0.5,
},
input_shape=[(2, 3, 4), (2, 4, 4)],
expected_output_shape=(2, 3, 4),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
# Sale and concat.
self.run_layer_test(
layers.Attention,
init_kwargs={
"use_scale": True,
"score_mode": "concat",
"dropout": 0.5,
},
input_shape=[(2, 3, 4), (2, 4, 4)],
expected_output_shape=(2, 3, 4),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
def test_attention_correctness(self):
query = np.array([[[1.0, 0.0], [0.0, 1.0]]])
key = np.array([[[0.0, 1.0], [1.0, 0.0]]])
value = np.array([[[1.0, 2.0], [3.0, 4.0]]])
# Dot.
layer = layers.Attention(score_mode="dot")
output, scores = layer(
[query, value, key],
return_attention_scores=True,
)
self.assertAllClose(
output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3
)
self.assertAllClose(
scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3
)
# Concat.
layer = layers.Attention(score_mode="concat")
output, scores = layer(
[query, value, key],
return_attention_scores=True,
)
self.assertAllClose(
output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3
)
self.assertAllClose(
scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3
)
def test_attention_with_mask(self):
layer = layers.Attention()
query = np.array([[[1.0, 0.0], [0.0, 1.0]]])
value = np.array([[[1.0, 1.0], [1.0, 1.0]]])
query_mask = np.array([[True, False]])
value_mask = np.array([[True, False]])
output, scores = layer(
[query, value],
mask=[query_mask, value_mask],
return_attention_scores=True,
)
self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]])
self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]])
def test_attention_errors(self):
layer = layers.Attention()
tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]])
with self.assertRaisesRegex(ValueError, "must be called on a list"):
layer(tensor)
with self.assertRaisesRegex(ValueError, "length 2 or 3"):
layer([tensor, tensor, tensor, tensor])
with self.assertRaisesRegex(ValueError, "layer mask must be a list"):
layer([tensor, tensor], mask=tensor)
with self.assertRaisesRegex(ValueError, "length 2 or 3"):
layer([tensor, tensor], mask=[tensor])