diff --git a/keras_core/layers/__init__.py b/keras_core/layers/__init__.py index 1a44931a3..554bd93ea 100644 --- a/keras_core/layers/__init__.py +++ b/keras_core/layers/__init__.py @@ -1,4 +1,5 @@ from keras_core.layers.activations.activation import Activation +from keras_core.layers.attention.attention import Attention from keras_core.layers.convolutional.conv1d import Conv1D from keras_core.layers.convolutional.conv2d import Conv2D from keras_core.layers.convolutional.conv3d import Conv3D diff --git a/keras_core/layers/attention/attention.py b/keras_core/layers/attention/attention.py new file mode 100644 index 000000000..caebe19a3 --- /dev/null +++ b/keras_core/layers/attention/attention.py @@ -0,0 +1,275 @@ +from keras_core import backend +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.layers.layer import Layer + + +@keras_core_export("keras_core.layers.Attention") +class Attention(Layer): + """Dot-product attention layer, a.k.a. Luong-style attention. + + Inputs are a list with 2 or 3 elements: + 1. A query tensor of shape `(batch_size, Tq, dim)`. + 2. A value tensor of shape `(batch_size, Tv, dim)`. + 3. A optional key tensor of shape `(batch_size, Tv, dim)`. If none supplied, + the value tensor will be used as a key. + + The calculation follows the steps: + 1. Calculate attention scores using query and key with shape + `(batch_size, Tq, Tv)`. + 2. Use scores to calculate a softmax distribution with shape + `(batch_size, Tq, Tv)`. + 3. Use the softmax distribution to create a linear combination of value + with shape `(batch_size, Tq, dim)`. + + Args: + use_scale: If `True`, will create a scalar variable to scale the + attention scores. + dropout: Float between 0 and 1. Fraction of the units to drop for the + attention scores. Defaults to 0.0. + score_mode: Function to use to compute attention scores, one of + `{"dot", "concat"}`. `"dot"` refers to the dot product between the + query and key vectors. `"concat"` refers to the hyperbolic tangent + of the concatenation of the query and key vectors. + + Call Args: + inputs: List of the following tensors: + - query: Query `Tensor` of shape `(batch_size, Tq, dim)`. + - value: Value `Tensor` of shape `(batch_size, Tv, dim)`. + - key: Optional key `Tensor` of shape `(batch_size, Tv, dim)`. If + not given, will use `value` for both `key` and `value`, which is + the most common case. + mask: List of the following tensors: + - query_mask: A boolean mask `Tensor` of shape `(batch_size, Tq)`. + If given, the output will be zero at the positions where + `mask==False`. + - value_mask: A boolean mask `Tensor` of shape `(batch_size, Tv)`. + If given, will apply the mask such that values at positions + where `mask==False` do not contribute to the result. + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds + a mask such that position `i` cannot attend to positions `j > i`. + This prevents the flow of information from the future towards the + past. Defaults to `False`. + + Output: + Attention outputs of shape `(batch_size, Tq, dim)`. + (Optional) Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + + def __init__( + self, + use_scale=False, + score_mode="dot", + dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.use_scale = use_scale + self.score_mode = score_mode + self.dropout = dropout + if self.score_mode not in ["dot", "concat"]: + raise ValueError( + "Invalid value for argument score_mode. " + "Expected one of {'dot', 'concat'}. " + f"Received: score_mode={score_mode}" + ) + + def build(self, input_shape): + self.scale = None + self.concat_score_weight = None + if self.use_scale: + self.scale = self.add_weight( + name="scale", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + if self.score_mode == "concat": + self.concat_score_weight = self.add_weight( + name="concat_score_weight", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + self.built = True + + def _calculate_scores(self, query, key): + """Calculates attention scores as a query-key dot product. + + Args: + query: Query tensor of shape `(batch_size, Tq, dim)`. + key: Key tensor of shape `(batch_size, Tv, dim)`. + + Returns: + Tensor of shape `(batch_size, Tq, Tv)`. + """ + if self.score_mode == "dot": + scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1])) + if self.scale is not None: + scores *= self.scale + elif self.score_mode == "concat": + # Reshape tensors to enable broadcasting. + # Reshape into [batch_size, Tq, 1, dim]. + q_reshaped = ops.expand_dims(query, axis=-2) + # Reshape into [batch_size, 1, Tv, dim]. + k_reshaped = ops.expand_dims(key, axis=-3) + if self.scale is not None: + scores = self.concat_score_weight * ops.sum( + ops.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1 + ) + else: + scores = self.concat_score_weight * ops.sum( + ops.tanh(q_reshaped + k_reshaped), axis=-1 + ) + + return scores + + def _apply_scores(self, scores, value, scores_mask=None, training=False): + """Applies attention scores to the given value tensor. + + To use this method in your attention layer, follow the steps: + + * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of + shape `(batch_size, Tv)` to calculate the attention `scores`. + * Pass `scores` and `value` tensors to this method. The method applies + `scores_mask`, calculates + `attention_distribution = softmax(scores)`, then returns + `matmul(attention_distribution, value). + * Apply `query_mask` and return the result. + + Args: + scores: Scores float tensor of shape `(batch_size, Tq, Tv)`. + value: Value tensor of shape `(batch_size, Tv, dim)`. + scores_mask: A boolean mask `Tensor` of shape `(batch_size, 1, Tv)` + or `(batch_size, Tq, Tv)`. If given, scores at positions where + `scores_mask==False` do not contribute to the result. It must + contain at least one `True` value in each line along the last + dimension. + training: Python boolean indicating whether the layer should behave + in training mode (adding dropout) or in inference mode + (no dropout). + + Returns: + Tensor of shape `(batch_size, Tq, dim)`. + Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + if scores_mask is not None: + padding_mask = ops.logical_not(scores_mask) + # Bias so padding positions do not contribute to attention + # distribution. Note 65504. is the max float16 value. + max_value = 65504.0 if scores.dtype == "float16" else 1.0e9 + scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype) + + weights = ops.softmax(scores, axis=-1) + if training and self.dropout > 0: + weights = backend.random.dropout( + weights, + self.dropout, + noise_shape=self.noise_shape, + seed=self.seed_generator, + ) + return ops.matmul(weights, value), weights + + def _calculate_score_mask(self, scores, v_mask, use_causal_mask): + if v_mask is not None: + # Mask of shape [batch_size, 1, Tv]. + v_mask = ops.expand_dims(v_mask, axis=-2) + if not use_causal_mask: + return v_mask + + # Creates a lower triangular mask, so position i cannot attend to + # positions j>i. This prevents the flow of information from the + # future into the past. + score_shape = ops.shape(scores) + # causal_mask_shape = [1, Tq, Tv]. + mask_shape = (1, score_shape[-2], score_shape[-1]) + ones_mask = ops.ones(shape=mask_shape, dtype="int32") + row_index = ops.cumsum(ones_mask, axis=-2) + col_index = ops.cumsum(ones_mask, axis=-1) + causal_mask = ops.greater_equal(row_index, col_index) + + if v_mask is not None: + return causal_mask + return ops.logical_and(v_mask, causal_mask) + + def call( + self, + inputs, + mask=None, + training=False, + return_attention_scores=False, + use_causal_mask=False, + ): + self._validate_call_args(inputs=inputs, mask=mask) + q = inputs[0] + v = inputs[1] + k = inputs[2] if len(inputs) > 2 else v + q_mask = mask[0] if mask else None + v_mask = mask[1] if mask else None + scores = self._calculate_scores(query=q, key=k) + scores_mask = self._calculate_score_mask( + scores, v_mask, use_causal_mask + ) + result, attention_scores = self._apply_scores( + scores=scores, value=v, scores_mask=scores_mask, training=training + ) + if q_mask is not None: + # Mask of shape [batch_size, Tq, 1]. + q_mask = ops.expand_dims(q_mask, axis=-1) + result *= ops.cast(q_mask, dtype=result.dtype) + if return_attention_scores: + return result, attention_scores + return result + + def compute_mask(self, inputs, mask=None): + self._validate_call_args(inputs=inputs, mask=mask) + if mask is None or mask[0] is None: + return None + return ops.convert_to_tensor(mask[0]) + + def compute_output_shape(self, input_shape): + return input_shape[0] + + def _validate_call_args(self, inputs, mask): + """Validates arguments of the call method.""" + class_name = self.__class__.__name__ + if not isinstance(inputs, list): + raise ValueError( + f"{class_name} layer must be called on a list of inputs, " + "namely [query, value] or [query, value, key]. " + f"Received: inputs={inputs}." + ) + if len(inputs) < 2 or len(inputs) > 3: + raise ValueError( + f"{class_name} layer accepts inputs list of length 2 or 3, " + "namely [query, value] or [query, value, key]. " + f"Received length: {len(inputs)}." + ) + if mask is not None: + if not isinstance(mask, list): + raise ValueError( + f"{class_name} layer mask must be a list, " + f"namely [query_mask, value_mask]. Received: mask={mask}." + ) + if len(mask) < 2 or len(mask) > 3: + raise ValueError( + f"{class_name} layer accepts mask list of length 2 or 3. " + f"Received: inputs={inputs}, mask={mask}." + ) + + def get_config(self): + base_config = super().get_config() + config = { + "use_scale": self.use_scale, + "score_mode": self.score_mode, + "dropout": self.dropout, + } + return {**base_config, **config} diff --git a/keras_core/layers/attention/attention_test.py b/keras_core/layers/attention/attention_test.py new file mode 100644 index 000000000..c60facae0 --- /dev/null +++ b/keras_core/layers/attention/attention_test.py @@ -0,0 +1,99 @@ +import numpy as np + +from keras_core import layers +from keras_core import testing + + +class DenseTest(testing.TestCase): + def test_attention_basics(self): + # No scale, no concat. + self.run_layer_test( + layers.Attention, + init_kwargs={ + "score_mode": "dot", + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # Sale and concat. + self.run_layer_test( + layers.Attention, + init_kwargs={ + "use_scale": True, + "score_mode": "concat", + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_attention_correctness(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + + # Dot. + layer = layers.Attention(score_mode="dot") + output, scores = layer( + [query, value, key], + return_attention_scores=True, + ) + self.assertAllClose( + output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3 + ) + self.assertAllClose( + scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3 + ) + + # Concat. + layer = layers.Attention(score_mode="concat") + output, scores = layer( + [query, value, key], + return_attention_scores=True, + ) + self.assertAllClose( + output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3 + ) + self.assertAllClose( + scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3 + ) + + def test_attention_with_mask(self): + layer = layers.Attention() + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + value = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + query_mask = np.array([[True, False]]) + value_mask = np.array([[True, False]]) + output, scores = layer( + [query, value], + mask=[query_mask, value_mask], + return_attention_scores=True, + ) + self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]]) + self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]]) + + def test_attention_errors(self): + layer = layers.Attention() + tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + with self.assertRaisesRegex(ValueError, "must be called on a list"): + layer(tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor, tensor, tensor]) + + with self.assertRaisesRegex(ValueError, "layer mask must be a list"): + layer([tensor, tensor], mask=tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor], mask=[tensor])