from keras_core import backend from keras_core import operations as ops from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer @keras_core_export("keras_core.layers.Attention") class Attention(Layer): """Dot-product attention layer, a.k.a. Luong-style attention. Inputs are a list with 2 or 3 elements: 1. A `query` tensor of shape `(batch_size, Tq, dim)`. 2. A `value` tensor of shape `(batch_size, Tv, dim)`. 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none supplied, `value` will be used as a `key`. The calculation follows the steps: 1. Calculate attention scores using `query` and `key` with shape `(batch_size, Tq, Tv)`. 2. Use scores to calculate a softmax distribution with shape `(batch_size, Tq, Tv)`. 3. Use the softmax distribution to create a linear combination of `value` with shape `(batch_size, Tq, dim)`. Args: use_scale: If `True`, will create a scalar variable to scale the attention scores. dropout: Float between 0 and 1. Fraction of the units to drop for the attention scores. Defaults to 0.0. score_mode: Function to use to compute attention scores, one of `{"dot", "concat"}`. `"dot"` refers to the dot product between the query and key vectors. `"concat"` refers to the hyperbolic tangent of the concatenation of the `query` and `key` vectors. Call Args: inputs: List of the following tensors: - `query`: Query tensor of shape `(batch_size, Tq, dim)`. - `value`: Value tensor of shape `(batch_size, Tv, dim)`. - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If not given, will use `value` for both `key` and `value`, which is the most common case. mask: List of the following tensors: - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`. If given, the output will be zero at the positions where `mask==False`. - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`. If given, will apply the mask such that values at positions where `mask==False` do not contribute to the result. return_attention_scores: bool, it `True`, returns the attention scores (after masking and softmax) as an additional output argument. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds a mask such that position `i` cannot attend to positions `j > i`. This prevents the flow of information from the future towards the past. Defaults to `False`. Output: Attention outputs of shape `(batch_size, Tq, dim)`. (Optional) Attention scores after masking and softmax with shape `(batch_size, Tq, Tv)`. """ def __init__( self, use_scale=False, score_mode="dot", dropout=0.0, **kwargs, ): super().__init__(**kwargs) self.use_scale = use_scale self.score_mode = score_mode self.dropout = dropout if self.score_mode not in ["dot", "concat"]: raise ValueError( "Invalid value for argument score_mode. " "Expected one of {'dot', 'concat'}. " f"Received: score_mode={score_mode}" ) def build(self, input_shape): self._validate_inputs(input_shape) self.scale = None self.concat_score_weight = None if self.use_scale: self.scale = self.add_weight( name="scale", shape=(), initializer="ones", dtype=self.dtype, trainable=True, ) if self.score_mode == "concat": self.concat_score_weight = self.add_weight( name="concat_score_weight", shape=(), initializer="ones", dtype=self.dtype, trainable=True, ) self.built = True def _calculate_scores(self, query, key): """Calculates attention scores as a query-key dot product. Args: query: Query tensor of shape `(batch_size, Tq, dim)`. key: Key tensor of shape `(batch_size, Tv, dim)`. Returns: Tensor of shape `(batch_size, Tq, Tv)`. """ if self.score_mode == "dot": scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1])) if self.scale is not None: scores *= self.scale elif self.score_mode == "concat": # Reshape tensors to enable broadcasting. # Reshape into [batch_size, Tq, 1, dim]. q_reshaped = ops.expand_dims(query, axis=-2) # Reshape into [batch_size, 1, Tv, dim]. k_reshaped = ops.expand_dims(key, axis=-3) if self.scale is not None: scores = self.concat_score_weight * ops.sum( ops.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1 ) else: scores = self.concat_score_weight * ops.sum( ops.tanh(q_reshaped + k_reshaped), axis=-1 ) return scores def _apply_scores(self, scores, value, scores_mask=None, training=False): """Applies attention scores to the given value tensor. To use this method in your attention layer, follow the steps: * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of shape `(batch_size, Tv)` to calculate the attention `scores`. * Pass `scores` and `value` tensors to this method. The method applies `scores_mask`, calculates `attention_distribution = softmax(scores)`, then returns `matmul(attention_distribution, value). * Apply `query_mask` and return the result. Args: scores: Scores float tensor of shape `(batch_size, Tq, Tv)`. value: Value tensor of shape `(batch_size, Tv, dim)`. scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)` or `(batch_size, Tq, Tv)`. If given, scores at positions where `scores_mask==False` do not contribute to the result. It must contain at least one `True` value in each line along the last dimension. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). Returns: Tensor of shape `(batch_size, Tq, dim)`. Attention scores after masking and softmax with shape `(batch_size, Tq, Tv)`. """ if scores_mask is not None: padding_mask = ops.logical_not(scores_mask) # Bias so padding positions do not contribute to attention # distribution. Note 65504. is the max float16 value. max_value = 65504.0 if scores.dtype == "float16" else 1.0e9 scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype) weights = ops.softmax(scores, axis=-1) if training and self.dropout > 0: weights = backend.random.dropout( weights, self.dropout, noise_shape=self.noise_shape, seed=self.seed_generator, ) return ops.matmul(weights, value), weights def _calculate_score_mask(self, scores, v_mask, use_causal_mask): if v_mask is not None: # Mask of shape [batch_size, 1, Tv]. v_mask = ops.expand_dims(v_mask, axis=-2) if not use_causal_mask: return v_mask # Creates a lower triangular mask, so position i cannot attend to # positions j>i. This prevents the flow of information from the # future into the past. score_shape = ops.shape(scores) # causal_mask_shape = [1, Tq, Tv]. mask_shape = (1, score_shape[-2], score_shape[-1]) ones_mask = ops.ones(shape=mask_shape, dtype="int32") row_index = ops.cumsum(ones_mask, axis=-2) col_index = ops.cumsum(ones_mask, axis=-1) causal_mask = ops.greater_equal(row_index, col_index) if v_mask is not None: return causal_mask return ops.logical_and(v_mask, causal_mask) def call( self, inputs, mask=None, training=False, return_attention_scores=False, use_causal_mask=False, ): self._validate_inputs(inputs=inputs, mask=mask) q = inputs[0] v = inputs[1] k = inputs[2] if len(inputs) > 2 else v q_mask = mask[0] if mask else None v_mask = mask[1] if mask else None scores = self._calculate_scores(query=q, key=k) scores_mask = self._calculate_score_mask( scores, v_mask, use_causal_mask ) result, attention_scores = self._apply_scores( scores=scores, value=v, scores_mask=scores_mask, training=training ) if q_mask is not None: # Mask of shape [batch_size, Tq, 1]. q_mask = ops.expand_dims(q_mask, axis=-1) result *= ops.cast(q_mask, dtype=result.dtype) if return_attention_scores: return result, attention_scores return result def compute_mask(self, inputs, mask=None): self._validate_inputs(inputs=inputs, mask=mask) if mask is None or mask[0] is None: return None return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): return input_shape[0] def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" class_name = self.__class__.__name__ if not isinstance(inputs, list): raise ValueError( f"{class_name} layer must be called on a list of inputs, " "namely [query, value] or [query, value, key]. " f"Received: inputs={inputs}." ) if len(inputs) < 2 or len(inputs) > 3: raise ValueError( f"{class_name} layer accepts inputs list of length 2 or 3, " "namely [query, value] or [query, value, key]. " f"Received length: {len(inputs)}." ) if mask is not None: if not isinstance(mask, list): raise ValueError( f"{class_name} layer mask must be a list, " f"namely [query_mask, value_mask]. Received: mask={mask}." ) if len(mask) < 2 or len(mask) > 3: raise ValueError( f"{class_name} layer accepts mask list of length 2 or 3. " f"Received: inputs={inputs}, mask={mask}." ) def get_config(self): base_config = super().get_config() config = { "use_scale": self.use_scale, "score_mode": self.score_mode, "dropout": self.dropout, } return {**base_config, **config}