keras/keras_core/backend/jax/rnn.py
2023-05-19 11:40:17 -07:00

216 lines
7.1 KiB
Python

from jax import lax
from jax import numpy as jnp
from tensorflow import nest
from keras_core.backend.common.stateless_scope import StatelessScope
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return jnp.transpose(input_t, axes)
if not time_major:
inputs = nest.map_structure(swap_batch_timestep, inputs)
flattened_inputs = nest.flatten(inputs)
time_steps = flattened_inputs[0].shape[0]
if mask is not None:
if mask.dtype != "bool":
mask = mask.astype("bool")
if len(mask.shape) == 2:
mask = jnp.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
def _expand_mask(mask_t, input_t, fixed_dim=1):
if nest.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if nest.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = jnp.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])
return jnp.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if nest.is_nested(inputs):
processed_input = nest.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return nest.pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = jnp.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = jnp.where(tiled_mask_t, output, prev_output)
flat_states = nest.flatten(states)
flat_new_states = nest.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
jnp.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = nest.pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = jnp.stack(successive_outputs)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = jnp.stack(successive_outputs)
else: # Unroll == False
if mask is not None:
def _step(states, current_input):
current_input, current_mask = current_input
is_masked = jnp.all(
jnp.logical_not(current_mask), axis=-1, keepdims=True
)
output_t, new_states = step_function(current_input, states)
if zero_output_for_mask:
masked_outs = jnp.where(
is_masked, jnp.zeros_like(output_t), output_t
)
else:
# Assume the first state is the previous output.
output_tm1 = states[0]
masked_outs = jnp.where(is_masked, output_tm1, output_t)
new_states = [
jnp.where(is_masked, s, ns)
for s, ns in zip(states, new_states)
]
return (new_states, masked_outs)
scan_xs = (inputs, mask)
else:
def _step(states, current_input):
output_t, new_states = step_function(current_input, states)
return new_states, output_t
scan_xs = inputs
with StatelessScope():
# We must use a stateless scope because `scan` will involve
# JAX tracing -- any variable update at this stage would
# be a leak.
new_states, outputs = lax.scan(
f=_step,
init=initial_states,
xs=scan_xs,
reverse=go_backwards,
)
if go_backwards:
outputs = jnp.flip(outputs, axis=0)
last_output = outputs[-1]
if not time_major:
outputs = nest.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
def lstm(*args, **kwargs):
raise NotImplementedError
def gru(*args, **kwargs):
raise NotImplementedError
def unstack(x, axis=0):
return [
lax.index_in_dim(x, i, axis, keepdims=False)
for i in range(x.shape[axis])
]