keras/keras_core/backend/tensorflow/rnn.py
Neel Kovelamudi 2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00

958 lines
34 KiB
Python

import tensorflow as tf
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
"""Iterates over the time dimension of a tensor.
Args:
step_function: RNN step function.
Args;
`input`; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
`states`; List of tensors.
Returns;
`output`; Tensor with shape `(samples, output_dim)`
(no time dimension).
`new_states`; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D), or nested tensors, and each of which has shape
`(samples, time, ...)`.
initial_states: Tensor with shape `(samples, state_size)`
(no time dimension), containing the initial values for the states
used in the step function. In the case that state_size is in a
nested shape, the shape of initial_states will also follow the
nested structure.
go_backwards: Boolean. If `True`, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: An integer or a 1-D Tensor, depending on whether
the time dimension is fixed-length or not. In case of variable
length input, it is used for masking in case there's no mask
specified.
time_major: Boolean. If `True`, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of
the RNN calculation. However, most TensorFlow data is batch-major,
so by default this function accepts input and emits output in
batch-major form.
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
will be zeros, whereas in the `False` case, output from previous
timestep is returned.
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
all timesteps in the sequence. If `False`, only return the output
for the last timestep (which consumes less memory).
Returns:
A tuple, `(last_output, outputs, new_states)`.
- `last_output`: the latest output of the rnn,
with shape `(samples, ...)`.
- `outputs`:
- If `return_all_outputs=True`: a tensor with shape
`(samples, time, ...)` where each entry `outputs[s, t]` is the
output of the step function at time `t` for sample `s`
- Else, a tensor equal to `last_output` with shape
`(samples, 1, ...)`
- `new_states`: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return tf.transpose(input_t, axes)
if not time_major:
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
flatted_inputs = tf.nest.flatten(inputs)
time_steps = flatted_inputs[0].shape[0]
batch = flatted_inputs[0].shape[1]
time_steps_t = tf.shape(flatted_inputs[0])[0]
for input_ in flatted_inputs:
input_.shape.with_rank_at_least(3)
if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) == 2:
mask = tf.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
# tf.where needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
# So we need to broadcast the mask to match the shape of inputs.
# That's what the tile call does, it just repeats the mask along its
# second dimension n times.
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tf.nest.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tf.nest.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = tf.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
return tf.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tf.nest.is_nested(inputs):
processed_input = tf.nest.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return tf.nest.pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = tf.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
flat_states = tf.nest.flatten(states)
flat_new_states = tf.nest.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
tf.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = tf.nest.pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
if zero_output_for_mask:
last_output = tf.where(
_expand_mask(mask_list[-1], last_output),
last_output,
tf.zeros_like(last_output),
)
outputs = tf.where(
_expand_mask(mask, outputs, fixed_dim=2),
outputs,
tf.zeros_like(outputs),
)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else: # Unroll == False
states = tuple(initial_states)
# Create input tensor array, if the inputs is nested tensors, then it
# will be flattened first, and tensor array will be created one per
# flattened tensor.
input_ta = tuple(
tf.TensorArray(
dtype=inp.dtype,
size=time_steps_t,
tensor_array_name=f"input_ta_{i}",
)
for i, inp in enumerate(flatted_inputs)
)
input_ta = tuple(
ta.unstack(input_)
if not go_backwards
else ta.unstack(tf.reverse(input_, [0]))
for ta, input_ in zip(input_ta, flatted_inputs)
)
# Get the time(0) input and compute the output for that, the output will
# be used to determine the dtype of output tensor array. Don't read from
# input_ta due to TensorArray clear_after_read default to True.
input_time_zero = tf.nest.pack_sequence_as(
inputs, [inp[0] for inp in flatted_inputs]
)
# output_time_zero is used to determine the cell output shape and its
# dtype. the value is discarded.
output_time_zero, _ = step_function(
input_time_zero, tuple(initial_states) + tuple(constants)
)
output_ta_size = time_steps_t if return_all_outputs else 1
output_ta = tuple(
tf.TensorArray(
dtype=out.dtype,
size=output_ta_size,
element_shape=out.shape,
tensor_array_name=f"output_ta_{i}",
)
for i, out in enumerate(tf.nest.flatten(output_time_zero))
)
time = tf.constant(0, dtype="int32", name="time")
if input_length is None:
max_iterations = time_steps_t
else:
max_iterations = tf.reduce_max(input_length)
while_loop_kwargs = {
"cond": lambda time, *_: time < time_steps_t,
"maximum_iterations": max_iterations,
"parallel_iterations": 32,
"swap_memory": True,
}
if mask is not None:
if go_backwards:
mask = tf.reverse(mask, [0])
mask_ta = tf.TensorArray(
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
)
mask_ta = mask_ta.unstack(mask)
def masking_fn(time):
return mask_ta.read(time)
def compute_masked_output(mask_t, flat_out, flat_mask):
tiled_mask_t = tuple(
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
for o in flat_out
)
return tuple(
tf.where(m, o, fm)
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
)
elif isinstance(input_length, tf.Tensor):
if go_backwards:
max_len = tf.reduce_max(input_length, axis=0)
rev_input_length = tf.subtract(max_len - 1, input_length)
def masking_fn(time):
return tf.less(rev_input_length, time)
else:
def masking_fn(time):
return tf.greater(input_length, time)
def compute_masked_output(mask_t, flat_out, flat_mask):
return tuple(
tf.where(mask_t, o, zo)
for (o, zo) in zip(flat_out, flat_mask)
)
else:
masking_fn = None
if masking_fn is not None:
# Mask for the T output will be base on the output of T - 1. In the
# case T = 0, a zero filled tensor will be used.
flat_zero_output = tuple(
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
)
def _step(time, output_ta_t, prev_output, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
# maybe set shape.
current_input = tf.nest.pack_sequence_as(inputs, current_input)
mask_t = masking_fn(time)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
# mask output
flat_output = tf.nest.flatten(output)
flat_mask_output = (
flat_zero_output
if zero_output_for_mask
else tf.nest.flatten(prev_output)
)
flat_new_output = compute_masked_output(
mask_t, flat_output, flat_mask_output
)
# mask states
flat_state = tf.nest.flatten(states)
flat_new_state = tf.nest.flatten(new_states)
flat_final_state = compute_masked_output(
mask_t, flat_new_state, flat_state
)
new_states = tf.nest.pack_sequence_as(
new_states, flat_final_state
)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_new_output)
)
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
new_states
)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta, flat_zero_output) + states,
**while_loop_kwargs,
)
# Skip final_outputs[2] which is the output for final timestep.
new_states = final_outputs[3:]
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
Returns:
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
current_input = tf.nest.pack_sequence_as(inputs, current_input)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
flat_new_state = tf.nest.flatten(new_states)
flat_output = tf.nest.flatten(output)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_output)
)
new_states = tf.nest.pack_sequence_as(
initial_states, flat_new_state
)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs,
)
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = tuple(o.stack() for o in output_ta)
last_output = tuple(o[-1] for o in outputs)
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
# static shape inference
def set_shape(output_):
if isinstance(output_, tf.Tensor):
shape = output_.shape.as_list()
if return_all_outputs:
shape[0] = time_steps
else:
shape[0] = 1
shape[1] = batch
output_.set_shape(shape)
return output_
outputs = tf.nest.map_structure(set_shape, outputs)
if not time_major:
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
def gru(
inputs,
initial_state,
mask,
kernel,
recurrent_kernel,
bias,
activation,
recurrent_activation,
return_sequences=False,
go_backwards=False,
unroll=False,
time_major=False,
reset_after=True,
):
args_supported = _do_gru_arguments_support_cudnn(
activation=activation,
recurrent_activation=recurrent_activation,
unroll=unroll,
bias=bias,
reset_after=reset_after,
)
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
if not args_supported or not inputs_supported:
raise NotImplementedError
from keras_core.backend.tensorflow import Variable
if isinstance(kernel, Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
bias = bias.value
try:
return _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
)
except tf.errors.InvalidArgumentError:
# cuDNN op not found.
raise NotImplementedError
except tf.errors.NotFoundError:
# alternative error: device not found for op
raise NotImplementedError
def _do_gru_arguments_support_cudnn(
activation,
recurrent_activation,
unroll,
bias,
reset_after,
):
from keras_core import activations
from keras_core import operations as ops
return (
activation in (activations.tanh, tf.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
and not unroll
and bias is not None
and reset_after
)
def _do_lstm_arguments_support_cudnn(
activation,
recurrent_activation,
unroll,
bias,
):
from keras_core import activations
from keras_core import operations as ops
return (
activation in (activations.tanh, tf.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
and not unroll
and bias is not None
)
def _do_rnn_inputs_support_cudnn(mask, time_major):
if tf.sysconfig.get_build_info()["is_rocm_build"]:
if mask is not None:
return tf.reduce_all(mask)
return True
if mask is None:
return True
if time_major:
mask = tf.transpose(mask)
return tf.logical_and(
_is_sequence_right_padded(mask),
tf.logical_not(_has_fully_masked_sequence(mask)),
)
def _is_sequence_right_padded(mask):
"""Check the mask tensor and see if it right padded.
For cuDNN kernel, it uses the sequence length param to skip the tailing
timestep. If the data is left padded, or not a strict right padding (has
masked value in the middle of the sequence), then cuDNN kernel won't be work
properly in those cases.
Left padded data: [[False, False, True, True, True]].
Right padded data: [[True, True, True, False, False]].
Mixture of mask/unmasked data: [[True, False, True, False, False]].
Note that for the mixed data example above, the actually data RNN should see
are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
not pollute the internal states.
Args:
mask: the Boolean tensor with shape [batch, timestep]
Returns:
boolean scalar tensor, whether the mask is strictly right padded.
"""
max_seq_length = tf.shape(mask)[1]
count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
return tf.reduce_all(tf.equal(mask, right_padded_mask))
def _has_fully_masked_sequence(mask):
# Cudnn kernel will error out if the input sequence contains any
# fully masked data. We walk around this issue by rerouting the computation
# to standard kernel, until the issue on cudnn side has been fixed. For a
# fully masked sequence, it will contain all Falses. To make it easy to
# check, we inverse the boolean, check if any of the sequence has all True.
return tf.reduce_any(tf.reduce_all(tf.logical_not(mask), axis=1))
def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
"""Utility function convert variable to cuDNN compatible parameter.
Note that Keras weights for kernels are different from the cuDNN format.
Eg.:
```
Keras cuDNN
[[0, 1, 2], <---> [[0, 2, 4],
[3, 4, 5]] [1, 3, 5]]
```
If the input weights need to be in a unified format, then set
`transpose_weights=True` to convert the weights.
Args:
weights: list of weights for the kernels and recurrent kernels.
biases: list of biases for individual gate.
shape: the shape for the converted variables that will be feed to cuDNN.
transpose_weights: boolean, whether to transpose the weights.
Returns:
The converted weights that can be feed to cuDNN ops as param.
"""
def convert(w):
return tf.transpose(w) if transpose_weights else w
weights = [tf.reshape(convert(x), shape) for x in weights]
biases = [tf.reshape(x, shape) for x in biases]
return tf.concat(weights + biases, axis=0)
def _compute_sequence_length_from_mask(mask, time_major):
"""Calculate the sequence length tensor (1-D) based on the masking tensor.
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
any timestep that should be masked, the corresponding field will be False.
Consider the following example:
a = [[True, True, False, False],
[True, True, True, False]]
It is a (2, 4) tensor, and the corresponding sequence length result should
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
padded that could be checked by, e.g., `is_sequence_right_padded()`.
Args:
mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
time_major=True.
time_major: Boolean, which indicates whether the mask is time major or
batch major.
Returns:
sequence_length: 1D int32 tensor.
"""
timestep_index = 0 if time_major else 1
return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)
@tf.function(autograph=False)
def _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
):
"""GRU with cuDNN implementation which is only available for GPU."""
if mask is not None:
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
else:
sequence_lengths = None
if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
# For init_h, cuDNN expects one more dim of num_layers before or after batch
# dim for time major or batch major inputs respectively
init_h = tf.expand_dims(initial_state, axis=seq_axis)
weights = tf.split(kernel, 3, axis=1)
weights += tf.split(recurrent_kernel, 3, axis=1)
# Note that the bias was initialized as shape (2, 3 * units), flatten it to
# (6 * units)
bias = tf.split(tf.reshape(bias, [-1]), 6)
if tf.sysconfig.get_build_info()["is_cuda_build"]:
# Note that the gate order for cuDNN is different from the canonical
# format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
# The swap need to be done for kernel, recurrent_kernel, input_bias,
# recurrent_bias.
# z is update gate weights.
# r is reset gate weights.
# h is output gate weights.
weights[0], weights[1] = weights[1], weights[0]
weights[3], weights[4] = weights[4], weights[3]
bias[0], bias[1] = bias[1], bias[0]
bias[3], bias[4] = bias[4], bias[3]
params = _standardize_cudnn_weights(
weights=weights,
biases=bias,
shape=tf.constant([-1]),
transpose_weights=True,
)
if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
# output_from_cudnn = [6, 5, 4, 0, 0]
# expected_output = [0, 0, 6, 5 ,4]
inputs = tf.reverse_sequence(
inputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
sequence_lengths=sequence_lengths,
time_major=time_major,
)
if go_backwards:
outputs = tf.reverse_sequence(
outputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs = tf.reverse(outputs, axis=[seq_axis])
else:
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = tf.reverse(inputs, axis=[0])
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
)
last_output = outputs[-1]
if not time_major and sequence_lengths is None and return_sequences:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
state = tf.squeeze(h, axis=seq_axis)
# In the case of variable length input, the cudnn kernel will fill zeros for
# the output, whereas the default keras behavior is to bring over the
# previous output for t-1, so that in the return_sequence=False case, user
# can quickly get the final effect output instead just 0s at the last
# timestep. In order to mimic the default keras behavior, we copy the final
# h state as the last_output, since it is numerically same as the output.
if sequence_lengths is not None:
last_output = state
# Match CPU return format
if not return_sequences:
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
return (
last_output,
outputs,
state,
)
def lstm(
inputs,
initial_state_h,
initial_state_c,
mask,
kernel,
recurrent_kernel,
bias,
activation,
recurrent_activation,
return_sequences=False,
go_backwards=False,
unroll=False,
time_major=False,
):
args_supported = _do_lstm_arguments_support_cudnn(
activation=activation,
recurrent_activation=recurrent_activation,
unroll=unroll,
bias=bias,
)
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
if not args_supported or not inputs_supported:
raise NotImplementedError
from keras_core.backend.tensorflow import Variable
if isinstance(kernel, Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
bias = bias.value
try:
return _cudnn_lstm(
inputs,
initial_state_h,
initial_state_c,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
)
except tf.errors.InvalidArgumentError:
# cuDNN op not found.
raise NotImplementedError
except tf.errors.NotFoundError:
# alternative error: device not found for op
raise NotImplementedError
def _cudnn_lstm(
inputs,
initial_state_h,
initial_state_c,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
):
if mask is not None:
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
else:
sequence_lengths = None
if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
# For init_h and init_c, cuDNN expects one more dim of num_layers before or
# after batch dim for time major or batch major inputs respectively
init_h = tf.expand_dims(initial_state_h, axis=seq_axis)
init_c = tf.expand_dims(initial_state_c, axis=seq_axis)
weights = tf.split(kernel, 4, axis=1)
weights += tf.split(recurrent_kernel, 4, axis=1)
# cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
# so that mathematically it is same as the canonical LSTM implementation.
full_bias = tf.concat((tf.zeros_like(bias), bias), 0)
if tf.sysconfig.get_build_info()["is_rocm_build"]:
# ROCm MIOpen's weight sequence for LSTM is different from both
# canonical and Cudnn format
# MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
# i is input gate weights.
# f is forget gate weights.
# o is output gate weights.
# c is cell gate weights.
weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
# full_bias is a tensor of shape (8*n,)
full_bias = tf.split(full_bias, 8, axis=0)
full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
params = _standardize_cudnn_weights(
weights=weights,
biases=tf.split(full_bias, 8),
shape=tf.constant([-1]),
transpose_weights=True,
)
if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
# output_from_cudnn = [6, 5, 4, 0, 0]
# expected_output = [0, 0, 6, 5 ,4]
inputs = tf.reverse_sequence(
inputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=init_c,
params=params,
is_training=True,
rnn_mode="lstm",
sequence_lengths=sequence_lengths,
time_major=time_major,
)
if go_backwards:
outputs = tf.reverse_sequence(
outputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs = tf.reverse(outputs, axis=[seq_axis])
else:
# # Fill the array with shape [batch] with value of max timesteps.
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
# array_ops.shape(inputs)[0])
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = tf.reverse(inputs, axis=[0])
outputs, h, c, _ = tf.raw_ops.CudnnRNN(
input=inputs,
input_h=init_h,
input_c=init_c,
params=params,
is_training=True,
rnn_mode="lstm",
)
last_output = outputs[-1]
if not time_major and sequence_lengths is None and return_sequences:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
h = tf.squeeze(h, axis=seq_axis)
c = tf.squeeze(c, axis=seq_axis)
# In the case of variable length input, the cudnn kernel will fill zeros for
# the output, whereas the default keras behavior is to bring over the
# previous output for t-1, so that in the return_sequence=False case, user
# can quickly get the final effect output instead just 0s at the last
# timestep. In order to mimic the default keras behavior, we copy the final
# h state as the last_output, since it is numerically same as the output.
if sequence_lengths is not None:
last_output = h
# Match CPU return format
if not return_sequences:
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
return (last_output, outputs, [h, c])