keras/keras_core/layers/merging/dot.py
Francois Chollet 8852928f6e Add DenseNet
2023-05-22 16:51:41 -07:00

376 lines
13 KiB
Python

from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge
from keras_core.utils.numerical_utils import normalize
def batch_dot(x, y, axes=None):
"""Batchwise dot product.
`batch_dot` is used to compute dot product of `x` and `y` when
`x` and `y` are data in batch, i.e. in a shape of `(batch_size, :)`.
`batch_dot` results in a tensor or variable with less dimensions
than the input. If the number of dimensions is reduced to 1,
we use `expand_dims` to make sure that ndim is at least 2.
Shape inference:
Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
If `axes` is (1, 2), to find the output shape of resultant tensor,
loop through each dimension in `x`'s shape and `y`'s shape:
* `x.shape[0]` : 100 : append to output shape
* `x.shape[1]` : 20 : do not append to output shape, dimension 1 of
`x` has been summed over. (`dot_axes[0]` = 1)
* `y.shape[0]` : 100 : do not append to output shape, always ignore
first dimension of `y`
* `y.shape[1]` : 30 : append to output shape
* `y.shape[2]` : 20 : do not append to output shape, dimension 2 of
`y` has been summed over.
(`dot_axes[1]` = 2) `output_shape` = `(100, 30)`
Examples:
>>> x_batch = np.ones(shape=(32, 20, 1))
>>> y_batch = np.ones(shape=(32, 30, 20))
>>> xy_batch_dot = batch_dot(x_batch, y_batch, axes=(1, 2))
Args:
x: Keras tensor or variable with `ndim >= 2`.
y: Keras tensor or variable with `ndim >= 2`.
axes: Tuple or list of integers with target dimensions, or single
integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`
should be equal.
Returns:
A tensor with shape equal to the concatenation of `x`'s shape
(less the dimension that was summed over) and `y`'s shape (less the
batch dimension and the dimension that was summed over). If the final
rank is 1, we reshape it to `(batch_size, 1)`.
"""
x_shape = x.shape
y_shape = y.shape
x_ndim = len(x_shape)
y_ndim = len(y_shape)
if x_ndim < 2 or y_ndim < 2:
raise ValueError(
f"Cannot do batch_dot on inputs "
f"with rank < 2. "
f"Received inputs with shapes "
f"{x_shape} and {y_shape}."
)
x_batch_size = x_shape[0]
y_batch_size = y_shape[0]
if x_batch_size is not None and y_batch_size is not None:
if x_batch_size != y_batch_size:
raise ValueError(
f"Cannot do batch_dot on inputs "
f"with different batch sizes. "
f"Received inputs with shapes "
f"{x_shape} and {y_shape}."
)
if isinstance(axes, int):
axes = [axes, axes]
if axes is None:
if y_ndim == 2:
axes = [x_ndim - 1, y_ndim - 1]
else:
axes = [x_ndim - 1, y_ndim - 2]
if any(isinstance(a, (list, tuple)) for a in axes):
raise ValueError(
f"Multiple target dimensions are not supported. "
f"Expected: None, int, (int, int), "
f"Provided: {axes} "
)
# if tuple, convert to list.
axes = list(axes)
# convert negative indices.
if axes[0] < 0:
axes[0] += x_ndim
if axes[1] < 0:
axes[1] += y_ndim
# sanity checks
if 0 in axes:
raise ValueError(
"Cannot perform batch_dot over axis 0. "
"If your inputs are not batched, "
"add a dummy batch dimension to your "
"inputs using keras_core.operations.expand_dims(x, 0)"
)
a0, a1 = axes
d1 = x_shape[a0]
d2 = y_shape[a1]
if d1 is not None and d2 is not None and d1 != d2:
raise ValueError(
f"Cannot do batch_dot on inputs with shapes "
f"{x_shape} and {y_shape} with axes={axes}. "
f"x.shape[{axes[0]}] != y.shape[{axes[1]}] ({d1} != {d2})."
)
# backup ndims. Need them later.
orig_x_ndim = x_ndim
orig_y_ndim = y_ndim
# if rank is 2, expand to 3.
if x_ndim == 2:
x = ops.expand_dims(x, 1)
a0 += 1
x_ndim += 1
if y_ndim == 2:
y = ops.expand_dims(y, 2)
y_ndim += 1
# bring x's dimension to be reduced to last axis.
if a0 != x_ndim - 1:
pattern = list(range(x_ndim))
for i in range(a0, x_ndim - 1):
pattern[i] = pattern[i + 1]
pattern[-1] = a0
x = ops.transpose(x, pattern)
# bring y's dimension to be reduced to axis 1.
if a1 != 1:
pattern = list(range(y_ndim))
for i in range(a1, 1, -1):
pattern[i] = pattern[i - 1]
pattern[1] = a1
y = ops.transpose(y, pattern)
# normalize both inputs to rank 3.
if x_ndim > 3:
# squash middle dimensions of x.
x_shape = list(ops.shape(x))
x_mid_dims = x_shape[1:-1]
x_squashed_shape = ops.stack([x_shape[0], -1, x_shape[-1]])
x = ops.reshape(x, x_squashed_shape)
x_squashed = True
else:
x_squashed = False
if y_ndim > 3:
# squash trailing dimensions of y.
y_shape = list(ops.shape(y))
y_trail_dims = y_shape[2:]
y_squashed_shape = ops.stack([y_shape[0], y_shape[1], -1])
y = ops.reshape(y, y_squashed_shape)
y_squashed = True
else:
y_squashed = False
result = ops.matmul(x, y)
# if inputs were squashed, we have to reshape the matmul output.
output_shape = ops.shape(result)
do_reshape = False
if x_squashed:
output_shape = ops.concat(
[output_shape[:1], x_mid_dims, output_shape[-1:]], 0
)
do_reshape = True
if y_squashed:
output_shape = ops.concat([output_shape[:-1], y_trail_dims], 0)
do_reshape = True
if do_reshape:
result = ops.reshape(result, output_shape)
# if the inputs were originally rank 2, we remove the added 1 dim.
if orig_x_ndim == 2:
result = ops.squeeze(result, 1)
elif orig_y_ndim == 2:
result = ops.squeeze(result, -1)
return result
@keras_core_export("keras_core.layers.Dot")
class Dot(Merge):
"""Computes element-wise dot product of two tensors.
It takes a list of inputs of size 2, and the axes
corresponding to each input along with the dot product
is to be performed.
Let's say `x` and `y` are the two input tensors with shapes
`(2, 3, 5)` and `(2, 10, 3)`. The batch dimension should be
of same size for both the inputs, and `axes` should correspond
to the dimensions that have the same size in the corresponding
inputs. e.g. with `axes=(1, 2)`, the dot product of `x`, and `y`
will result in a tensor with shape `(2, 5, 10)`
Examples:
>>> x = np.arange(10).reshape(1, 5, 2)
>>> y = np.arange(10, 20).reshape(1, 2, 5)
>>> keras_core.layers.Dot(axes=(1, 2))([x, y])
Usage in a Keras model:
>>> x1 = keras_core.layers.Dense(8)(np.arange(10).reshape(5, 2))
>>> x2 = keras_core.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
>>> y = keras_core.layers.Dot(axes=1)([x1, x2])
Args:
axes: Integer or tuple of integers, axis or axes along which to
take the dot product. If a tuple, should be two integers
corresponding to the desired axis from the first input and the
desired axis from the second input, respectively. Note that the
size of the two selected axes must match.
normalize: Whether to L2-normalize samples along the dot product axis
before taking the dot product. If set to `True`, then
the output of the dot product is the cosine proximity
between the two samples.
**kwargs: Standard layer keyword arguments.
Returns:
A tensor, the dot product of the samples from the inputs.
"""
def __init__(self, axes, normalize=False, **kwargs):
super().__init__(**kwargs)
if not isinstance(axes, int):
if not isinstance(axes, (list, tuple)):
raise TypeError(
f"Invalid type for argument `axes`: it should be "
f"a list or an int. Received: axes={axes}"
)
if len(axes) != 2:
raise ValueError(
f"Invalid format for argument `axes`: it should contain "
f"two elements. Received: axes={axes}"
)
if not isinstance(axes[0], int) or not isinstance(axes[1], int):
raise ValueError(
f"Invalid format for argument `axes`: list elements should "
f"be integers. Received: axes={axes}"
)
self.axes = axes
self.normalize = normalize
self.supports_masking = True
self._reshape_required = False
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
raise ValueError(
f"A `Dot` layer should be called on a list of 2 inputs. "
f"Received: input_shape={input_shape}"
)
shape1 = input_shape[0]
shape2 = input_shape[1]
if shape1 is None or shape2 is None:
return
if isinstance(self.axes, int):
if self.axes < 0:
axes = [self.axes % len(shape1), self.axes % len(shape2)]
else:
axes = [self.axes] * 2
else:
axes = self.axes
if shape1[axes[0]] != shape2[axes[1]]:
raise ValueError(
f"Incompatible input shapes: "
f"axis values {shape1[axes[0]]} (at axis {axes[0]}) != "
f"{shape2[axes[1]]} (at axis {axes[1]}). "
f"Full input shapes: {shape1}, {shape2}"
)
self.built = True
def _merge_function(self, inputs):
if len(inputs) != 2:
raise ValueError(
f"A `Dot` layer should be called on exactly 2 inputs. "
f"Received: inputs={inputs}"
)
x1 = inputs[0]
x2 = inputs[1]
if isinstance(self.axes, int):
if self.axes < 0:
axes = [
self.axes % len(x1.shape),
self.axes % len(x2.shape),
]
else:
axes = [self.axes] * 2
else:
axes = []
for i in range(len(self.axes)):
if self.axes[i] < 0:
axes.append(self.axes[i] % len(inputs[i].shape))
else:
axes.append(self.axes[i])
if self.normalize:
x1 = normalize(x1, axis=axes[0])
x2 = normalize(x2, axis=axes[1])
output = batch_dot(x1, x2, axes)
return output
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
raise ValueError(
f"A `Dot` layer should be called on a list of 2 inputs. "
f"Received: input_shape={input_shape}"
)
shape1 = list(input_shape[0])
shape2 = list(input_shape[1])
if isinstance(self.axes, int):
if self.axes < 0:
axes = [self.axes % len(shape1), self.axes % len(shape2)]
else:
axes = [self.axes] * 2
else:
axes = self.axes
shape1.pop(axes[0])
shape2.pop(axes[1])
shape2.pop(0)
output_shape = shape1 + shape2
if len(output_shape) == 1:
output_shape += [1]
return tuple(output_shape)
def compute_mask(self, inputs, mask=None):
return None
def get_config(self):
config = {
"axes": self.axes,
"normalize": self.normalize,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_core_export("keras_core.layers.dot")
def dot(inputs, axes=-1, **kwargs):
"""Functional interface to the `Dot` layer.
Args:
inputs: A list of input tensors (at least 2).
axes: Integer or tuple of integers,
axis or axes along which to take the dot product.
normalize: Whether to L2-normalize samples along the
dot product axis before taking the dot product.
If set to True, then the output of the dot product
is the cosine proximity between the two samples.
**kwargs: Standard layer keyword arguments.
Returns:
A tensor, the dot product of the samples from the inputs.
"""
return Dot(axes=axes, **kwargs)(inputs)