376 lines
13 KiB
Python
376 lines
13 KiB
Python
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.merging.base_merge import Merge
|
|
from keras_core.utils.numerical_utils import normalize
|
|
|
|
|
|
def batch_dot(x, y, axes=None):
|
|
"""Batchwise dot product.
|
|
|
|
`batch_dot` is used to compute dot product of `x` and `y` when
|
|
`x` and `y` are data in batch, i.e. in a shape of `(batch_size, :)`.
|
|
`batch_dot` results in a tensor or variable with less dimensions
|
|
than the input. If the number of dimensions is reduced to 1,
|
|
we use `expand_dims` to make sure that ndim is at least 2.
|
|
|
|
Shape inference:
|
|
|
|
Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
|
|
If `axes` is (1, 2), to find the output shape of resultant tensor,
|
|
loop through each dimension in `x`'s shape and `y`'s shape:
|
|
|
|
* `x.shape[0]` : 100 : append to output shape
|
|
* `x.shape[1]` : 20 : do not append to output shape, dimension 1 of
|
|
`x` has been summed over. (`dot_axes[0]` = 1)
|
|
* `y.shape[0]` : 100 : do not append to output shape, always ignore
|
|
first dimension of `y`
|
|
* `y.shape[1]` : 30 : append to output shape
|
|
* `y.shape[2]` : 20 : do not append to output shape, dimension 2 of
|
|
`y` has been summed over.
|
|
(`dot_axes[1]` = 2) `output_shape` = `(100, 30)`
|
|
|
|
Examples:
|
|
|
|
>>> x_batch = np.ones(shape=(32, 20, 1))
|
|
>>> y_batch = np.ones(shape=(32, 30, 20))
|
|
>>> xy_batch_dot = batch_dot(x_batch, y_batch, axes=(1, 2))
|
|
|
|
Args:
|
|
x: Keras tensor or variable with `ndim >= 2`.
|
|
y: Keras tensor or variable with `ndim >= 2`.
|
|
axes: Tuple or list of integers with target dimensions, or single
|
|
integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`
|
|
should be equal.
|
|
|
|
Returns:
|
|
A tensor with shape equal to the concatenation of `x`'s shape
|
|
(less the dimension that was summed over) and `y`'s shape (less the
|
|
batch dimension and the dimension that was summed over). If the final
|
|
rank is 1, we reshape it to `(batch_size, 1)`.
|
|
"""
|
|
|
|
x_shape = x.shape
|
|
y_shape = y.shape
|
|
|
|
x_ndim = len(x_shape)
|
|
y_ndim = len(y_shape)
|
|
|
|
if x_ndim < 2 or y_ndim < 2:
|
|
raise ValueError(
|
|
f"Cannot do batch_dot on inputs "
|
|
f"with rank < 2. "
|
|
f"Received inputs with shapes "
|
|
f"{x_shape} and {y_shape}."
|
|
)
|
|
|
|
x_batch_size = x_shape[0]
|
|
y_batch_size = y_shape[0]
|
|
|
|
if x_batch_size is not None and y_batch_size is not None:
|
|
if x_batch_size != y_batch_size:
|
|
raise ValueError(
|
|
f"Cannot do batch_dot on inputs "
|
|
f"with different batch sizes. "
|
|
f"Received inputs with shapes "
|
|
f"{x_shape} and {y_shape}."
|
|
)
|
|
if isinstance(axes, int):
|
|
axes = [axes, axes]
|
|
|
|
if axes is None:
|
|
if y_ndim == 2:
|
|
axes = [x_ndim - 1, y_ndim - 1]
|
|
else:
|
|
axes = [x_ndim - 1, y_ndim - 2]
|
|
|
|
if any(isinstance(a, (list, tuple)) for a in axes):
|
|
raise ValueError(
|
|
f"Multiple target dimensions are not supported. "
|
|
f"Expected: None, int, (int, int), "
|
|
f"Provided: {axes} "
|
|
)
|
|
|
|
# if tuple, convert to list.
|
|
axes = list(axes)
|
|
|
|
# convert negative indices.
|
|
if axes[0] < 0:
|
|
axes[0] += x_ndim
|
|
if axes[1] < 0:
|
|
axes[1] += y_ndim
|
|
|
|
# sanity checks
|
|
if 0 in axes:
|
|
raise ValueError(
|
|
"Cannot perform batch_dot over axis 0. "
|
|
"If your inputs are not batched, "
|
|
"add a dummy batch dimension to your "
|
|
"inputs using keras_core.operations.expand_dims(x, 0)"
|
|
)
|
|
a0, a1 = axes
|
|
d1 = x_shape[a0]
|
|
d2 = y_shape[a1]
|
|
|
|
if d1 is not None and d2 is not None and d1 != d2:
|
|
raise ValueError(
|
|
f"Cannot do batch_dot on inputs with shapes "
|
|
f"{x_shape} and {y_shape} with axes={axes}. "
|
|
f"x.shape[{axes[0]}] != y.shape[{axes[1]}] ({d1} != {d2})."
|
|
)
|
|
|
|
# backup ndims. Need them later.
|
|
orig_x_ndim = x_ndim
|
|
orig_y_ndim = y_ndim
|
|
|
|
# if rank is 2, expand to 3.
|
|
if x_ndim == 2:
|
|
x = ops.expand_dims(x, 1)
|
|
a0 += 1
|
|
x_ndim += 1
|
|
if y_ndim == 2:
|
|
y = ops.expand_dims(y, 2)
|
|
y_ndim += 1
|
|
|
|
# bring x's dimension to be reduced to last axis.
|
|
if a0 != x_ndim - 1:
|
|
pattern = list(range(x_ndim))
|
|
for i in range(a0, x_ndim - 1):
|
|
pattern[i] = pattern[i + 1]
|
|
pattern[-1] = a0
|
|
x = ops.transpose(x, pattern)
|
|
|
|
# bring y's dimension to be reduced to axis 1.
|
|
if a1 != 1:
|
|
pattern = list(range(y_ndim))
|
|
for i in range(a1, 1, -1):
|
|
pattern[i] = pattern[i - 1]
|
|
pattern[1] = a1
|
|
y = ops.transpose(y, pattern)
|
|
|
|
# normalize both inputs to rank 3.
|
|
if x_ndim > 3:
|
|
# squash middle dimensions of x.
|
|
x_shape = list(ops.shape(x))
|
|
x_mid_dims = x_shape[1:-1]
|
|
x_squashed_shape = ops.stack([x_shape[0], -1, x_shape[-1]])
|
|
x = ops.reshape(x, x_squashed_shape)
|
|
x_squashed = True
|
|
else:
|
|
x_squashed = False
|
|
|
|
if y_ndim > 3:
|
|
# squash trailing dimensions of y.
|
|
y_shape = list(ops.shape(y))
|
|
y_trail_dims = y_shape[2:]
|
|
y_squashed_shape = ops.stack([y_shape[0], y_shape[1], -1])
|
|
y = ops.reshape(y, y_squashed_shape)
|
|
y_squashed = True
|
|
else:
|
|
y_squashed = False
|
|
|
|
result = ops.matmul(x, y)
|
|
|
|
# if inputs were squashed, we have to reshape the matmul output.
|
|
output_shape = ops.shape(result)
|
|
do_reshape = False
|
|
|
|
if x_squashed:
|
|
output_shape = ops.concat(
|
|
[output_shape[:1], x_mid_dims, output_shape[-1:]], 0
|
|
)
|
|
do_reshape = True
|
|
|
|
if y_squashed:
|
|
output_shape = ops.concat([output_shape[:-1], y_trail_dims], 0)
|
|
do_reshape = True
|
|
|
|
if do_reshape:
|
|
result = ops.reshape(result, output_shape)
|
|
|
|
# if the inputs were originally rank 2, we remove the added 1 dim.
|
|
if orig_x_ndim == 2:
|
|
result = ops.squeeze(result, 1)
|
|
elif orig_y_ndim == 2:
|
|
result = ops.squeeze(result, -1)
|
|
|
|
return result
|
|
|
|
|
|
@keras_core_export("keras_core.layers.Dot")
|
|
class Dot(Merge):
|
|
"""Computes element-wise dot product of two tensors.
|
|
|
|
It takes a list of inputs of size 2, and the axes
|
|
corresponding to each input along with the dot product
|
|
is to be performed.
|
|
|
|
Let's say `x` and `y` are the two input tensors with shapes
|
|
`(2, 3, 5)` and `(2, 10, 3)`. The batch dimension should be
|
|
of same size for both the inputs, and `axes` should correspond
|
|
to the dimensions that have the same size in the corresponding
|
|
inputs. e.g. with `axes=(1, 2)`, the dot product of `x`, and `y`
|
|
will result in a tensor with shape `(2, 5, 10)`
|
|
|
|
Examples:
|
|
|
|
>>> x = np.arange(10).reshape(1, 5, 2)
|
|
>>> y = np.arange(10, 20).reshape(1, 2, 5)
|
|
>>> keras_core.layers.Dot(axes=(1, 2))([x, y])
|
|
|
|
Usage in a Keras model:
|
|
|
|
>>> x1 = keras_core.layers.Dense(8)(np.arange(10).reshape(5, 2))
|
|
>>> x2 = keras_core.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
|
|
>>> y = keras_core.layers.Dot(axes=1)([x1, x2])
|
|
|
|
Args:
|
|
axes: Integer or tuple of integers, axis or axes along which to
|
|
take the dot product. If a tuple, should be two integers
|
|
corresponding to the desired axis from the first input and the
|
|
desired axis from the second input, respectively. Note that the
|
|
size of the two selected axes must match.
|
|
normalize: Whether to L2-normalize samples along the dot product axis
|
|
before taking the dot product. If set to `True`, then
|
|
the output of the dot product is the cosine proximity
|
|
between the two samples.
|
|
**kwargs: Standard layer keyword arguments.
|
|
|
|
Returns:
|
|
A tensor, the dot product of the samples from the inputs.
|
|
"""
|
|
|
|
def __init__(self, axes, normalize=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if not isinstance(axes, int):
|
|
if not isinstance(axes, (list, tuple)):
|
|
raise TypeError(
|
|
f"Invalid type for argument `axes`: it should be "
|
|
f"a list or an int. Received: axes={axes}"
|
|
)
|
|
if len(axes) != 2:
|
|
raise ValueError(
|
|
f"Invalid format for argument `axes`: it should contain "
|
|
f"two elements. Received: axes={axes}"
|
|
)
|
|
if not isinstance(axes[0], int) or not isinstance(axes[1], int):
|
|
raise ValueError(
|
|
f"Invalid format for argument `axes`: list elements should "
|
|
f"be integers. Received: axes={axes}"
|
|
)
|
|
self.axes = axes
|
|
self.normalize = normalize
|
|
self.supports_masking = True
|
|
self._reshape_required = False
|
|
|
|
def build(self, input_shape):
|
|
# Used purely for shape validation.
|
|
if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
|
|
raise ValueError(
|
|
f"A `Dot` layer should be called on a list of 2 inputs. "
|
|
f"Received: input_shape={input_shape}"
|
|
)
|
|
shape1 = input_shape[0]
|
|
shape2 = input_shape[1]
|
|
if shape1 is None or shape2 is None:
|
|
return
|
|
if isinstance(self.axes, int):
|
|
if self.axes < 0:
|
|
axes = [self.axes % len(shape1), self.axes % len(shape2)]
|
|
else:
|
|
axes = [self.axes] * 2
|
|
else:
|
|
axes = self.axes
|
|
if shape1[axes[0]] != shape2[axes[1]]:
|
|
raise ValueError(
|
|
f"Incompatible input shapes: "
|
|
f"axis values {shape1[axes[0]]} (at axis {axes[0]}) != "
|
|
f"{shape2[axes[1]]} (at axis {axes[1]}). "
|
|
f"Full input shapes: {shape1}, {shape2}"
|
|
)
|
|
self.built = True
|
|
|
|
def _merge_function(self, inputs):
|
|
if len(inputs) != 2:
|
|
raise ValueError(
|
|
f"A `Dot` layer should be called on exactly 2 inputs. "
|
|
f"Received: inputs={inputs}"
|
|
)
|
|
x1 = inputs[0]
|
|
x2 = inputs[1]
|
|
|
|
if isinstance(self.axes, int):
|
|
if self.axes < 0:
|
|
axes = [
|
|
self.axes % len(x1.shape),
|
|
self.axes % len(x2.shape),
|
|
]
|
|
else:
|
|
axes = [self.axes] * 2
|
|
else:
|
|
axes = []
|
|
for i in range(len(self.axes)):
|
|
if self.axes[i] < 0:
|
|
axes.append(self.axes[i] % len(inputs[i].shape))
|
|
else:
|
|
axes.append(self.axes[i])
|
|
|
|
if self.normalize:
|
|
x1 = normalize(x1, axis=axes[0])
|
|
x2 = normalize(x2, axis=axes[1])
|
|
output = batch_dot(x1, x2, axes)
|
|
return output
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
|
|
raise ValueError(
|
|
f"A `Dot` layer should be called on a list of 2 inputs. "
|
|
f"Received: input_shape={input_shape}"
|
|
)
|
|
shape1 = list(input_shape[0])
|
|
shape2 = list(input_shape[1])
|
|
if isinstance(self.axes, int):
|
|
if self.axes < 0:
|
|
axes = [self.axes % len(shape1), self.axes % len(shape2)]
|
|
else:
|
|
axes = [self.axes] * 2
|
|
else:
|
|
axes = self.axes
|
|
shape1.pop(axes[0])
|
|
shape2.pop(axes[1])
|
|
shape2.pop(0)
|
|
output_shape = shape1 + shape2
|
|
if len(output_shape) == 1:
|
|
output_shape += [1]
|
|
return tuple(output_shape)
|
|
|
|
def compute_mask(self, inputs, mask=None):
|
|
return None
|
|
|
|
def get_config(self):
|
|
config = {
|
|
"axes": self.axes,
|
|
"normalize": self.normalize,
|
|
}
|
|
base_config = super().get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
|
|
@keras_core_export("keras_core.layers.dot")
|
|
def dot(inputs, axes=-1, **kwargs):
|
|
"""Functional interface to the `Dot` layer.
|
|
|
|
Args:
|
|
inputs: A list of input tensors (at least 2).
|
|
axes: Integer or tuple of integers,
|
|
axis or axes along which to take the dot product.
|
|
normalize: Whether to L2-normalize samples along the
|
|
dot product axis before taking the dot product.
|
|
If set to True, then the output of the dot product
|
|
is the cosine proximity between the two samples.
|
|
**kwargs: Standard layer keyword arguments.
|
|
|
|
Returns:
|
|
A tensor, the dot product of the samples from the inputs.
|
|
"""
|
|
return Dot(axes=axes, **kwargs)(inputs)
|