116 lines
4.7 KiB
Python
116 lines
4.7 KiB
Python
"""Wrapper layer to apply every temporal slice of an input."""
|
|
|
|
from keras_core import backend
|
|
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.core.wrapper import Wrapper
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
@keras_core_export("keras_core.layers.TimeDistributed")
|
|
class TimeDistributed(Wrapper):
|
|
"""This wrapper allows to apply a layer to every temporal slice of an input.
|
|
|
|
Every input should be at least 3D, and the dimension of index one of the
|
|
first input will be considered to be the temporal dimension.
|
|
|
|
Consider a batch of 32 video samples, where each sample is a 128x128 RGB
|
|
image with `channels_last` data format, across 10 timesteps.
|
|
The batch input shape is `(32, 10, 128, 128, 3)`.
|
|
|
|
You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
|
|
of the 10 timesteps, independently:
|
|
|
|
>>> inputs = layers.Input(shape=(10, 128, 128, 3), batch_size=32)
|
|
>>> conv_2d_layer = layers.Conv2D(64, (3, 3))
|
|
>>> outputs = layers.TimeDistributed(conv_2d_layer)(inputs)
|
|
>>> outputs.shape
|
|
(32, 10, 126, 126, 64)
|
|
|
|
Because `TimeDistributed` applies the same instance of `Conv2D` to each of
|
|
the timestamps, the same set of weights are used at each timestamp.
|
|
|
|
Args:
|
|
layer: a `keras_core.layers.Layer` instance.
|
|
|
|
Call arguments:
|
|
inputs: Input tensor of shape (batch, time, ...) or nested tensors,
|
|
and each of which has shape (batch, time, ...).
|
|
training: Python boolean indicating whether the layer should behave in
|
|
training mode or in inference mode. This argument is passed to the
|
|
wrapped layer (only if the layer supports this argument).
|
|
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
|
|
a given timestep should be masked. This argument is passed to the
|
|
wrapped layer (only if the layer supports this argument).
|
|
"""
|
|
|
|
def __init__(self, layer, **kwargs):
|
|
if not isinstance(layer, Layer):
|
|
raise ValueError(
|
|
"Please initialize `TimeDistributed` layer with a "
|
|
f"`keras_core.layers.Layer` instance. Received: {layer}"
|
|
)
|
|
super().__init__(layer, **kwargs)
|
|
self.supports_masking = True
|
|
|
|
def _get_child_input_shape(self, input_shape):
|
|
if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:
|
|
raise ValueError(
|
|
"`TimeDistributed` Layer should be passed an `input_shape` "
|
|
f"with at least 3 dimensions, received: {input_shape}"
|
|
)
|
|
return (input_shape[0], *input_shape[2:])
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
child_input_shape = self._get_child_input_shape(input_shape)
|
|
child_output_shape = self.layer.compute_output_shape(child_input_shape)
|
|
return (child_output_shape[0], input_shape[1], *child_output_shape[1:])
|
|
|
|
def build(self, input_shape):
|
|
child_input_shape = self._get_child_input_shape(input_shape)
|
|
super().build(child_input_shape)
|
|
self.built = True
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
input_shape = inputs.shape
|
|
mask_shape = None if mask is None else tuple(mask.shape)
|
|
batch_size = input_shape[0]
|
|
timesteps = input_shape[1]
|
|
|
|
if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps):
|
|
raise ValueError(
|
|
"`TimeDistributed` Layer should be passed a `mask` of shape "
|
|
f"({batch_size}, {timesteps}, ...), "
|
|
f"received: mask.shape={mask_shape}"
|
|
)
|
|
|
|
def time_distributed_transpose(data):
|
|
"""Swaps the timestep and batch dimensions of a tensor."""
|
|
axes = [1, 0, *range(2, len(data.shape))]
|
|
return ops.transpose(data, axes=axes)
|
|
|
|
inputs = time_distributed_transpose(inputs)
|
|
if mask is not None:
|
|
mask = time_distributed_transpose(mask)
|
|
|
|
def step_function(i):
|
|
kwargs = {}
|
|
if self.layer._call_has_mask_arg() and mask is not None:
|
|
kwargs["mask"] = mask[i]
|
|
if self.layer._call_has_training_arg():
|
|
kwargs["training"] = training
|
|
return self.layer.call(inputs[i], **kwargs)
|
|
|
|
# Implementation #1: is the time axis is static, use a Python for loop.
|
|
|
|
if inputs.shape[0] is not None:
|
|
outputs = ops.stack(
|
|
[step_function(i) for i in range(inputs.shape[0])]
|
|
)
|
|
return time_distributed_transpose(outputs)
|
|
|
|
# Implementation #2: use backend.vectorized_map.
|
|
|
|
outputs = backend.vectorized_map(step_function, ops.arange(timesteps))
|
|
return time_distributed_transpose(outputs)
|