Add layers.MultiHeadAttention (#169)

* Add layers.MultiHeadAttention

* Update build/compute_output_signature argument checks

We do not definitively know which arguments are tensor arguments
at any given invocation (e.g. arguments with a None value may be tensor
arguments).

So rather than check that the build signature matches perfectly with
tensor call arguments, we will check the build signature arguments match
with some call argument.
This commit is contained in:
Matt Watson 2023-05-16 18:22:42 -07:00 committed by Francois Chollet
parent 3eaa2675df
commit cc053ac309
8 changed files with 1001 additions and 64 deletions

@ -1,6 +1,7 @@
from keras_core.layers.activations.activation import Activation
from keras_core.layers.attention.additive_attention import AdditiveAttention
from keras_core.layers.attention.attention import Attention
from keras_core.layers.attention.multi_head_attention import MultiHeadAttention
from keras_core.layers.convolutional.conv1d import Conv1D
from keras_core.layers.convolutional.conv1d_transpose import Conv1DTranspose
from keras_core.layers.convolutional.conv2d import Conv2D

@ -0,0 +1,649 @@
import collections
import math
import string
import numpy as np
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.layers.activations.softmax import Softmax
from keras_core.layers.core.einsum_dense import EinsumDense
from keras_core.layers.layer import Layer
from keras_core.layers.regularization.dropout import Dropout
@keras_core_export("keras_core.layers.Attention")
class MultiHeadAttention(Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention as described in the
paper "Attention is all you Need"
[Vaswani et al., 2017](https://arxiv.org/abs/1706.03762).
If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `key`, and returns a fixed-width vector.
This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are `(batch_size, <query dimensions>, key_dim)`,
`(batch_size, <key/value dimensions>, key_dim)`,
`(batch_size, <key/value dimensions>, value_dim)`.
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor.
Finally, the result tensor with the last dimension as `value_dim` can take
a linear projection and return.
Args:
num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch
and sequence dims. If not specified, projects back to the query
feature dim (the query input's last dimension).
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call arguments:
query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
`T` is the target sequence length, and dim is the feature dimension.
value: Value tensor of shape `(B, S, dim)`, where `B` is the batch size,
`S` is the source sequence length, and dim is the feature dimension.
key: Optional key tensor of shape `(B, S, dim)`. If not given, will
use `value` for both `key` and `value`, which is the most common
case.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.
return_attention_scores: A boolean to indicate whether the output should
be `(attention_output, attention_scores)` if `True`, or
`attention_output` if `False`. Defaults to `False`.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Will go with either using the training mode of the parent
layer/model, or `False` (inference) if there is no parent layer.
use_causal_mask: A boolean to indicate whether to apply a causal mask to
prevent tokens from attending to future tokens (e.g., used in a
decoder Transformer).
Returns:
attention_output: The result of the computation, of shape `(B, T, E)`,
where `T` is for target sequence shapes and `E` is the query input
last dimension if `output_shape` is `None`. Otherwise, the
multi-head outputs are projected to the shape specified by
`output_shape`.
attention_scores: (Optional) multi-head attention coefficients over
attention axes.
"""
def __init__(
self,
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs,
):
super().__init__(**kwargs)
self.supports_masking = True
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
self._output_shape = output_shape
self._kernel_initializer = initializers.get(kernel_initializer)
self._bias_initializer = initializers.get(bias_initializer)
self._kernel_regularizer = regularizers.get(kernel_regularizer)
self._bias_regularizer = regularizers.get(bias_regularizer)
self._activity_regularizer = regularizers.get(activity_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
if isinstance(attention_axes, int):
attention_axes = (attention_axes,)
elif attention_axes and not isinstance(attention_axes, (list, tuple)):
raise ValueError(
"`attention_axes` must be an int, list, or tuple."
f"Received: attention_axes={attention_axes}"
)
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
base_config = super().get_config()
config = {
"num_heads": self._num_heads,
"key_dim": self._key_dim,
"value_dim": self._value_dim,
"dropout": self._dropout,
"use_bias": self._use_bias,
"output_shape": self._output_shape,
"attention_axes": self._attention_axes,
"kernel_initializer": initializers.serialize(
self._kernel_initializer
),
"bias_initializer": initializers.serialize(self._bias_initializer),
"kernel_regularizer": regularizers.serialize(
self._kernel_regularizer
),
"bias_regularizer": regularizers.serialize(self._bias_regularizer),
"activity_regularizer": regularizers.serialize(
self._activity_regularizer
),
"kernel_constraint": constraints.serialize(self._kernel_constraint),
"bias_constraint": constraints.serialize(self._bias_constraint),
}
return {**base_config, **config}
def build(
self,
query_shape,
value_shape,
key_shape=None,
):
"""Builds layers and variables.
Args:
query_shape: Shape of the `query` tensor.
value_shape: Shape of the `value` tensor.
key: Optional shape of the `key` tensor.
"""
key_shape = value_shape if key_shape is None else key_shape
query_rank = len(query_shape)
value_rank = len(value_shape)
key_rank = len(key_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
query_rank - 1, bound_dims=1, output_dims=2
)
self._query_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**self._get_common_kwargs_for_sublayer(),
)
self._query_dense.build(query_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_rank - 1, bound_dims=1, output_dims=2
)
self._key_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_rank - 1, bound_dims=1, output_dims=2
)
self._value_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._value_dim]
),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**self._get_common_kwargs_for_sublayer(),
)
self._value_dense.build(value_shape)
# Builds the attention computations for multi-head dot product
# attention. These computations could be wrapped into the keras
# attention layer once it supports multi-head einsum computations.
self._build_attention(output_rank)
self._output_dense = self._make_output_dense(
query_shape,
self._get_common_kwargs_for_sublayer(),
"attention_output",
)
output_dense_input_shape = list(
self._query_dense.compute_output_shape(query_shape)
)
output_dense_input_shape[-1] = self._value_dim
self._output_dense.build(tuple(output_dense_input_shape))
self.built = True
def _get_common_kwargs_for_sublayer(self):
common_kwargs = dict(
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
)
# Create new clone of kernel/bias initializer, so that we don't reuse
# the initializer instance, which could lead to same init value since
# initializer is stateless.
kernel_initializer = self._kernel_initializer.__class__.from_config(
self._kernel_initializer.get_config()
)
bias_initializer = self._bias_initializer.__class__.from_config(
self._bias_initializer.get_config()
)
common_kwargs["kernel_initializer"] = kernel_initializer
common_kwargs["bias_initializer"] = bias_initializer
return common_kwargs
def _make_output_dense(self, query_shape, common_kwargs, name=None):
"""Builds the output projection matrix.
Args:
free_dims: Number of free dimensions for einsum equation building.
common_kwargs: Common keyword arguments for einsum layer.
name: Name for the projection layer.
Returns:
Projection layer.
"""
query_rank = len(query_shape)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
query_rank - 1, bound_dims=2, output_dims=len(output_shape)
)
return EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name=name,
**common_kwargs,
)
def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
customize attention computation to replace the default dot-product
attention.
Args:
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
(
self._dot_product_equation,
self._combine_equation,
attn_scores_rank,
) = _build_attention_equation(rank, attn_axes=self._attention_axes)
norm_axes = tuple(
range(
attn_scores_rank - len(self._attention_axes), attn_scores_rank
)
)
self._softmax = Softmax(axis=norm_axes)
self._dropout_layer = Dropout(rate=self._dropout)
def _masked_softmax(self, attention_scores, attention_mask=None):
# Normalize the attention scores to probabilities.
# attention_scores = [B, N, T, S]
if attention_mask is not None:
# The expand dim happens starting from the `num_heads` dimension,
# (<batch_dims>, num_heads, <query_attention_dims,
# key_attention_dims>)
mask_expansion_axis = -len(self._attention_axes) * 2 - 1
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
return self._softmax(attention_scores, mask=attention_mask)
def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for
customized attention implementation.
Args:
query: Projected query tensor of shape `(B, T, N, key_dim)`.
key: Projected key tensor of shape `(B, S, N, key_dim)`.
value: Projected value tensor of shape `(B, S, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions. It is generally not needed if
the `query` and `value` (and/or `key`) are masked.
training: Python boolean indicating whether the layer should behave
in training mode (adding dropout) or in inference mode (doing
nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
# `context_layer` = [B, T, N, H]
attention_output = ops.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores
def call(
self,
query,
value,
key=None,
query_mask=None,
value_mask=None,
key_mask=None,
attention_mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
if key is None:
key = value
attention_mask = self._compute_attention_mask(
query,
value,
query_mask=query_mask,
value_mask=value_mask,
key_mask=key_mask,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training
)
attention_output = self._output_dense(attention_output)
if return_attention_scores:
return attention_output, attention_scores
return attention_output
def _compute_attention_mask(
self,
query,
value,
query_mask=None,
value_mask=None,
key_mask=None,
attention_mask=None,
use_causal_mask=False,
):
"""Computes the attention mask, using the Keras masks of the inputs.
* The `query`'s mask is reshaped from [B, T] to [B, T, 1].
* The `value`'s mask is reshaped from [B, S] to [B, 1, S].
* The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
mask is ignored if `key` is `None` or if `key is value`.
* If `use_causal_mask=True`, then the causal mask is computed. Its shape
is [1, T, S].
All defined masks are merged using a logical AND operation (`&`).
In general, if the `query` and `value` are masked, then there is no need
to define the `attention_mask`.
Args:
query: Projected query tensor of shape `(B, T, N, key_dim)`.
key: Projected key tensor of shape `(B, T, N, key_dim)`.
value: Projected value tensor of shape `(B, T, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
use_causal_mask: A boolean to indicate whether to apply a causal
mask to prevent tokens from attending to future tokens (e.g.,
used in a decoder Transformer).
Returns:
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions, based on the Keras masks of the
`query`, `key`, `value`, and `attention_mask` tensors, and the
causal mask if `use_causal_mask=True`.
"""
auto_mask = None
if query_mask is not None:
query_mask = ops.cast(query_mask, "bool") # defensive casting
# B = batch size, T = max query length
auto_mask = ops.expand_dims(query_mask, -1) # shape is [B, T, 1]
if value_mask is not None:
value_mask = ops.cast(value_mask, "bool") # defensive casting
# B = batch size, S == max value length
mask = ops.expand_dims(value_mask, -2) # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if key_mask is not None:
key_mask = ops.cast(key_mask, "bool") # defensive casting
# B == batch size, S == max key length == max value length
mask = ops.expand_dims(key_mask, -2) # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if use_causal_mask:
# the shape of the causal mask is [1, T, S]
mask = self._compute_causal_mask(query, value)
auto_mask = mask if auto_mask is None else auto_mask & mask
if auto_mask is not None:
# merge attention_mask & automatic mask, to shape [B, T, S]
attention_mask = (
auto_mask
if attention_mask is None
else ops.cast(attention_mask, bool) & auto_mask
)
return attention_mask
def _compute_causal_mask(self, query, value=None):
"""Computes a causal mask (e.g., for masked self-attention layers).
For example, if query and value both contain sequences of length 4,
this function returns a boolean tensor equal to:
```
[[[True, False, False, False],
[True, True, False, False],
[True, True, True, False],
[True, True, True, True]]]
```
Args:
query: query tensor of shape `(B, T, ...)`.
value: value tensor of shape `(B, S, ...)` (optional, defaults to
query).
Returns:
mask: a boolean tensor of shape `(1, T, S)` containing a lower
triangular matrix of shape `(T, S)`.
"""
q_seq_length = ops.shape(query)[1]
v_seq_length = q_seq_length if value is None else ops.shape(value)[1]
ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype="int32")
row_index = ops.cumsum(ones_mask, axis=-2)
col_index = ops.cumsum(ones_mask, axis=-1)
return ops.greater_equal(row_index, col_index)
def compute_output_shape(
self,
query_shape,
value_shape,
key_shape=None,
):
if key_shape is None:
key_shape = value_shape
if query_shape[-1] != value_shape[-1]:
raise ValueError(
"The last dimension of `query_shape` and `value_shape` "
f"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. "
"Received: query_shape={query_shape}, value_shape={value_shape}"
)
if value_shape[1:-1] != key_shape[1:-1]:
raise ValueError(
"All dimensions of `value` and `key`, except the last one, "
f"must be equal. Received: value_shape={value_shape} and "
f"key_shape={key_shape}"
)
if self._output_shape:
return query_shape[:-1] + self._output_shape
return query_shape
def _index_to_einsum_variable(i):
"""Coverts an index to a einsum variable name.
We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'.
"""
return string.ascii_lowercase[i]
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
1. Query-key dot product:
(<batch dims>, <query attention dims>, num_heads, channels),
(<batch dims>, <key attention dims>, num_heads, channels) ->
(<batch dims>, num_heads, <query attention dims>, <key attention dims>)
2. Combination:
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch
dims>, <query attention dims>, num_heads, channels)
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`,
that attention will be applied to.
Returns:
Einsum equations.
"""
target_notation = ""
for i in range(rank):
target_notation += _index_to_einsum_variable(i)
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _index_to_einsum_variable(letter_offset)
letter_offset += 1
product_notation = "".join(
[target_notation[i] for i in batch_dims]
+ [target_notation[i] for i in attn_axes]
+ [source_notation[i] for i in attn_axes]
)
dot_product_equation = "%s,%s->%s" % (
source_notation,
target_notation,
product_notation,
)
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (
product_notation,
source_notation,
target_notation,
)
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _index_to_einsum_variable(i + letter_offset)
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _index_to_einsum_variable(i + letter_offset)
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _index_to_einsum_variable(i + letter_offset)
kernel_str += char
output_str += char
bias_axes += char
equation = f"{input_str},{kernel_str}->{output_str}"
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)

@ -0,0 +1,238 @@
import numpy as np
from absl.testing import parameterized
from keras_core import initializers
from keras_core import layers
from keras_core import testing
class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
def test_basics(self):
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=8,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"value_dim": 4,
"use_bias": False,
"dropout": 0.5,
},
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=4,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
@parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)),
("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)),
("4d_inputs_1freebatch_mask4", (3, 4), (3, 2), (3, 2, 4, 2), (2,)),
("4d_inputs_2d_attention", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)),
("5d_inputs_2d_attention", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)),
(
"5d_inputs_2d_attention_fullmask",
(5, 3, 4),
(5, 3, 2),
(5, 3, 4, 3, 2),
(2, 3),
),
)
def test_high_dim_attention(
self, q_dims, v_dims, mask_dims, attention_axes
):
batch_size, hidden_size = 3, 8
query_shape = (batch_size,) + q_dims + (hidden_size,)
value_shape = (batch_size,) + v_dims + (hidden_size,)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
"num_heads": 2,
"key_dim": 2,
"attention_axes": attention_axes,
},
input_shape={
"query_shape": query_shape,
"value_shape": value_shape,
},
expected_output_shape=query_shape,
expected_num_trainable_weights=8,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
@parameterized.named_parameters(
("without_key_same_proj", (4, 8), (2, 8), None, None),
("with_key_same_proj", (4, 8), (2, 8), (2, 3), None),
("wihtout_key_different_proj", (4, 8), (2, 8), None, (3, 4)),
("with_key_different_proj", (4, 8), (2, 8), (2, 3), (1, 5)),
("high_dim_same_proj", (4, 2, 3, 8), (1, 1, 5, 8), (1, 1, 5, 2), None),
(
"high_dim_different_proj",
(4, 2, 3, 8),
(1, 1, 5, 8),
(1, 1, 5, 2),
(3, 2),
),
)
def test_compute_output_shape(
self, query_dims, value_dims, key_dims, output_shape
):
"""Test computed shape is equal to the layer output's shape."""
layer = layers.MultiHeadAttention(
num_heads=2,
key_dim=2,
value_dim=2,
output_shape=output_shape,
)
batch_size = 7
query_shape = (batch_size,) + query_dims
value_shape = (batch_size,) + value_dims
key_shape = (batch_size,) + key_dims if key_dims else None
query = np.ones(query_shape)
value = np.ones(value_shape)
key = np.ones(key_shape) if key_shape else None
output = layer(query=query, value=value, key=key)
comp_output_shape = layer.compute_output_shape(
query_shape, value_shape, key_shape
)
self.assertEqual(output.shape, comp_output_shape)
@parameterized.named_parameters(
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), 2),
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),
(
"key_value_dim_mismatch_high_dim",
(2, 4, 2, 3, 8),
(2, 1, 1, 5, 8),
(2, 1, 15, 5, 2),
),
)
def test_shape_mismatch_error(self, query_shape, value_shape, key_shape):
"""Test dimension mismatches"""
layer = layers.MultiHeadAttention(
num_heads=4,
key_dim=2,
value_dim=2,
)
with self.assertRaisesRegex(ValueError, r"must be equal"):
layer.compute_output_shape(query_shape, value_shape, key_shape)
def test_initializer(self):
# Test with a specified initializer.
layer = layers.MultiHeadAttention(
num_heads=12,
key_dim=64,
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
)
layer.build((2, 4, 8), (2, 4, 8))
# Make sure the sub layers have different kernel init value.
self.assertNotAllClose(
layer._query_dense.kernel,
layer._key_dense.kernel,
)
self.assertNotAllClose(
layer._query_dense.kernel,
layer._value_dense.kernel,
)
self.assertNotAllClose(
layer._query_dense.kernel,
layer._output_dense.kernel,
)
def test_query_mask_progagation(self):
"""Test automatic propagation of the query's mask."""
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
self.assertTrue(layer.supports_masking)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.normal(size=(3, 3, 8))
output = layer(query=masked_query, value=value)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)
@parameterized.named_parameters(("causal", True), ("not_causal", 0))
def test_masking(self, use_causal_mask):
"""Test that the value and causal masks are taken into account."""
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = layers.Embedding(6, 8, mask_zero=True)(value)
output = layer(
query=masked_query,
value=masked_value,
use_causal_mask=use_causal_mask,
)
mask = np.array(
[[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]
+ [[[1, 0, 0]] * 5]
+ [[[1, 1, 1]] + [[0, 0, 0]] * 4]
).astype(bool)
if use_causal_mask:
mask = mask & np.array(
[[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]
).astype(bool)
del masked_query._keras_mask
del masked_value._keras_mask
output_with_manual_mask = layer(
query=masked_query, value=masked_value, attention_mask=mask
)
self.assertAllClose(output, output_with_manual_mask)
def test_correctness(self):
query = np.array([[[1.0, 0.0], [0.0, 1.0]]])
key = np.array([[[0.0, 1.0], [1.0, 0.0]]])
value = np.array([[[1.0, 2.0], [3.0, 4.0]]])
# Setup layer.
num_heads = 2
key_dim = 2
layer = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
)
layer.build(query.shape, key.shape, value.shape)
# Set layer weights.
kernel = np.identity(key_dim)
# To get an identity kernel we need to add a head dim and repeat on it.
kernel = np.repeat(kernel[:, np.newaxis, :], num_heads, axis=1)
# Zeros for all biases.
bias = np.zeros((2, 2))
output_bias = np.zeros((2,))
layer.set_weights([kernel, bias] * 3 + [kernel, output_bias])
# Call layer and assert output.
output, scores = layer(
query=query,
value=value,
key=key,
return_attention_scores=True,
)
self.assertAllClose(output, [[[5.679, 5.679], [4.32, 4.32]]], atol=1e-3)
self.assertAllClose(
scores,
[[[[0.33, 0.67], [0.67, 0.33]], [[0.33, 0.67], [0.67, 0.33]]]],
atol=1e-3,
)

@ -571,7 +571,7 @@ class Layer(Operation):
else:
# Use compute_output_shape() to return the right output spec
call_spec = CallSpec(self.call, args, kwargs)
shapes_dict = get_shapes_dict(call_spec)
shapes_dict = get_shapes_dict(self.compute_output_shape, call_spec)
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
@ -723,7 +723,7 @@ class Layer(Operation):
def _maybe_build(self, call_spec):
if not self.built:
shapes_dict = get_shapes_dict(call_spec)
shapes_dict = get_shapes_dict(self.build, call_spec)
self._build_shapes_dict = shapes_dict
failure = False
if len(shapes_dict) == 1:
@ -741,9 +741,6 @@ class Layer(Operation):
else:
self.build(input_shape)
else:
# More than one shape: pass them by name,
# and check that build() expects the right args.
check_build_signature(self.build, shapes_dict)
with backend.name_scope(self.name):
if utils.is_default(self.build):
if might_have_unbuilt_state(self):
@ -751,29 +748,7 @@ class Layer(Operation):
if not status:
failure = True
else:
run_build = True
build_args = set(
inspect.signature(self.build).parameters.keys()
)
for key in shapes_dict.keys():
if key not in build_args:
run_build = False
if run_build:
self.build(**shapes_dict)
else:
raise ValueError(
"In a layer with multiple tensor arguments "
"in call(), the build() method should accept "
"corresponding `*_shape` arguments, e.g. "
"if the call signature is "
"`def call(self, x1, x2)` "
"then the build signature should be "
"`def build(self, x1_shape, x2_shape)`. "
"Keras will not build this layer automatically "
"since it does not conform to this. "
"Expected the following build keys: "
f"{list(shapes_dict.keys())}"
)
self.build(**shapes_dict)
if failure:
if call_spec.eager:
# Will let the actual eager call do the state-building
@ -821,7 +796,7 @@ class Layer(Operation):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
# We strip the `_shape` suffix to recover kwarg names.
k[:-6]: backend.KerasTensor(shape)
k.removesuffix("_shape"): backend.KerasTensor(shape)
for k, shape in shapes_dict.items()
}
try:
@ -992,16 +967,17 @@ def get_arguments_dict(fn, args, kwargs):
return arg_dict
def get_shapes_dict(call_spec):
def get_shapes_dict(target_fn, call_spec):
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
>>> get_shapes_dict(call_spec)
>>> get_shapes_dict(self.build, call_spec)
{"input_a_shape": (2, 3)}
```
"""
expected_names = check_shapes_signature(target_fn, call_spec)
shapes_dict = {}
for k, v in call_spec.tensor_arguments_dict.items():
if k == "mask" or k.startswith("mask_"):
@ -1010,6 +986,8 @@ def get_shapes_dict(call_spec):
if k == "kwargs" or k == "args":
# Do not include catch-alls in shapes dict
continue
if expected_names is not None and f"{k}_shape" not in expected_names:
continue
if k in call_spec.nested_tensor_argument_names:
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v
@ -1019,22 +997,26 @@ def get_shapes_dict(call_spec):
return shapes_dict
def check_build_signature(build_fn, shapes_dict):
"""Asserts that the argument names in build_fn match entries in shapes_dict.
def check_shapes_signature(target_fn, call_spec):
"""Asserts that the argument names in `target_fn` match arguments in `call`.
For instance if call() has the signature `def call(self, a, b)`
then we'll see `shapes_dict == {"a_shape": (...), "b_shape": (...)}
and we expect build() to have signature `def build(self, a_shape, b_shape)`.
We use this to check that `build()` and `compute_output_shape()` arguments
align with `call()` arguments.
When there is a single tensor argument, we pass it positionally and thus
don't check names (if we did, it would force call() to always take
`input` as its first argument, which is usually not the case).
For instance if `build()` has the signature
`def build(self, a_shape, b_shape)` we expect `call()` to accept the
arguments `a` and `b`.
When there is a single argument accepted by `target_fn`, we do allow any
name and do not check the call signature.
Returns:
The list of arguments names expected by the `target_fn` or
`None` if any passed name is acceptable.
"""
if len(shapes_dict) == 1:
return
if utils.is_default(build_fn):
return
sig = inspect.signature(build_fn)
if utils.is_default(target_fn):
return None
sig = inspect.signature(target_fn)
expected_names = []
for name, param in sig.parameters.items():
if param.kind in (
@ -1043,14 +1025,29 @@ def check_build_signature(build_fn, shapes_dict):
param.KEYWORD_ONLY,
):
expected_names.append(name)
if set(expected_names) != set(shapes_dict.keys()):
comma_separated = ", ".join(shapes_dict.keys())
raise ValueError(
"For a `call()` method with more than one tensor argument, "
"the arguments of the `build()` method should match the "
"tensor arguments of `call()` method. Here we expect the signature "
f"`build(self, {comma_separated})`."
if len(expected_names) == 1:
return None
for name in expected_names:
method_name = target_fn.__name__
error_preamble = (
f"For a `{method_name}()` method with more than one argument, all "
"arguments should have a `_shape` suffix and match an argument "
f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` "
"would match `call(self, foo, bar)`."
)
if not name.endswith("_shape"):
raise ValueError(
f"{error_preamble} Received `{method_name}()` argument "
f"`{name}`, which does not end in `_shape`."
)
expected_call_arg = name.removesuffix("_shape")
if expected_call_arg not in call_spec.arguments_dict:
raise ValueError(
f"{error_preamble} Received `{method_name}()` argument "
f"`{name}`, but `call()` does not have argument "
f"`{expected_call_arg}`."
)
return expected_names
class CallContext:

@ -588,3 +588,41 @@ class LayerTest(testing.TestCase):
self.assertEqual(len(layer.trainable_variables), 2)
self.assertEqual(len(layer.non_trainable_weights), 2)
self.assertEqual(len(layer.non_trainable_variables), 3)
def test_build_signature_errors(self):
class NoShapeSuffix(layers.Layer):
def build(self, foo_shape, bar):
self._built = True
def call(self, foo, bar):
return foo + bar
class NonMatchingArgument(layers.Layer):
def build(self, foo_shape, baz_shape):
self._built = True
def call(self, foo, bar):
return foo + bar
class MatchingArguments(layers.Layer):
def build(self, foo_shape, bar_shape):
self._built = True
def call(self, foo, bar):
return foo + bar
foo = backend.numpy.ones((4, 4))
bar = backend.numpy.ones((4, 4))
with self.assertRaisesRegex(
ValueError,
r"argument `bar`, which does not end in `_shape`",
):
NoShapeSuffix()(foo, bar)
with self.assertRaisesRegex(
ValueError,
r"`baz_shape`, but `call\(\)` does not have argument `baz`",
):
NonMatchingArgument()(foo, bar)
MatchingArguments()(foo, bar)

@ -1,10 +1,10 @@
from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Dropout")
class Dropout(layers.Layer):
class Dropout(Layer):
"""Applies dropout to the input.
The `Dropout` layer randomly sets input units to 0 with a frequency of

@ -51,7 +51,7 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
for the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape `(batch, features)`.
sequence: A 2D tensor, with shape `(batch, features)`.
states: A 2D tensor with shape `(batch, units)`, which is the state
from the previous time step.
training: Python boolean indicating whether the layer should behave in
@ -151,14 +151,14 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
self.bias = None
self.built = True
def call(self, inputs, states, training=False):
def call(self, sequence, states, training=False):
prev_output = states[0] if isinstance(states, (list, tuple)) else states
dp_mask = self.get_dropout_mask(inputs)
dp_mask = self.get_dropout_mask(sequence)
rec_dp_mask = self.get_recurrent_dropout_mask(prev_output)
if training and dp_mask is not None:
inputs *= dp_mask
h = ops.matmul(inputs, self.kernel)
sequence *= dp_mask
h = ops.matmul(sequence, self.kernel)
if self.bias is not None:
h += self.bias
@ -265,7 +265,7 @@ class SimpleRNN(RNN):
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
sequence: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked. An individual `True` entry
indicates that the corresponding timestep should be utilized,

@ -255,13 +255,21 @@ class TestCase(unittest.TestCase):
# Build test.
if input_shape is not None:
layer = layer_cls(**init_kwargs)
layer.build(input_shape)
if isinstance(input_shape, dict):
layer.build(**input_shape)
else:
layer.build(input_shape)
run_build_asserts(layer)
# Symbolic call test.
keras_tensor_inputs = create_keras_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
if isinstance(keras_tensor_inputs, dict):
keras_tensor_outputs = layer(
**keras_tensor_inputs, **call_kwargs
)
else:
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
run_build_asserts(layer)
run_output_asserts(layer, keras_tensor_outputs, eager=False)
@ -274,7 +282,10 @@ class TestCase(unittest.TestCase):
if input_data is None:
input_data = create_eager_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
output_data = layer(input_data, **call_kwargs)
if isinstance(input_data, dict):
output_data = layer(**input_data, **call_kwargs)
else:
output_data = layer(input_data, **call_kwargs)
run_output_asserts(layer, output_data, eager=True)
@ -287,7 +298,7 @@ def create_keras_tensors(input_shape, dtype):
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k: keras_tensor.KerasTensor(v, dtype=dtype)
k.removesuffix("_shape"): keras_tensor.KerasTensor(v, dtype=dtype)
for k, v in input_shape.items()
}
@ -320,4 +331,7 @@ def create_eager_tensors(input_shape, dtype):
if isinstance(input_shape, list):
return [create_fn(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {k: create_fn(v, dtype=dtype) for k, v in input_shape.items()}
return {
k.removesuffix("_shape"): create_fn(v, dtype=dtype)
for k, v in input_shape.items()
}