950 lines
34 KiB
Python
950 lines
34 KiB
Python
import tensorflow as tf
|
|
|
|
|
|
def rnn(
|
|
step_function,
|
|
inputs,
|
|
initial_states,
|
|
go_backwards=False,
|
|
mask=None,
|
|
constants=None,
|
|
unroll=False,
|
|
input_length=None,
|
|
time_major=False,
|
|
zero_output_for_mask=False,
|
|
return_all_outputs=True,
|
|
):
|
|
"""Iterates over the time dimension of a tensor.
|
|
|
|
Args:
|
|
step_function: RNN step function.
|
|
Args;
|
|
`input`; Tensor with shape `(samples, ...)` (no time dimension),
|
|
representing input for the batch of samples at a certain
|
|
time step.
|
|
`states`; List of tensors.
|
|
Returns;
|
|
`output`; Tensor with shape `(samples, output_dim)`
|
|
(no time dimension).
|
|
`new_states`; List of tensors, same length and shapes
|
|
as 'states'. The first state in the list must be the
|
|
output tensor at the previous timestep.
|
|
inputs: Tensor of temporal data of shape `(samples, time, ...)`
|
|
(at least 3D), or nested tensors, and each of which has shape
|
|
`(samples, time, ...)`.
|
|
initial_states: Tensor with shape `(samples, state_size)`
|
|
(no time dimension), containing the initial values for the states
|
|
used in the step function. In the case that state_size is in a
|
|
nested shape, the shape of initial_states will also follow the
|
|
nested structure.
|
|
go_backwards: Boolean. If `True`, do the iteration over the time
|
|
dimension in reverse order and return the reversed sequence.
|
|
mask: Binary tensor with shape `(samples, time, 1)`,
|
|
with a zero for every element that is masked.
|
|
constants: List of constant values passed at each step.
|
|
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
|
|
input_length: An integer or a 1-D Tensor, depending on whether
|
|
the time dimension is fixed-length or not. In case of variable
|
|
length input, it is used for masking in case there's no mask
|
|
specified.
|
|
time_major: Boolean. If `True`, the inputs and outputs will be in shape
|
|
`(timesteps, batch, ...)`, whereas in the False case, it will be
|
|
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
|
|
efficient because it avoids transposes at the beginning and end of
|
|
the RNN calculation. However, most TensorFlow data is batch-major,
|
|
so by default this function accepts input and emits output in
|
|
batch-major form.
|
|
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
|
|
will be zeros, whereas in the `False` case, output from previous
|
|
timestep is returned.
|
|
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
|
|
all timesteps in the sequence. If `False`, only return the output
|
|
for the last timestep (which consumes less memory).
|
|
|
|
Returns:
|
|
A tuple, `(last_output, outputs, new_states)`.
|
|
- `last_output`: the latest output of the rnn,
|
|
with shape `(samples, ...)`.
|
|
- `outputs`:
|
|
- If `return_all_outputs=True`: a tensor with shape
|
|
`(samples, time, ...)` where each entry `outputs[s, t]` is the
|
|
output of the step function at time `t` for sample `s`
|
|
- Else, a tensor equal to `last_output` with shape
|
|
`(samples, 1, ...)`
|
|
- `new_states`: list of tensors, latest states returned by
|
|
the step function, of shape `(samples, ...)`.
|
|
"""
|
|
|
|
def swap_batch_timestep(input_t):
|
|
# Swap the batch and timestep dim for the incoming tensor.
|
|
axes = list(range(len(input_t.shape)))
|
|
axes[0], axes[1] = 1, 0
|
|
return tf.transpose(input_t, axes)
|
|
|
|
if not time_major:
|
|
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
|
|
|
|
flattened_inputs = tf.nest.flatten(inputs)
|
|
time_steps = flattened_inputs[0].shape[0]
|
|
time_steps_t = tf.shape(flattened_inputs[0])[0]
|
|
|
|
for input_ in flattened_inputs:
|
|
input_.shape.with_rank_at_least(3)
|
|
|
|
if mask is not None:
|
|
if mask.dtype != tf.bool:
|
|
mask = tf.cast(mask, tf.bool)
|
|
if len(mask.shape) == 2:
|
|
mask = tf.expand_dims(mask, axis=-1)
|
|
if not time_major:
|
|
mask = swap_batch_timestep(mask)
|
|
|
|
if constants is None:
|
|
constants = []
|
|
|
|
# tf.where needs its condition tensor to be the same shape as its two
|
|
# result tensors, but in our case the condition (mask) tensor is
|
|
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
|
|
# So we need to broadcast the mask to match the shape of inputs.
|
|
# That's what the tile call does, it just repeats the mask along its
|
|
# second dimension n times.
|
|
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
|
if tf.nest.is_nested(mask_t):
|
|
raise ValueError(
|
|
f"mask_t is expected to be tensor, but got {mask_t}"
|
|
)
|
|
if tf.nest.is_nested(input_t):
|
|
raise ValueError(
|
|
f"input_t is expected to be tensor, but got {input_t}"
|
|
)
|
|
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
|
for _ in range(rank_diff):
|
|
mask_t = tf.expand_dims(mask_t, -1)
|
|
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
|
|
return tf.tile(mask_t, multiples)
|
|
|
|
if unroll:
|
|
if not time_steps:
|
|
raise ValueError("Unrolling requires a fixed number of timesteps.")
|
|
states = tuple(initial_states)
|
|
successive_states = []
|
|
successive_outputs = []
|
|
|
|
# Process the input tensors. The input tensor need to be split on the
|
|
# time_step dim, and reverse if go_backwards is True. In the case of
|
|
# nested input, the input is flattened and then transformed
|
|
# individually. The result of this will be a tuple of lists, each of
|
|
# the item in tuple is list of the tensor with shape (batch, feature)
|
|
def _process_single_input_t(input_t):
|
|
input_t = tf.unstack(input_t) # unstack for time_step dim
|
|
if go_backwards:
|
|
input_t.reverse()
|
|
return input_t
|
|
|
|
if tf.nest.is_nested(inputs):
|
|
processed_input = tf.nest.map_structure(
|
|
_process_single_input_t, inputs
|
|
)
|
|
else:
|
|
processed_input = (_process_single_input_t(inputs),)
|
|
|
|
def _get_input_tensor(time):
|
|
inp = [t_[time] for t_ in processed_input]
|
|
return tf.nest.pack_sequence_as(inputs, inp)
|
|
|
|
if mask is not None:
|
|
mask_list = tf.unstack(mask)
|
|
if go_backwards:
|
|
mask_list.reverse()
|
|
|
|
for i in range(time_steps):
|
|
inp = _get_input_tensor(i)
|
|
mask_t = mask_list[i]
|
|
output, new_states = step_function(
|
|
inp, tuple(states) + tuple(constants)
|
|
)
|
|
tiled_mask_t = _expand_mask(mask_t, output)
|
|
|
|
if not successive_outputs:
|
|
prev_output = tf.zeros_like(output)
|
|
else:
|
|
prev_output = successive_outputs[-1]
|
|
|
|
output = tf.where(tiled_mask_t, output, prev_output)
|
|
|
|
flat_states = tf.nest.flatten(states)
|
|
flat_new_states = tf.nest.flatten(new_states)
|
|
tiled_mask_t = tuple(
|
|
_expand_mask(mask_t, s) for s in flat_states
|
|
)
|
|
flat_final_states = tuple(
|
|
tf.where(m, s, ps)
|
|
for m, s, ps in zip(
|
|
tiled_mask_t, flat_new_states, flat_states
|
|
)
|
|
)
|
|
states = tf.nest.pack_sequence_as(states, flat_final_states)
|
|
|
|
if return_all_outputs:
|
|
successive_outputs.append(output)
|
|
successive_states.append(states)
|
|
else:
|
|
successive_outputs = [output]
|
|
successive_states = [states]
|
|
last_output = successive_outputs[-1]
|
|
new_states = successive_states[-1]
|
|
outputs = tf.stack(successive_outputs)
|
|
|
|
if zero_output_for_mask:
|
|
last_output = tf.where(
|
|
_expand_mask(mask_list[-1], last_output),
|
|
last_output,
|
|
tf.zeros_like(last_output),
|
|
)
|
|
outputs = tf.where(
|
|
_expand_mask(mask, outputs, fixed_dim=2),
|
|
outputs,
|
|
tf.zeros_like(outputs),
|
|
)
|
|
|
|
else: # mask is None
|
|
for i in range(time_steps):
|
|
inp = _get_input_tensor(i)
|
|
output, states = step_function(
|
|
inp, tuple(states) + tuple(constants)
|
|
)
|
|
if return_all_outputs:
|
|
successive_outputs.append(output)
|
|
successive_states.append(states)
|
|
else:
|
|
successive_outputs = [output]
|
|
successive_states = [states]
|
|
last_output = successive_outputs[-1]
|
|
new_states = successive_states[-1]
|
|
outputs = tf.stack(successive_outputs)
|
|
|
|
else: # Unroll == False
|
|
states = tuple(initial_states)
|
|
|
|
# Create input tensor array, if the inputs is nested tensors, then it
|
|
# will be flattened first, and tensor array will be created one per
|
|
# flattened tensor.
|
|
input_ta = tuple(
|
|
tf.TensorArray(
|
|
dtype=inp.dtype,
|
|
size=time_steps_t,
|
|
tensor_array_name=f"input_ta_{i}",
|
|
)
|
|
for i, inp in enumerate(flattened_inputs)
|
|
)
|
|
input_ta = tuple(
|
|
ta.unstack(input_)
|
|
if not go_backwards
|
|
else ta.unstack(tf.reverse(input_, [0]))
|
|
for ta, input_ in zip(input_ta, flattened_inputs)
|
|
)
|
|
|
|
# Get the time(0) input and compute the output for that, the output will
|
|
# be used to determine the dtype of output tensor array. Don't read from
|
|
# input_ta due to TensorArray clear_after_read default to True.
|
|
input_time_zero = tf.nest.pack_sequence_as(
|
|
inputs, [inp[0] for inp in flattened_inputs]
|
|
)
|
|
# output_time_zero is used to determine the cell output shape and its
|
|
# dtype. the value is discarded.
|
|
output_time_zero, _ = step_function(
|
|
input_time_zero, tuple(initial_states) + tuple(constants)
|
|
)
|
|
|
|
output_ta_size = time_steps_t if return_all_outputs else 1
|
|
output_ta = tuple(
|
|
tf.TensorArray(
|
|
dtype=out.dtype,
|
|
size=output_ta_size,
|
|
element_shape=out.shape,
|
|
tensor_array_name=f"output_ta_{i}",
|
|
)
|
|
for i, out in enumerate(tf.nest.flatten(output_time_zero))
|
|
)
|
|
|
|
time = tf.constant(0, dtype="int32", name="time")
|
|
|
|
if input_length is None:
|
|
max_iterations = time_steps_t
|
|
else:
|
|
max_iterations = tf.reduce_max(input_length)
|
|
|
|
while_loop_kwargs = {
|
|
"cond": lambda time, *_: time < time_steps_t,
|
|
"maximum_iterations": max_iterations,
|
|
"parallel_iterations": 32,
|
|
"swap_memory": True,
|
|
}
|
|
if mask is not None:
|
|
if go_backwards:
|
|
mask = tf.reverse(mask, [0])
|
|
|
|
mask_ta = tf.TensorArray(
|
|
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
|
|
)
|
|
mask_ta = mask_ta.unstack(mask)
|
|
|
|
def masking_fn(time):
|
|
return mask_ta.read(time)
|
|
|
|
def compute_masked_output(mask_t, flat_out, flat_mask):
|
|
tiled_mask_t = tuple(
|
|
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
|
|
for o in flat_out
|
|
)
|
|
return tuple(
|
|
tf.where(m, o, fm)
|
|
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
|
|
)
|
|
|
|
elif isinstance(input_length, tf.Tensor):
|
|
if go_backwards:
|
|
max_len = tf.reduce_max(input_length, axis=0)
|
|
rev_input_length = tf.subtract(max_len - 1, input_length)
|
|
|
|
def masking_fn(time):
|
|
return tf.less(rev_input_length, time)
|
|
|
|
else:
|
|
|
|
def masking_fn(time):
|
|
return tf.greater(input_length, time)
|
|
|
|
def compute_masked_output(mask_t, flat_out, flat_mask):
|
|
return tuple(
|
|
tf.where(mask_t, o, zo)
|
|
for (o, zo) in zip(flat_out, flat_mask)
|
|
)
|
|
|
|
else:
|
|
masking_fn = None
|
|
|
|
if masking_fn is not None:
|
|
# Mask for the T output will be base on the output of T - 1. In the
|
|
# case T = 0, a zero filled tensor will be used.
|
|
flat_zero_output = tuple(
|
|
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
|
|
)
|
|
|
|
def _step(time, output_ta_t, prev_output, *states):
|
|
"""RNN step function.
|
|
|
|
Args:
|
|
time: Current timestep value.
|
|
output_ta_t: TensorArray.
|
|
prev_output: tuple of outputs from time - 1.
|
|
*states: List of states.
|
|
|
|
Returns:
|
|
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
|
|
"""
|
|
current_input = tuple(ta.read(time) for ta in input_ta)
|
|
# maybe set shape.
|
|
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
|
mask_t = masking_fn(time)
|
|
output, new_states = step_function(
|
|
current_input, tuple(states) + tuple(constants)
|
|
)
|
|
# mask output
|
|
flat_output = tf.nest.flatten(output)
|
|
flat_mask_output = (
|
|
flat_zero_output
|
|
if zero_output_for_mask
|
|
else tf.nest.flatten(prev_output)
|
|
)
|
|
flat_new_output = compute_masked_output(
|
|
mask_t, flat_output, flat_mask_output
|
|
)
|
|
|
|
# mask states
|
|
flat_state = tf.nest.flatten(states)
|
|
flat_new_state = tf.nest.flatten(new_states)
|
|
flat_final_state = compute_masked_output(
|
|
mask_t, flat_new_state, flat_state
|
|
)
|
|
new_states = tf.nest.pack_sequence_as(
|
|
new_states, flat_final_state
|
|
)
|
|
|
|
ta_index_to_write = time if return_all_outputs else 0
|
|
output_ta_t = tuple(
|
|
ta.write(ta_index_to_write, out)
|
|
for ta, out in zip(output_ta_t, flat_new_output)
|
|
)
|
|
|
|
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
|
|
new_states
|
|
)
|
|
|
|
final_outputs = tf.while_loop(
|
|
body=_step,
|
|
loop_vars=(time, output_ta, flat_zero_output) + states,
|
|
**while_loop_kwargs,
|
|
)
|
|
# Skip final_outputs[2] which is the output for final timestep.
|
|
new_states = final_outputs[3:]
|
|
else:
|
|
|
|
def _step(time, output_ta_t, *states):
|
|
"""RNN step function.
|
|
|
|
Args:
|
|
time: Current timestep value.
|
|
output_ta_t: TensorArray.
|
|
*states: List of states.
|
|
|
|
Returns:
|
|
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
|
|
"""
|
|
current_input = tuple(ta.read(time) for ta in input_ta)
|
|
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
|
output, new_states = step_function(
|
|
current_input, tuple(states) + tuple(constants)
|
|
)
|
|
flat_new_state = tf.nest.flatten(new_states)
|
|
|
|
flat_output = tf.nest.flatten(output)
|
|
ta_index_to_write = time if return_all_outputs else 0
|
|
output_ta_t = tuple(
|
|
ta.write(ta_index_to_write, out)
|
|
for ta, out in zip(output_ta_t, flat_output)
|
|
)
|
|
|
|
new_states = tf.nest.pack_sequence_as(
|
|
initial_states, flat_new_state
|
|
)
|
|
return (time + 1, output_ta_t) + tuple(new_states)
|
|
|
|
final_outputs = tf.while_loop(
|
|
body=_step,
|
|
loop_vars=(time, output_ta) + states,
|
|
**while_loop_kwargs,
|
|
)
|
|
new_states = final_outputs[2:]
|
|
|
|
output_ta = final_outputs[1]
|
|
|
|
outputs = tuple(o.stack() for o in output_ta)
|
|
last_output = tuple(o[-1] for o in outputs)
|
|
|
|
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
|
|
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
|
|
|
|
if not time_major:
|
|
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
|
|
|
|
return last_output, outputs, new_states
|
|
|
|
|
|
def gru(
|
|
inputs,
|
|
initial_state,
|
|
mask,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
activation,
|
|
recurrent_activation,
|
|
return_sequences=False,
|
|
go_backwards=False,
|
|
unroll=False,
|
|
time_major=False,
|
|
reset_after=True,
|
|
):
|
|
args_supported = _do_gru_arguments_support_cudnn(
|
|
activation=activation,
|
|
recurrent_activation=recurrent_activation,
|
|
unroll=unroll,
|
|
bias=bias,
|
|
reset_after=reset_after,
|
|
)
|
|
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
|
|
if not args_supported or not inputs_supported:
|
|
raise NotImplementedError
|
|
|
|
from keras_core.backend.tensorflow import Variable
|
|
|
|
if isinstance(kernel, Variable):
|
|
kernel = kernel.value
|
|
if isinstance(recurrent_kernel, Variable):
|
|
recurrent_kernel = recurrent_kernel.value
|
|
if isinstance(bias, Variable):
|
|
bias = bias.value
|
|
|
|
try:
|
|
return _cudnn_gru(
|
|
inputs,
|
|
initial_state,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
mask,
|
|
time_major,
|
|
go_backwards,
|
|
return_sequences,
|
|
)
|
|
except tf.errors.InvalidArgumentError:
|
|
# cuDNN op not found.
|
|
raise NotImplementedError
|
|
except tf.errors.NotFoundError:
|
|
# alternative error: device not found for op
|
|
raise NotImplementedError
|
|
|
|
|
|
def _do_gru_arguments_support_cudnn(
|
|
activation,
|
|
recurrent_activation,
|
|
unroll,
|
|
bias,
|
|
reset_after,
|
|
):
|
|
from keras_core import activations
|
|
from keras_core import operations as ops
|
|
|
|
return (
|
|
activation in (activations.tanh, tf.tanh, ops.tanh)
|
|
and recurrent_activation
|
|
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
|
|
and not unroll
|
|
and bias is not None
|
|
and reset_after
|
|
)
|
|
|
|
|
|
def _do_lstm_arguments_support_cudnn(
|
|
activation,
|
|
recurrent_activation,
|
|
unroll,
|
|
bias,
|
|
):
|
|
from keras_core import activations
|
|
from keras_core import operations as ops
|
|
|
|
return (
|
|
activation in (activations.tanh, tf.tanh, ops.tanh)
|
|
and recurrent_activation
|
|
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
|
|
and not unroll
|
|
and bias is not None
|
|
)
|
|
|
|
|
|
def _do_rnn_inputs_support_cudnn(mask, time_major):
|
|
if tf.sysconfig.get_build_info()["is_rocm_build"]:
|
|
if mask is not None:
|
|
return tf.reduce_all(mask)
|
|
return True
|
|
if mask is None:
|
|
return True
|
|
if time_major:
|
|
mask = tf.transpose(mask)
|
|
return tf.logical_and(
|
|
_is_sequence_right_padded(mask),
|
|
tf.logical_not(_has_fully_masked_sequence(mask)),
|
|
)
|
|
|
|
|
|
def _is_sequence_right_padded(mask):
|
|
"""Check the mask tensor and see if it right padded.
|
|
|
|
For cuDNN kernel, it uses the sequence length param to skip the tailing
|
|
timestep. If the data is left padded, or not a strict right padding (has
|
|
masked value in the middle of the sequence), then cuDNN kernel won't be work
|
|
properly in those cases.
|
|
|
|
Left padded data: [[False, False, True, True, True]].
|
|
Right padded data: [[True, True, True, False, False]].
|
|
Mixture of mask/unmasked data: [[True, False, True, False, False]].
|
|
|
|
Note that for the mixed data example above, the actually data RNN should see
|
|
are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
|
|
not pollute the internal states.
|
|
|
|
Args:
|
|
mask: the Boolean tensor with shape [batch, timestep]
|
|
|
|
Returns:
|
|
boolean scalar tensor, whether the mask is strictly right padded.
|
|
"""
|
|
max_seq_length = tf.shape(mask)[1]
|
|
count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
|
|
right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
|
|
return tf.reduce_all(
|
|
tf.equal(
|
|
tf.cast(mask, dtype="bool"),
|
|
tf.cast(right_padded_mask, dtype="bool"),
|
|
)
|
|
)
|
|
|
|
|
|
def _has_fully_masked_sequence(mask):
|
|
# Cudnn kernel will error out if the input sequence contains any
|
|
# fully masked data. We walk around this issue by rerouting the computation
|
|
# to standard kernel, until the issue on cudnn side has been fixed. For a
|
|
# fully masked sequence, it will contain all Falses. To make it easy to
|
|
# check, we inverse the boolean, check if any of the sequence has all True.
|
|
return tf.reduce_any(
|
|
tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
|
|
)
|
|
|
|
|
|
def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
|
|
"""Utility function convert variable to cuDNN compatible parameter.
|
|
|
|
Note that Keras weights for kernels are different from the cuDNN format.
|
|
Eg.:
|
|
|
|
```
|
|
Keras cuDNN
|
|
[[0, 1, 2], <---> [[0, 2, 4],
|
|
[3, 4, 5]] [1, 3, 5]]
|
|
```
|
|
|
|
If the input weights need to be in a unified format, then set
|
|
`transpose_weights=True` to convert the weights.
|
|
|
|
Args:
|
|
weights: list of weights for the kernels and recurrent kernels.
|
|
biases: list of biases for individual gate.
|
|
shape: the shape for the converted variables that will be feed to cuDNN.
|
|
transpose_weights: boolean, whether to transpose the weights.
|
|
|
|
Returns:
|
|
The converted weights that can be feed to cuDNN ops as param.
|
|
"""
|
|
|
|
def convert(w):
|
|
return tf.transpose(w) if transpose_weights else w
|
|
|
|
weights = [tf.reshape(convert(x), shape) for x in weights]
|
|
biases = [tf.reshape(x, shape) for x in biases]
|
|
return tf.concat(weights + biases, axis=0)
|
|
|
|
|
|
def _compute_sequence_length_from_mask(mask, time_major):
|
|
"""Calculate the sequence length tensor (1-D) based on the masking tensor.
|
|
|
|
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
|
|
any timestep that should be masked, the corresponding field will be False.
|
|
Consider the following example:
|
|
a = [[True, True, False, False],
|
|
[True, True, True, False]]
|
|
It is a (2, 4) tensor, and the corresponding sequence length result should
|
|
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
|
|
padded that could be checked by, e.g., `is_sequence_right_padded()`.
|
|
|
|
Args:
|
|
mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
|
|
time_major=True.
|
|
time_major: Boolean, which indicates whether the mask is time major or
|
|
batch major.
|
|
Returns:
|
|
sequence_length: 1D int32 tensor.
|
|
"""
|
|
timestep_index = 0 if time_major else 1
|
|
return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)
|
|
|
|
|
|
@tf.function(autograph=False)
|
|
def _cudnn_gru(
|
|
inputs,
|
|
initial_state,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
mask,
|
|
time_major,
|
|
go_backwards,
|
|
return_sequences,
|
|
):
|
|
"""GRU with cuDNN implementation which is only available for GPU."""
|
|
if mask is not None:
|
|
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
|
|
else:
|
|
sequence_lengths = None
|
|
|
|
if not time_major and sequence_lengths is None:
|
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
|
seq_axis, batch_axis = (0, 1)
|
|
else:
|
|
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
|
# For init_h, cuDNN expects one more dim of num_layers before or after batch
|
|
# dim for time major or batch major inputs respectively
|
|
init_h = tf.expand_dims(initial_state, axis=seq_axis)
|
|
|
|
weights = tf.split(kernel, 3, axis=1)
|
|
weights += tf.split(recurrent_kernel, 3, axis=1)
|
|
# Note that the bias was initialized as shape (2, 3 * units), flatten it to
|
|
# (6 * units)
|
|
bias = tf.split(tf.reshape(bias, [-1]), 6)
|
|
|
|
if tf.sysconfig.get_build_info()["is_cuda_build"]:
|
|
# Note that the gate order for cuDNN is different from the canonical
|
|
# format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
|
|
# The swap need to be done for kernel, recurrent_kernel, input_bias,
|
|
# recurrent_bias.
|
|
# z is update gate weights.
|
|
# r is reset gate weights.
|
|
# h is output gate weights.
|
|
weights[0], weights[1] = weights[1], weights[0]
|
|
weights[3], weights[4] = weights[4], weights[3]
|
|
bias[0], bias[1] = bias[1], bias[0]
|
|
bias[3], bias[4] = bias[4], bias[3]
|
|
|
|
params = _standardize_cudnn_weights(
|
|
weights=weights,
|
|
biases=bias,
|
|
shape=tf.constant([-1]),
|
|
transpose_weights=True,
|
|
)
|
|
|
|
if sequence_lengths is not None:
|
|
if go_backwards:
|
|
# Three reversals are required. E.g.,
|
|
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
|
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
|
# output_from_cudnn = [6, 5, 4, 0, 0]
|
|
# expected_output = [0, 0, 6, 5 ,4]
|
|
inputs = tf.reverse_sequence(
|
|
inputs,
|
|
sequence_lengths,
|
|
seq_axis=seq_axis,
|
|
batch_axis=batch_axis,
|
|
)
|
|
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
|
input=inputs,
|
|
input_h=init_h,
|
|
input_c=0,
|
|
params=params,
|
|
is_training=True,
|
|
rnn_mode="gru",
|
|
sequence_lengths=sequence_lengths,
|
|
time_major=time_major,
|
|
)
|
|
if go_backwards:
|
|
outputs = tf.reverse_sequence(
|
|
outputs,
|
|
sequence_lengths,
|
|
seq_axis=seq_axis,
|
|
batch_axis=batch_axis,
|
|
)
|
|
outputs = tf.reverse(outputs, axis=[seq_axis])
|
|
else:
|
|
if go_backwards:
|
|
# Reverse axis 0 since the input is already convert to time major.
|
|
inputs = tf.reverse(inputs, axis=[0])
|
|
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
|
|
input=inputs,
|
|
input_h=init_h,
|
|
input_c=0,
|
|
params=params,
|
|
is_training=True,
|
|
rnn_mode="gru",
|
|
)
|
|
|
|
last_output = outputs[-1]
|
|
if not time_major and sequence_lengths is None and return_sequences:
|
|
outputs = tf.transpose(outputs, perm=[1, 0, 2])
|
|
state = tf.squeeze(h, axis=seq_axis)
|
|
|
|
# In the case of variable length input, the cudnn kernel will fill zeros for
|
|
# the output, whereas the default keras behavior is to bring over the
|
|
# previous output for t-1, so that in the return_sequence=False case, user
|
|
# can quickly get the final effect output instead just 0s at the last
|
|
# timestep. In order to mimic the default keras behavior, we copy the final
|
|
# h state as the last_output, since it is numerically same as the output.
|
|
if sequence_lengths is not None:
|
|
last_output = state
|
|
|
|
# Match CPU return format
|
|
if not return_sequences:
|
|
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
|
|
|
|
return (
|
|
last_output,
|
|
outputs,
|
|
state,
|
|
)
|
|
|
|
|
|
def lstm(
|
|
inputs,
|
|
initial_state_h,
|
|
initial_state_c,
|
|
mask,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
activation,
|
|
recurrent_activation,
|
|
return_sequences=False,
|
|
go_backwards=False,
|
|
unroll=False,
|
|
time_major=False,
|
|
):
|
|
args_supported = _do_lstm_arguments_support_cudnn(
|
|
activation=activation,
|
|
recurrent_activation=recurrent_activation,
|
|
unroll=unroll,
|
|
bias=bias,
|
|
)
|
|
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
|
|
if not args_supported or not inputs_supported:
|
|
raise NotImplementedError
|
|
|
|
from keras_core.backend.tensorflow import Variable
|
|
|
|
if isinstance(kernel, Variable):
|
|
kernel = kernel.value
|
|
if isinstance(recurrent_kernel, Variable):
|
|
recurrent_kernel = recurrent_kernel.value
|
|
if isinstance(bias, Variable):
|
|
bias = bias.value
|
|
|
|
try:
|
|
return _cudnn_lstm(
|
|
inputs,
|
|
initial_state_h,
|
|
initial_state_c,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
mask,
|
|
time_major,
|
|
go_backwards,
|
|
return_sequences,
|
|
)
|
|
except tf.errors.InvalidArgumentError:
|
|
# cuDNN op not found.
|
|
raise NotImplementedError
|
|
except tf.errors.NotFoundError:
|
|
# alternative error: device not found for op
|
|
raise NotImplementedError
|
|
|
|
|
|
def _cudnn_lstm(
|
|
inputs,
|
|
initial_state_h,
|
|
initial_state_c,
|
|
kernel,
|
|
recurrent_kernel,
|
|
bias,
|
|
mask,
|
|
time_major,
|
|
go_backwards,
|
|
return_sequences,
|
|
):
|
|
if mask is not None:
|
|
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
|
|
else:
|
|
sequence_lengths = None
|
|
|
|
if not time_major and sequence_lengths is None:
|
|
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
|
seq_axis, batch_axis = (0, 1)
|
|
else:
|
|
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
|
# For init_h and init_c, cuDNN expects one more dim of num_layers before or
|
|
# after batch dim for time major or batch major inputs respectively
|
|
init_h = tf.expand_dims(initial_state_h, axis=seq_axis)
|
|
init_c = tf.expand_dims(initial_state_c, axis=seq_axis)
|
|
|
|
weights = tf.split(kernel, 4, axis=1)
|
|
weights += tf.split(recurrent_kernel, 4, axis=1)
|
|
# cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
|
|
# so that mathematically it is same as the canonical LSTM implementation.
|
|
full_bias = tf.concat((tf.zeros_like(bias), bias), 0)
|
|
|
|
if tf.sysconfig.get_build_info()["is_rocm_build"]:
|
|
# ROCm MIOpen's weight sequence for LSTM is different from both
|
|
# canonical and Cudnn format
|
|
# MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
|
|
# i is input gate weights.
|
|
# f is forget gate weights.
|
|
# o is output gate weights.
|
|
# c is cell gate weights.
|
|
weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
|
|
# full_bias is a tensor of shape (8*n,)
|
|
full_bias = tf.split(full_bias, 8, axis=0)
|
|
full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
|
|
|
|
params = _standardize_cudnn_weights(
|
|
weights=weights,
|
|
biases=tf.split(full_bias, 8),
|
|
shape=tf.constant([-1]),
|
|
transpose_weights=True,
|
|
)
|
|
|
|
if sequence_lengths is not None:
|
|
if go_backwards:
|
|
# Three reversals are required. E.g.,
|
|
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
|
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
|
# output_from_cudnn = [6, 5, 4, 0, 0]
|
|
# expected_output = [0, 0, 6, 5 ,4]
|
|
inputs = tf.reverse_sequence(
|
|
inputs,
|
|
sequence_lengths,
|
|
seq_axis=seq_axis,
|
|
batch_axis=batch_axis,
|
|
)
|
|
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
|
|
input=inputs,
|
|
input_h=init_h,
|
|
input_c=init_c,
|
|
params=params,
|
|
is_training=True,
|
|
rnn_mode="lstm",
|
|
sequence_lengths=sequence_lengths,
|
|
time_major=time_major,
|
|
)
|
|
if go_backwards:
|
|
outputs = tf.reverse_sequence(
|
|
outputs,
|
|
sequence_lengths,
|
|
seq_axis=seq_axis,
|
|
batch_axis=batch_axis,
|
|
)
|
|
outputs = tf.reverse(outputs, axis=[seq_axis])
|
|
else:
|
|
# # Fill the array with shape [batch] with value of max timesteps.
|
|
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
|
|
# array_ops.shape(inputs)[0])
|
|
if go_backwards:
|
|
# Reverse axis 0 since the input is already convert to time major.
|
|
inputs = tf.reverse(inputs, axis=[0])
|
|
outputs, h, c, _ = tf.raw_ops.CudnnRNN(
|
|
input=inputs,
|
|
input_h=init_h,
|
|
input_c=init_c,
|
|
params=params,
|
|
is_training=True,
|
|
rnn_mode="lstm",
|
|
)
|
|
|
|
last_output = outputs[-1]
|
|
if not time_major and sequence_lengths is None and return_sequences:
|
|
outputs = tf.transpose(outputs, perm=[1, 0, 2])
|
|
h = tf.squeeze(h, axis=seq_axis)
|
|
c = tf.squeeze(c, axis=seq_axis)
|
|
|
|
# In the case of variable length input, the cudnn kernel will fill zeros for
|
|
# the output, whereas the default keras behavior is to bring over the
|
|
# previous output for t-1, so that in the return_sequence=False case, user
|
|
# can quickly get the final effect output instead just 0s at the last
|
|
# timestep. In order to mimic the default keras behavior, we copy the final
|
|
# h state as the last_output, since it is numerically same as the output.
|
|
if sequence_lengths is not None:
|
|
last_output = h
|
|
|
|
# Match CPU return format
|
|
if not return_sequences:
|
|
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
|
|
|
|
return (last_output, outputs, [h, c])
|