Add layers.MultiHeadAttention (#169)
* Add layers.MultiHeadAttention * Update build/compute_output_signature argument checks We do not definitively know which arguments are tensor arguments at any given invocation (e.g. arguments with a None value may be tensor arguments). So rather than check that the build signature matches perfectly with tensor call arguments, we will check the build signature arguments match with some call argument.
This commit is contained in:
parent
3eaa2675df
commit
cc053ac309
@ -1,6 +1,7 @@
|
||||
from keras_core.layers.activations.activation import Activation
|
||||
from keras_core.layers.attention.additive_attention import AdditiveAttention
|
||||
from keras_core.layers.attention.attention import Attention
|
||||
from keras_core.layers.attention.multi_head_attention import MultiHeadAttention
|
||||
from keras_core.layers.convolutional.conv1d import Conv1D
|
||||
from keras_core.layers.convolutional.conv1d_transpose import Conv1DTranspose
|
||||
from keras_core.layers.convolutional.conv2d import Conv2D
|
||||
|
649
keras_core/layers/attention/multi_head_attention.py
Normal file
649
keras_core/layers/attention/multi_head_attention.py
Normal file
@ -0,0 +1,649 @@
|
||||
import collections
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras_core import constraints
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.activations.softmax import Softmax
|
||||
from keras_core.layers.core.einsum_dense import EinsumDense
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.layers.regularization.dropout import Dropout
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Attention")
|
||||
class MultiHeadAttention(Layer):
|
||||
"""MultiHeadAttention layer.
|
||||
|
||||
This is an implementation of multi-headed attention as described in the
|
||||
paper "Attention is all you Need"
|
||||
[Vaswani et al., 2017](https://arxiv.org/abs/1706.03762).
|
||||
If `query`, `key,` `value` are the same, then
|
||||
this is self-attention. Each timestep in `query` attends to the
|
||||
corresponding sequence in `key`, and returns a fixed-width vector.
|
||||
|
||||
This layer first projects `query`, `key` and `value`. These are
|
||||
(effectively) a list of tensors of length `num_attention_heads`, where the
|
||||
corresponding shapes are `(batch_size, <query dimensions>, key_dim)`,
|
||||
`(batch_size, <key/value dimensions>, key_dim)`,
|
||||
`(batch_size, <key/value dimensions>, value_dim)`.
|
||||
|
||||
Then, the query and key tensors are dot-producted and scaled. These are
|
||||
softmaxed to obtain attention probabilities. The value tensors are then
|
||||
interpolated by these probabilities, then concatenated back to a single
|
||||
tensor.
|
||||
|
||||
Finally, the result tensor with the last dimension as `value_dim` can take
|
||||
a linear projection and return.
|
||||
|
||||
Args:
|
||||
num_heads: Number of attention heads.
|
||||
key_dim: Size of each attention head for query and key.
|
||||
value_dim: Size of each attention head for value.
|
||||
dropout: Dropout probability.
|
||||
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
|
||||
output_shape: The expected shape of an output tensor, besides the batch
|
||||
and sequence dims. If not specified, projects back to the query
|
||||
feature dim (the query input's last dimension).
|
||||
attention_axes: axes over which the attention is applied. `None` means
|
||||
attention over all axes, but batch, heads, and features.
|
||||
kernel_initializer: Initializer for dense layer kernels.
|
||||
bias_initializer: Initializer for dense layer biases.
|
||||
kernel_regularizer: Regularizer for dense layer kernels.
|
||||
bias_regularizer: Regularizer for dense layer biases.
|
||||
activity_regularizer: Regularizer for dense layer activity.
|
||||
kernel_constraint: Constraint for dense layer kernels.
|
||||
bias_constraint: Constraint for dense layer kernels.
|
||||
|
||||
Call arguments:
|
||||
query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
|
||||
`T` is the target sequence length, and dim is the feature dimension.
|
||||
value: Value tensor of shape `(B, S, dim)`, where `B` is the batch size,
|
||||
`S` is the source sequence length, and dim is the feature dimension.
|
||||
key: Optional key tensor of shape `(B, S, dim)`. If not given, will
|
||||
use `value` for both `key` and `value`, which is the most common
|
||||
case.
|
||||
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
|
||||
attention to certain positions. The boolean mask specifies which
|
||||
query elements can attend to which key elements, 1 indicates
|
||||
attention and 0 indicates no attention. Broadcasting can happen for
|
||||
the missing batch dimensions and the head dimension.
|
||||
return_attention_scores: A boolean to indicate whether the output should
|
||||
be `(attention_output, attention_scores)` if `True`, or
|
||||
`attention_output` if `False`. Defaults to `False`.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
Will go with either using the training mode of the parent
|
||||
layer/model, or `False` (inference) if there is no parent layer.
|
||||
use_causal_mask: A boolean to indicate whether to apply a causal mask to
|
||||
prevent tokens from attending to future tokens (e.g., used in a
|
||||
decoder Transformer).
|
||||
|
||||
Returns:
|
||||
attention_output: The result of the computation, of shape `(B, T, E)`,
|
||||
where `T` is for target sequence shapes and `E` is the query input
|
||||
last dimension if `output_shape` is `None`. Otherwise, the
|
||||
multi-head outputs are projected to the shape specified by
|
||||
`output_shape`.
|
||||
attention_scores: (Optional) multi-head attention coefficients over
|
||||
attention axes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim=None,
|
||||
dropout=0.0,
|
||||
use_bias=True,
|
||||
output_shape=None,
|
||||
attention_axes=None,
|
||||
kernel_initializer="glorot_uniform",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self._num_heads = num_heads
|
||||
self._key_dim = key_dim
|
||||
self._value_dim = value_dim if value_dim else key_dim
|
||||
self._dropout = dropout
|
||||
self._use_bias = use_bias
|
||||
self._output_shape = output_shape
|
||||
self._kernel_initializer = initializers.get(kernel_initializer)
|
||||
self._bias_initializer = initializers.get(bias_initializer)
|
||||
self._kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
self._bias_regularizer = regularizers.get(bias_regularizer)
|
||||
self._activity_regularizer = regularizers.get(activity_regularizer)
|
||||
self._kernel_constraint = constraints.get(kernel_constraint)
|
||||
self._bias_constraint = constraints.get(bias_constraint)
|
||||
if isinstance(attention_axes, int):
|
||||
attention_axes = (attention_axes,)
|
||||
elif attention_axes and not isinstance(attention_axes, (list, tuple)):
|
||||
raise ValueError(
|
||||
"`attention_axes` must be an int, list, or tuple."
|
||||
f"Received: attention_axes={attention_axes}"
|
||||
)
|
||||
self._attention_axes = attention_axes
|
||||
self._built_from_signature = False
|
||||
|
||||
def get_config(self):
|
||||
base_config = super().get_config()
|
||||
config = {
|
||||
"num_heads": self._num_heads,
|
||||
"key_dim": self._key_dim,
|
||||
"value_dim": self._value_dim,
|
||||
"dropout": self._dropout,
|
||||
"use_bias": self._use_bias,
|
||||
"output_shape": self._output_shape,
|
||||
"attention_axes": self._attention_axes,
|
||||
"kernel_initializer": initializers.serialize(
|
||||
self._kernel_initializer
|
||||
),
|
||||
"bias_initializer": initializers.serialize(self._bias_initializer),
|
||||
"kernel_regularizer": regularizers.serialize(
|
||||
self._kernel_regularizer
|
||||
),
|
||||
"bias_regularizer": regularizers.serialize(self._bias_regularizer),
|
||||
"activity_regularizer": regularizers.serialize(
|
||||
self._activity_regularizer
|
||||
),
|
||||
"kernel_constraint": constraints.serialize(self._kernel_constraint),
|
||||
"bias_constraint": constraints.serialize(self._bias_constraint),
|
||||
}
|
||||
return {**base_config, **config}
|
||||
|
||||
def build(
|
||||
self,
|
||||
query_shape,
|
||||
value_shape,
|
||||
key_shape=None,
|
||||
):
|
||||
"""Builds layers and variables.
|
||||
|
||||
Args:
|
||||
query_shape: Shape of the `query` tensor.
|
||||
value_shape: Shape of the `value` tensor.
|
||||
key: Optional shape of the `key` tensor.
|
||||
"""
|
||||
key_shape = value_shape if key_shape is None else key_shape
|
||||
query_rank = len(query_shape)
|
||||
value_rank = len(value_shape)
|
||||
key_rank = len(key_shape)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
query_rank - 1, bound_dims=1, output_dims=2
|
||||
)
|
||||
self._query_dense = EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(
|
||||
output_rank - 1, [self._num_heads, self._key_dim]
|
||||
),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="query",
|
||||
**self._get_common_kwargs_for_sublayer(),
|
||||
)
|
||||
self._query_dense.build(query_shape)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
key_rank - 1, bound_dims=1, output_dims=2
|
||||
)
|
||||
self._key_dense = EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(
|
||||
output_rank - 1, [self._num_heads, self._key_dim]
|
||||
),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="key",
|
||||
**self._get_common_kwargs_for_sublayer(),
|
||||
)
|
||||
self._key_dense.build(key_shape)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
value_rank - 1, bound_dims=1, output_dims=2
|
||||
)
|
||||
self._value_dense = EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(
|
||||
output_rank - 1, [self._num_heads, self._value_dim]
|
||||
),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="value",
|
||||
**self._get_common_kwargs_for_sublayer(),
|
||||
)
|
||||
self._value_dense.build(value_shape)
|
||||
|
||||
# Builds the attention computations for multi-head dot product
|
||||
# attention. These computations could be wrapped into the keras
|
||||
# attention layer once it supports multi-head einsum computations.
|
||||
self._build_attention(output_rank)
|
||||
self._output_dense = self._make_output_dense(
|
||||
query_shape,
|
||||
self._get_common_kwargs_for_sublayer(),
|
||||
"attention_output",
|
||||
)
|
||||
output_dense_input_shape = list(
|
||||
self._query_dense.compute_output_shape(query_shape)
|
||||
)
|
||||
output_dense_input_shape[-1] = self._value_dim
|
||||
self._output_dense.build(tuple(output_dense_input_shape))
|
||||
self.built = True
|
||||
|
||||
def _get_common_kwargs_for_sublayer(self):
|
||||
common_kwargs = dict(
|
||||
kernel_regularizer=self._kernel_regularizer,
|
||||
bias_regularizer=self._bias_regularizer,
|
||||
activity_regularizer=self._activity_regularizer,
|
||||
kernel_constraint=self._kernel_constraint,
|
||||
bias_constraint=self._bias_constraint,
|
||||
)
|
||||
# Create new clone of kernel/bias initializer, so that we don't reuse
|
||||
# the initializer instance, which could lead to same init value since
|
||||
# initializer is stateless.
|
||||
kernel_initializer = self._kernel_initializer.__class__.from_config(
|
||||
self._kernel_initializer.get_config()
|
||||
)
|
||||
bias_initializer = self._bias_initializer.__class__.from_config(
|
||||
self._bias_initializer.get_config()
|
||||
)
|
||||
common_kwargs["kernel_initializer"] = kernel_initializer
|
||||
common_kwargs["bias_initializer"] = bias_initializer
|
||||
return common_kwargs
|
||||
|
||||
def _make_output_dense(self, query_shape, common_kwargs, name=None):
|
||||
"""Builds the output projection matrix.
|
||||
|
||||
Args:
|
||||
free_dims: Number of free dimensions for einsum equation building.
|
||||
common_kwargs: Common keyword arguments for einsum layer.
|
||||
name: Name for the projection layer.
|
||||
|
||||
Returns:
|
||||
Projection layer.
|
||||
"""
|
||||
query_rank = len(query_shape)
|
||||
if self._output_shape:
|
||||
if not isinstance(self._output_shape, collections.abc.Sized):
|
||||
output_shape = [self._output_shape]
|
||||
else:
|
||||
output_shape = self._output_shape
|
||||
else:
|
||||
output_shape = [query_shape[-1]]
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
query_rank - 1, bound_dims=2, output_dims=len(output_shape)
|
||||
)
|
||||
return EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1, output_shape),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name=name,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
def _build_attention(self, rank):
|
||||
"""Builds multi-head dot-product attention computations.
|
||||
|
||||
This function builds attributes necessary for `_compute_attention` to
|
||||
customize attention computation to replace the default dot-product
|
||||
attention.
|
||||
|
||||
Args:
|
||||
rank: the rank of query, key, value tensors.
|
||||
"""
|
||||
if self._attention_axes is None:
|
||||
self._attention_axes = tuple(range(1, rank - 2))
|
||||
else:
|
||||
self._attention_axes = tuple(self._attention_axes)
|
||||
(
|
||||
self._dot_product_equation,
|
||||
self._combine_equation,
|
||||
attn_scores_rank,
|
||||
) = _build_attention_equation(rank, attn_axes=self._attention_axes)
|
||||
norm_axes = tuple(
|
||||
range(
|
||||
attn_scores_rank - len(self._attention_axes), attn_scores_rank
|
||||
)
|
||||
)
|
||||
self._softmax = Softmax(axis=norm_axes)
|
||||
self._dropout_layer = Dropout(rate=self._dropout)
|
||||
|
||||
def _masked_softmax(self, attention_scores, attention_mask=None):
|
||||
# Normalize the attention scores to probabilities.
|
||||
# attention_scores = [B, N, T, S]
|
||||
if attention_mask is not None:
|
||||
# The expand dim happens starting from the `num_heads` dimension,
|
||||
# (<batch_dims>, num_heads, <query_attention_dims,
|
||||
# key_attention_dims>)
|
||||
mask_expansion_axis = -len(self._attention_axes) * 2 - 1
|
||||
for _ in range(
|
||||
len(attention_scores.shape) - len(attention_mask.shape)
|
||||
):
|
||||
attention_mask = ops.expand_dims(
|
||||
attention_mask, axis=mask_expansion_axis
|
||||
)
|
||||
return self._softmax(attention_scores, mask=attention_mask)
|
||||
|
||||
def _compute_attention(
|
||||
self, query, key, value, attention_mask=None, training=None
|
||||
):
|
||||
"""Applies Dot-product attention with query, key, value tensors.
|
||||
|
||||
This function defines the computation inside `call` with projected
|
||||
multi-head Q, K, V inputs. Users can override this function for
|
||||
customized attention implementation.
|
||||
|
||||
Args:
|
||||
query: Projected query tensor of shape `(B, T, N, key_dim)`.
|
||||
key: Projected key tensor of shape `(B, S, N, key_dim)`.
|
||||
value: Projected value tensor of shape `(B, S, N, value_dim)`.
|
||||
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
|
||||
attention to certain positions. It is generally not needed if
|
||||
the `query` and `value` (and/or `key`) are masked.
|
||||
training: Python boolean indicating whether the layer should behave
|
||||
in training mode (adding dropout) or in inference mode (doing
|
||||
nothing).
|
||||
|
||||
Returns:
|
||||
attention_output: Multi-headed outputs of attention computation.
|
||||
attention_scores: Multi-headed attention weights.
|
||||
"""
|
||||
# Note: Applying scalar multiply at the smaller end of einsum improves
|
||||
# XLA performance, but may introduce slight numeric differences in
|
||||
# the Transformer attention head.
|
||||
query = ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
attention_scores = ops.einsum(self._dot_product_equation, key, query)
|
||||
|
||||
attention_scores = self._masked_softmax(
|
||||
attention_scores, attention_mask
|
||||
)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_scores_dropout = self._dropout_layer(
|
||||
attention_scores, training=training
|
||||
)
|
||||
|
||||
# `context_layer` = [B, T, N, H]
|
||||
attention_output = ops.einsum(
|
||||
self._combine_equation, attention_scores_dropout, value
|
||||
)
|
||||
return attention_output, attention_scores
|
||||
|
||||
def call(
|
||||
self,
|
||||
query,
|
||||
value,
|
||||
key=None,
|
||||
query_mask=None,
|
||||
value_mask=None,
|
||||
key_mask=None,
|
||||
attention_mask=None,
|
||||
return_attention_scores=False,
|
||||
training=None,
|
||||
use_causal_mask=False,
|
||||
):
|
||||
if key is None:
|
||||
key = value
|
||||
|
||||
attention_mask = self._compute_attention_mask(
|
||||
query,
|
||||
value,
|
||||
query_mask=query_mask,
|
||||
value_mask=value_mask,
|
||||
key_mask=key_mask,
|
||||
attention_mask=attention_mask,
|
||||
use_causal_mask=use_causal_mask,
|
||||
)
|
||||
|
||||
# N = `num_attention_heads`
|
||||
# H = `size_per_head`
|
||||
# `query` = [B, T, N ,H]
|
||||
query = self._query_dense(query)
|
||||
|
||||
# `key` = [B, S, N, H]
|
||||
key = self._key_dense(key)
|
||||
|
||||
# `value` = [B, S, N, H]
|
||||
value = self._value_dense(value)
|
||||
|
||||
attention_output, attention_scores = self._compute_attention(
|
||||
query, key, value, attention_mask, training
|
||||
)
|
||||
attention_output = self._output_dense(attention_output)
|
||||
|
||||
if return_attention_scores:
|
||||
return attention_output, attention_scores
|
||||
return attention_output
|
||||
|
||||
def _compute_attention_mask(
|
||||
self,
|
||||
query,
|
||||
value,
|
||||
query_mask=None,
|
||||
value_mask=None,
|
||||
key_mask=None,
|
||||
attention_mask=None,
|
||||
use_causal_mask=False,
|
||||
):
|
||||
"""Computes the attention mask, using the Keras masks of the inputs.
|
||||
|
||||
* The `query`'s mask is reshaped from [B, T] to [B, T, 1].
|
||||
* The `value`'s mask is reshaped from [B, S] to [B, 1, S].
|
||||
* The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
|
||||
mask is ignored if `key` is `None` or if `key is value`.
|
||||
* If `use_causal_mask=True`, then the causal mask is computed. Its shape
|
||||
is [1, T, S].
|
||||
|
||||
All defined masks are merged using a logical AND operation (`&`).
|
||||
|
||||
In general, if the `query` and `value` are masked, then there is no need
|
||||
to define the `attention_mask`.
|
||||
|
||||
Args:
|
||||
query: Projected query tensor of shape `(B, T, N, key_dim)`.
|
||||
key: Projected key tensor of shape `(B, T, N, key_dim)`.
|
||||
value: Projected value tensor of shape `(B, T, N, value_dim)`.
|
||||
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
|
||||
attention to certain positions.
|
||||
use_causal_mask: A boolean to indicate whether to apply a causal
|
||||
mask to prevent tokens from attending to future tokens (e.g.,
|
||||
used in a decoder Transformer).
|
||||
|
||||
Returns:
|
||||
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
|
||||
attention to certain positions, based on the Keras masks of the
|
||||
`query`, `key`, `value`, and `attention_mask` tensors, and the
|
||||
causal mask if `use_causal_mask=True`.
|
||||
"""
|
||||
auto_mask = None
|
||||
if query_mask is not None:
|
||||
query_mask = ops.cast(query_mask, "bool") # defensive casting
|
||||
# B = batch size, T = max query length
|
||||
auto_mask = ops.expand_dims(query_mask, -1) # shape is [B, T, 1]
|
||||
if value_mask is not None:
|
||||
value_mask = ops.cast(value_mask, "bool") # defensive casting
|
||||
# B = batch size, S == max value length
|
||||
mask = ops.expand_dims(value_mask, -2) # shape is [B, 1, S]
|
||||
auto_mask = mask if auto_mask is None else auto_mask & mask
|
||||
if key_mask is not None:
|
||||
key_mask = ops.cast(key_mask, "bool") # defensive casting
|
||||
# B == batch size, S == max key length == max value length
|
||||
mask = ops.expand_dims(key_mask, -2) # shape is [B, 1, S]
|
||||
auto_mask = mask if auto_mask is None else auto_mask & mask
|
||||
if use_causal_mask:
|
||||
# the shape of the causal mask is [1, T, S]
|
||||
mask = self._compute_causal_mask(query, value)
|
||||
auto_mask = mask if auto_mask is None else auto_mask & mask
|
||||
if auto_mask is not None:
|
||||
# merge attention_mask & automatic mask, to shape [B, T, S]
|
||||
attention_mask = (
|
||||
auto_mask
|
||||
if attention_mask is None
|
||||
else ops.cast(attention_mask, bool) & auto_mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
def _compute_causal_mask(self, query, value=None):
|
||||
"""Computes a causal mask (e.g., for masked self-attention layers).
|
||||
|
||||
For example, if query and value both contain sequences of length 4,
|
||||
this function returns a boolean tensor equal to:
|
||||
|
||||
```
|
||||
[[[True, False, False, False],
|
||||
[True, True, False, False],
|
||||
[True, True, True, False],
|
||||
[True, True, True, True]]]
|
||||
```
|
||||
|
||||
Args:
|
||||
query: query tensor of shape `(B, T, ...)`.
|
||||
value: value tensor of shape `(B, S, ...)` (optional, defaults to
|
||||
query).
|
||||
|
||||
Returns:
|
||||
mask: a boolean tensor of shape `(1, T, S)` containing a lower
|
||||
triangular matrix of shape `(T, S)`.
|
||||
"""
|
||||
q_seq_length = ops.shape(query)[1]
|
||||
v_seq_length = q_seq_length if value is None else ops.shape(value)[1]
|
||||
ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype="int32")
|
||||
row_index = ops.cumsum(ones_mask, axis=-2)
|
||||
col_index = ops.cumsum(ones_mask, axis=-1)
|
||||
return ops.greater_equal(row_index, col_index)
|
||||
|
||||
def compute_output_shape(
|
||||
self,
|
||||
query_shape,
|
||||
value_shape,
|
||||
key_shape=None,
|
||||
):
|
||||
if key_shape is None:
|
||||
key_shape = value_shape
|
||||
|
||||
if query_shape[-1] != value_shape[-1]:
|
||||
raise ValueError(
|
||||
"The last dimension of `query_shape` and `value_shape` "
|
||||
f"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. "
|
||||
"Received: query_shape={query_shape}, value_shape={value_shape}"
|
||||
)
|
||||
|
||||
if value_shape[1:-1] != key_shape[1:-1]:
|
||||
raise ValueError(
|
||||
"All dimensions of `value` and `key`, except the last one, "
|
||||
f"must be equal. Received: value_shape={value_shape} and "
|
||||
f"key_shape={key_shape}"
|
||||
)
|
||||
|
||||
if self._output_shape:
|
||||
return query_shape[:-1] + self._output_shape
|
||||
|
||||
return query_shape
|
||||
|
||||
|
||||
def _index_to_einsum_variable(i):
|
||||
"""Coverts an index to a einsum variable name.
|
||||
|
||||
We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'.
|
||||
"""
|
||||
return string.ascii_lowercase[i]
|
||||
|
||||
|
||||
def _build_attention_equation(rank, attn_axes):
|
||||
"""Builds einsum equations for the attention computation.
|
||||
|
||||
Query, key, value inputs after projection are expected to have the shape as:
|
||||
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
|
||||
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
|
||||
|
||||
The attention operations can be generalized:
|
||||
1. Query-key dot product:
|
||||
(<batch dims>, <query attention dims>, num_heads, channels),
|
||||
(<batch dims>, <key attention dims>, num_heads, channels) ->
|
||||
(<batch dims>, num_heads, <query attention dims>, <key attention dims>)
|
||||
2. Combination:
|
||||
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
|
||||
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch
|
||||
dims>, <query attention dims>, num_heads, channels)
|
||||
|
||||
Args:
|
||||
rank: Rank of query, key, value tensors.
|
||||
attn_axes: List/tuple of axes, `[-1, rank)`,
|
||||
that attention will be applied to.
|
||||
|
||||
Returns:
|
||||
Einsum equations.
|
||||
"""
|
||||
target_notation = ""
|
||||
for i in range(rank):
|
||||
target_notation += _index_to_einsum_variable(i)
|
||||
# `batch_dims` includes the head dim.
|
||||
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
|
||||
letter_offset = rank
|
||||
source_notation = ""
|
||||
for i in range(rank):
|
||||
if i in batch_dims or i == rank - 1:
|
||||
source_notation += target_notation[i]
|
||||
else:
|
||||
source_notation += _index_to_einsum_variable(letter_offset)
|
||||
letter_offset += 1
|
||||
|
||||
product_notation = "".join(
|
||||
[target_notation[i] for i in batch_dims]
|
||||
+ [target_notation[i] for i in attn_axes]
|
||||
+ [source_notation[i] for i in attn_axes]
|
||||
)
|
||||
dot_product_equation = "%s,%s->%s" % (
|
||||
source_notation,
|
||||
target_notation,
|
||||
product_notation,
|
||||
)
|
||||
attn_scores_rank = len(product_notation)
|
||||
combine_equation = "%s,%s->%s" % (
|
||||
product_notation,
|
||||
source_notation,
|
||||
target_notation,
|
||||
)
|
||||
return dot_product_equation, combine_equation, attn_scores_rank
|
||||
|
||||
|
||||
def _build_proj_equation(free_dims, bound_dims, output_dims):
|
||||
"""Builds an einsum equation for projections inside multi-head attention."""
|
||||
input_str = ""
|
||||
kernel_str = ""
|
||||
output_str = ""
|
||||
bias_axes = ""
|
||||
letter_offset = 0
|
||||
for i in range(free_dims):
|
||||
char = _index_to_einsum_variable(i + letter_offset)
|
||||
input_str += char
|
||||
output_str += char
|
||||
|
||||
letter_offset += free_dims
|
||||
for i in range(bound_dims):
|
||||
char = _index_to_einsum_variable(i + letter_offset)
|
||||
input_str += char
|
||||
kernel_str += char
|
||||
|
||||
letter_offset += bound_dims
|
||||
for i in range(output_dims):
|
||||
char = _index_to_einsum_variable(i + letter_offset)
|
||||
kernel_str += char
|
||||
output_str += char
|
||||
bias_axes += char
|
||||
equation = f"{input_str},{kernel_str}->{output_str}"
|
||||
|
||||
return equation, bias_axes, len(output_str)
|
||||
|
||||
|
||||
def _get_output_shape(output_rank, known_last_dims):
|
||||
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
|
238
keras_core/layers/attention/multi_head_attention_test.py
Normal file
238
keras_core/layers/attention/multi_head_attention_test.py
Normal file
@ -0,0 +1,238 @@
|
||||
import numpy as np
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import initializers
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.MultiHeadAttention,
|
||||
init_kwargs={
|
||||
"num_heads": 2,
|
||||
"key_dim": 2,
|
||||
},
|
||||
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
|
||||
expected_output_shape=(2, 8, 16),
|
||||
expected_num_trainable_weights=8,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
self.run_layer_test(
|
||||
layers.MultiHeadAttention,
|
||||
init_kwargs={
|
||||
"num_heads": 2,
|
||||
"key_dim": 2,
|
||||
"value_dim": 4,
|
||||
"use_bias": False,
|
||||
"dropout": 0.5,
|
||||
},
|
||||
input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)},
|
||||
expected_output_shape=(2, 8, 16),
|
||||
expected_num_trainable_weights=4,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)),
|
||||
("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)),
|
||||
("4d_inputs_1freebatch_mask4", (3, 4), (3, 2), (3, 2, 4, 2), (2,)),
|
||||
("4d_inputs_2d_attention", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)),
|
||||
("5d_inputs_2d_attention", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)),
|
||||
(
|
||||
"5d_inputs_2d_attention_fullmask",
|
||||
(5, 3, 4),
|
||||
(5, 3, 2),
|
||||
(5, 3, 4, 3, 2),
|
||||
(2, 3),
|
||||
),
|
||||
)
|
||||
def test_high_dim_attention(
|
||||
self, q_dims, v_dims, mask_dims, attention_axes
|
||||
):
|
||||
batch_size, hidden_size = 3, 8
|
||||
query_shape = (batch_size,) + q_dims + (hidden_size,)
|
||||
value_shape = (batch_size,) + v_dims + (hidden_size,)
|
||||
self.run_layer_test(
|
||||
layers.MultiHeadAttention,
|
||||
init_kwargs={
|
||||
"num_heads": 2,
|
||||
"key_dim": 2,
|
||||
"attention_axes": attention_axes,
|
||||
},
|
||||
input_shape={
|
||||
"query_shape": query_shape,
|
||||
"value_shape": value_shape,
|
||||
},
|
||||
expected_output_shape=query_shape,
|
||||
expected_num_trainable_weights=8,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("without_key_same_proj", (4, 8), (2, 8), None, None),
|
||||
("with_key_same_proj", (4, 8), (2, 8), (2, 3), None),
|
||||
("wihtout_key_different_proj", (4, 8), (2, 8), None, (3, 4)),
|
||||
("with_key_different_proj", (4, 8), (2, 8), (2, 3), (1, 5)),
|
||||
("high_dim_same_proj", (4, 2, 3, 8), (1, 1, 5, 8), (1, 1, 5, 2), None),
|
||||
(
|
||||
"high_dim_different_proj",
|
||||
(4, 2, 3, 8),
|
||||
(1, 1, 5, 8),
|
||||
(1, 1, 5, 2),
|
||||
(3, 2),
|
||||
),
|
||||
)
|
||||
def test_compute_output_shape(
|
||||
self, query_dims, value_dims, key_dims, output_shape
|
||||
):
|
||||
"""Test computed shape is equal to the layer output's shape."""
|
||||
layer = layers.MultiHeadAttention(
|
||||
num_heads=2,
|
||||
key_dim=2,
|
||||
value_dim=2,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
batch_size = 7
|
||||
query_shape = (batch_size,) + query_dims
|
||||
value_shape = (batch_size,) + value_dims
|
||||
key_shape = (batch_size,) + key_dims if key_dims else None
|
||||
|
||||
query = np.ones(query_shape)
|
||||
value = np.ones(value_shape)
|
||||
key = np.ones(key_shape) if key_shape else None
|
||||
output = layer(query=query, value=value, key=key)
|
||||
comp_output_shape = layer.compute_output_shape(
|
||||
query_shape, value_shape, key_shape
|
||||
)
|
||||
self.assertEqual(output.shape, comp_output_shape)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), 2),
|
||||
("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)),
|
||||
(
|
||||
"key_value_dim_mismatch_high_dim",
|
||||
(2, 4, 2, 3, 8),
|
||||
(2, 1, 1, 5, 8),
|
||||
(2, 1, 15, 5, 2),
|
||||
),
|
||||
)
|
||||
def test_shape_mismatch_error(self, query_shape, value_shape, key_shape):
|
||||
"""Test dimension mismatches"""
|
||||
layer = layers.MultiHeadAttention(
|
||||
num_heads=4,
|
||||
key_dim=2,
|
||||
value_dim=2,
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, r"must be equal"):
|
||||
layer.compute_output_shape(query_shape, value_shape, key_shape)
|
||||
|
||||
def test_initializer(self):
|
||||
# Test with a specified initializer.
|
||||
layer = layers.MultiHeadAttention(
|
||||
num_heads=12,
|
||||
key_dim=64,
|
||||
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
||||
)
|
||||
layer.build((2, 4, 8), (2, 4, 8))
|
||||
|
||||
# Make sure the sub layers have different kernel init value.
|
||||
self.assertNotAllClose(
|
||||
layer._query_dense.kernel,
|
||||
layer._key_dense.kernel,
|
||||
)
|
||||
self.assertNotAllClose(
|
||||
layer._query_dense.kernel,
|
||||
layer._value_dense.kernel,
|
||||
)
|
||||
self.assertNotAllClose(
|
||||
layer._query_dense.kernel,
|
||||
layer._output_dense.kernel,
|
||||
)
|
||||
|
||||
def test_query_mask_progagation(self):
|
||||
"""Test automatic propagation of the query's mask."""
|
||||
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
|
||||
self.assertTrue(layer.supports_masking)
|
||||
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
|
||||
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
|
||||
value = np.random.normal(size=(3, 3, 8))
|
||||
output = layer(query=masked_query, value=value)
|
||||
self.assertAllClose(masked_query._keras_mask, output._keras_mask)
|
||||
|
||||
@parameterized.named_parameters(("causal", True), ("not_causal", 0))
|
||||
def test_masking(self, use_causal_mask):
|
||||
"""Test that the value and causal masks are taken into account."""
|
||||
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
|
||||
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
|
||||
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
|
||||
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
|
||||
masked_value = layers.Embedding(6, 8, mask_zero=True)(value)
|
||||
output = layer(
|
||||
query=masked_query,
|
||||
value=masked_value,
|
||||
use_causal_mask=use_causal_mask,
|
||||
)
|
||||
mask = np.array(
|
||||
[[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]
|
||||
+ [[[1, 0, 0]] * 5]
|
||||
+ [[[1, 1, 1]] + [[0, 0, 0]] * 4]
|
||||
).astype(bool)
|
||||
if use_causal_mask:
|
||||
mask = mask & np.array(
|
||||
[[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]
|
||||
).astype(bool)
|
||||
del masked_query._keras_mask
|
||||
del masked_value._keras_mask
|
||||
output_with_manual_mask = layer(
|
||||
query=masked_query, value=masked_value, attention_mask=mask
|
||||
)
|
||||
self.assertAllClose(output, output_with_manual_mask)
|
||||
|
||||
def test_correctness(self):
|
||||
query = np.array([[[1.0, 0.0], [0.0, 1.0]]])
|
||||
key = np.array([[[0.0, 1.0], [1.0, 0.0]]])
|
||||
value = np.array([[[1.0, 2.0], [3.0, 4.0]]])
|
||||
|
||||
# Setup layer.
|
||||
num_heads = 2
|
||||
key_dim = 2
|
||||
layer = layers.MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
key_dim=key_dim,
|
||||
)
|
||||
layer.build(query.shape, key.shape, value.shape)
|
||||
|
||||
# Set layer weights.
|
||||
kernel = np.identity(key_dim)
|
||||
# To get an identity kernel we need to add a head dim and repeat on it.
|
||||
kernel = np.repeat(kernel[:, np.newaxis, :], num_heads, axis=1)
|
||||
# Zeros for all biases.
|
||||
bias = np.zeros((2, 2))
|
||||
output_bias = np.zeros((2,))
|
||||
layer.set_weights([kernel, bias] * 3 + [kernel, output_bias])
|
||||
|
||||
# Call layer and assert output.
|
||||
output, scores = layer(
|
||||
query=query,
|
||||
value=value,
|
||||
key=key,
|
||||
return_attention_scores=True,
|
||||
)
|
||||
self.assertAllClose(output, [[[5.679, 5.679], [4.32, 4.32]]], atol=1e-3)
|
||||
self.assertAllClose(
|
||||
scores,
|
||||
[[[[0.33, 0.67], [0.67, 0.33]], [[0.33, 0.67], [0.67, 0.33]]]],
|
||||
atol=1e-3,
|
||||
)
|
@ -571,7 +571,7 @@ class Layer(Operation):
|
||||
else:
|
||||
# Use compute_output_shape() to return the right output spec
|
||||
call_spec = CallSpec(self.call, args, kwargs)
|
||||
shapes_dict = get_shapes_dict(call_spec)
|
||||
shapes_dict = get_shapes_dict(self.compute_output_shape, call_spec)
|
||||
if len(shapes_dict) == 1:
|
||||
# Single arg: pass it positionally
|
||||
input_shape = tuple(shapes_dict.values())[0]
|
||||
@ -723,7 +723,7 @@ class Layer(Operation):
|
||||
|
||||
def _maybe_build(self, call_spec):
|
||||
if not self.built:
|
||||
shapes_dict = get_shapes_dict(call_spec)
|
||||
shapes_dict = get_shapes_dict(self.build, call_spec)
|
||||
self._build_shapes_dict = shapes_dict
|
||||
failure = False
|
||||
if len(shapes_dict) == 1:
|
||||
@ -741,9 +741,6 @@ class Layer(Operation):
|
||||
else:
|
||||
self.build(input_shape)
|
||||
else:
|
||||
# More than one shape: pass them by name,
|
||||
# and check that build() expects the right args.
|
||||
check_build_signature(self.build, shapes_dict)
|
||||
with backend.name_scope(self.name):
|
||||
if utils.is_default(self.build):
|
||||
if might_have_unbuilt_state(self):
|
||||
@ -751,29 +748,7 @@ class Layer(Operation):
|
||||
if not status:
|
||||
failure = True
|
||||
else:
|
||||
run_build = True
|
||||
build_args = set(
|
||||
inspect.signature(self.build).parameters.keys()
|
||||
)
|
||||
for key in shapes_dict.keys():
|
||||
if key not in build_args:
|
||||
run_build = False
|
||||
if run_build:
|
||||
self.build(**shapes_dict)
|
||||
else:
|
||||
raise ValueError(
|
||||
"In a layer with multiple tensor arguments "
|
||||
"in call(), the build() method should accept "
|
||||
"corresponding `*_shape` arguments, e.g. "
|
||||
"if the call signature is "
|
||||
"`def call(self, x1, x2)` "
|
||||
"then the build signature should be "
|
||||
"`def build(self, x1_shape, x2_shape)`. "
|
||||
"Keras will not build this layer automatically "
|
||||
"since it does not conform to this. "
|
||||
"Expected the following build keys: "
|
||||
f"{list(shapes_dict.keys())}"
|
||||
)
|
||||
self.build(**shapes_dict)
|
||||
if failure:
|
||||
if call_spec.eager:
|
||||
# Will let the actual eager call do the state-building
|
||||
@ -821,7 +796,7 @@ class Layer(Operation):
|
||||
# Case: all input keyword arguments were plain tensors.
|
||||
input_tensors = {
|
||||
# We strip the `_shape` suffix to recover kwarg names.
|
||||
k[:-6]: backend.KerasTensor(shape)
|
||||
k.removesuffix("_shape"): backend.KerasTensor(shape)
|
||||
for k, shape in shapes_dict.items()
|
||||
}
|
||||
try:
|
||||
@ -992,16 +967,17 @@ def get_arguments_dict(fn, args, kwargs):
|
||||
return arg_dict
|
||||
|
||||
|
||||
def get_shapes_dict(call_spec):
|
||||
def get_shapes_dict(target_fn, call_spec):
|
||||
"""Convert the call() arguments dict into a dict of input shape arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
>>> get_shapes_dict(call_spec)
|
||||
>>> get_shapes_dict(self.build, call_spec)
|
||||
{"input_a_shape": (2, 3)}
|
||||
```
|
||||
"""
|
||||
expected_names = check_shapes_signature(target_fn, call_spec)
|
||||
shapes_dict = {}
|
||||
for k, v in call_spec.tensor_arguments_dict.items():
|
||||
if k == "mask" or k.startswith("mask_"):
|
||||
@ -1010,6 +986,8 @@ def get_shapes_dict(call_spec):
|
||||
if k == "kwargs" or k == "args":
|
||||
# Do not include catch-alls in shapes dict
|
||||
continue
|
||||
if expected_names is not None and f"{k}_shape" not in expected_names:
|
||||
continue
|
||||
if k in call_spec.nested_tensor_argument_names:
|
||||
shapes_dict[f"{k}_shape"] = nest.map_structure(
|
||||
lambda x: backend.standardize_shape(x.shape), v
|
||||
@ -1019,22 +997,26 @@ def get_shapes_dict(call_spec):
|
||||
return shapes_dict
|
||||
|
||||
|
||||
def check_build_signature(build_fn, shapes_dict):
|
||||
"""Asserts that the argument names in build_fn match entries in shapes_dict.
|
||||
def check_shapes_signature(target_fn, call_spec):
|
||||
"""Asserts that the argument names in `target_fn` match arguments in `call`.
|
||||
|
||||
For instance if call() has the signature `def call(self, a, b)`
|
||||
then we'll see `shapes_dict == {"a_shape": (...), "b_shape": (...)}
|
||||
and we expect build() to have signature `def build(self, a_shape, b_shape)`.
|
||||
We use this to check that `build()` and `compute_output_shape()` arguments
|
||||
align with `call()` arguments.
|
||||
|
||||
When there is a single tensor argument, we pass it positionally and thus
|
||||
don't check names (if we did, it would force call() to always take
|
||||
`input` as its first argument, which is usually not the case).
|
||||
For instance if `build()` has the signature
|
||||
`def build(self, a_shape, b_shape)` we expect `call()` to accept the
|
||||
arguments `a` and `b`.
|
||||
|
||||
When there is a single argument accepted by `target_fn`, we do allow any
|
||||
name and do not check the call signature.
|
||||
|
||||
Returns:
|
||||
The list of arguments names expected by the `target_fn` or
|
||||
`None` if any passed name is acceptable.
|
||||
"""
|
||||
if len(shapes_dict) == 1:
|
||||
return
|
||||
if utils.is_default(build_fn):
|
||||
return
|
||||
sig = inspect.signature(build_fn)
|
||||
if utils.is_default(target_fn):
|
||||
return None
|
||||
sig = inspect.signature(target_fn)
|
||||
expected_names = []
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind in (
|
||||
@ -1043,14 +1025,29 @@ def check_build_signature(build_fn, shapes_dict):
|
||||
param.KEYWORD_ONLY,
|
||||
):
|
||||
expected_names.append(name)
|
||||
if set(expected_names) != set(shapes_dict.keys()):
|
||||
comma_separated = ", ".join(shapes_dict.keys())
|
||||
raise ValueError(
|
||||
"For a `call()` method with more than one tensor argument, "
|
||||
"the arguments of the `build()` method should match the "
|
||||
"tensor arguments of `call()` method. Here we expect the signature "
|
||||
f"`build(self, {comma_separated})`."
|
||||
if len(expected_names) == 1:
|
||||
return None
|
||||
for name in expected_names:
|
||||
method_name = target_fn.__name__
|
||||
error_preamble = (
|
||||
f"For a `{method_name}()` method with more than one argument, all "
|
||||
"arguments should have a `_shape` suffix and match an argument "
|
||||
f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` "
|
||||
"would match `call(self, foo, bar)`."
|
||||
)
|
||||
if not name.endswith("_shape"):
|
||||
raise ValueError(
|
||||
f"{error_preamble} Received `{method_name}()` argument "
|
||||
f"`{name}`, which does not end in `_shape`."
|
||||
)
|
||||
expected_call_arg = name.removesuffix("_shape")
|
||||
if expected_call_arg not in call_spec.arguments_dict:
|
||||
raise ValueError(
|
||||
f"{error_preamble} Received `{method_name}()` argument "
|
||||
f"`{name}`, but `call()` does not have argument "
|
||||
f"`{expected_call_arg}`."
|
||||
)
|
||||
return expected_names
|
||||
|
||||
|
||||
class CallContext:
|
||||
|
@ -588,3 +588,41 @@ class LayerTest(testing.TestCase):
|
||||
self.assertEqual(len(layer.trainable_variables), 2)
|
||||
self.assertEqual(len(layer.non_trainable_weights), 2)
|
||||
self.assertEqual(len(layer.non_trainable_variables), 3)
|
||||
|
||||
def test_build_signature_errors(self):
|
||||
class NoShapeSuffix(layers.Layer):
|
||||
def build(self, foo_shape, bar):
|
||||
self._built = True
|
||||
|
||||
def call(self, foo, bar):
|
||||
return foo + bar
|
||||
|
||||
class NonMatchingArgument(layers.Layer):
|
||||
def build(self, foo_shape, baz_shape):
|
||||
self._built = True
|
||||
|
||||
def call(self, foo, bar):
|
||||
return foo + bar
|
||||
|
||||
class MatchingArguments(layers.Layer):
|
||||
def build(self, foo_shape, bar_shape):
|
||||
self._built = True
|
||||
|
||||
def call(self, foo, bar):
|
||||
return foo + bar
|
||||
|
||||
foo = backend.numpy.ones((4, 4))
|
||||
bar = backend.numpy.ones((4, 4))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"argument `bar`, which does not end in `_shape`",
|
||||
):
|
||||
NoShapeSuffix()(foo, bar)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"`baz_shape`, but `call\(\)` does not have argument `baz`",
|
||||
):
|
||||
NonMatchingArgument()(foo, bar)
|
||||
|
||||
MatchingArguments()(foo, bar)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Dropout")
|
||||
class Dropout(layers.Layer):
|
||||
class Dropout(Layer):
|
||||
"""Applies dropout to the input.
|
||||
|
||||
The `Dropout` layer randomly sets input units to 0 with a frequency of
|
||||
|
@ -51,7 +51,7 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
|
||||
for the linear transformation of the recurrent state. Default: 0.
|
||||
|
||||
Call arguments:
|
||||
inputs: A 2D tensor, with shape `(batch, features)`.
|
||||
sequence: A 2D tensor, with shape `(batch, features)`.
|
||||
states: A 2D tensor with shape `(batch, units)`, which is the state
|
||||
from the previous time step.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
@ -151,14 +151,14 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=False):
|
||||
def call(self, sequence, states, training=False):
|
||||
prev_output = states[0] if isinstance(states, (list, tuple)) else states
|
||||
dp_mask = self.get_dropout_mask(inputs)
|
||||
dp_mask = self.get_dropout_mask(sequence)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask(prev_output)
|
||||
|
||||
if training and dp_mask is not None:
|
||||
inputs *= dp_mask
|
||||
h = ops.matmul(inputs, self.kernel)
|
||||
sequence *= dp_mask
|
||||
h = ops.matmul(sequence, self.kernel)
|
||||
if self.bias is not None:
|
||||
h += self.bias
|
||||
|
||||
@ -265,7 +265,7 @@ class SimpleRNN(RNN):
|
||||
Unrolling is only suitable for short sequences.
|
||||
|
||||
Call arguments:
|
||||
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
|
||||
sequence: A 3D tensor, with shape `[batch, timesteps, feature]`.
|
||||
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
|
||||
a given timestep should be masked. An individual `True` entry
|
||||
indicates that the corresponding timestep should be utilized,
|
||||
|
@ -255,13 +255,21 @@ class TestCase(unittest.TestCase):
|
||||
# Build test.
|
||||
if input_shape is not None:
|
||||
layer = layer_cls(**init_kwargs)
|
||||
layer.build(input_shape)
|
||||
if isinstance(input_shape, dict):
|
||||
layer.build(**input_shape)
|
||||
else:
|
||||
layer.build(input_shape)
|
||||
run_build_asserts(layer)
|
||||
|
||||
# Symbolic call test.
|
||||
keras_tensor_inputs = create_keras_tensors(input_shape, input_dtype)
|
||||
layer = layer_cls(**init_kwargs)
|
||||
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
|
||||
if isinstance(keras_tensor_inputs, dict):
|
||||
keras_tensor_outputs = layer(
|
||||
**keras_tensor_inputs, **call_kwargs
|
||||
)
|
||||
else:
|
||||
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
|
||||
run_build_asserts(layer)
|
||||
run_output_asserts(layer, keras_tensor_outputs, eager=False)
|
||||
|
||||
@ -274,7 +282,10 @@ class TestCase(unittest.TestCase):
|
||||
if input_data is None:
|
||||
input_data = create_eager_tensors(input_shape, input_dtype)
|
||||
layer = layer_cls(**init_kwargs)
|
||||
output_data = layer(input_data, **call_kwargs)
|
||||
if isinstance(input_data, dict):
|
||||
output_data = layer(**input_data, **call_kwargs)
|
||||
else:
|
||||
output_data = layer(input_data, **call_kwargs)
|
||||
run_output_asserts(layer, output_data, eager=True)
|
||||
|
||||
|
||||
@ -287,7 +298,7 @@ def create_keras_tensors(input_shape, dtype):
|
||||
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
|
||||
if isinstance(input_shape, dict):
|
||||
return {
|
||||
k: keras_tensor.KerasTensor(v, dtype=dtype)
|
||||
k.removesuffix("_shape"): keras_tensor.KerasTensor(v, dtype=dtype)
|
||||
for k, v in input_shape.items()
|
||||
}
|
||||
|
||||
@ -320,4 +331,7 @@ def create_eager_tensors(input_shape, dtype):
|
||||
if isinstance(input_shape, list):
|
||||
return [create_fn(s, dtype=dtype) for s in input_shape]
|
||||
if isinstance(input_shape, dict):
|
||||
return {k: create_fn(v, dtype=dtype) for k, v in input_shape.items()}
|
||||
return {
|
||||
k.removesuffix("_shape"): create_fn(v, dtype=dtype)
|
||||
for k, v in input_shape.items()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user