216 lines
7.1 KiB
Python
216 lines
7.1 KiB
Python
from jax import lax
|
|
from jax import numpy as jnp
|
|
from tensorflow import nest
|
|
|
|
from keras_core.backend.common.stateless_scope import StatelessScope
|
|
|
|
|
|
def rnn(
|
|
step_function,
|
|
inputs,
|
|
initial_states,
|
|
go_backwards=False,
|
|
mask=None,
|
|
constants=None,
|
|
unroll=False,
|
|
input_length=None,
|
|
time_major=False,
|
|
zero_output_for_mask=False,
|
|
return_all_outputs=True,
|
|
):
|
|
def swap_batch_timestep(input_t):
|
|
# Swap the batch and timestep dim for the incoming tensor.
|
|
axes = list(range(len(input_t.shape)))
|
|
axes[0], axes[1] = 1, 0
|
|
return jnp.transpose(input_t, axes)
|
|
|
|
if not time_major:
|
|
inputs = nest.map_structure(swap_batch_timestep, inputs)
|
|
|
|
flattened_inputs = nest.flatten(inputs)
|
|
time_steps = flattened_inputs[0].shape[0]
|
|
|
|
if mask is not None:
|
|
if mask.dtype != "bool":
|
|
mask = mask.astype("bool")
|
|
if len(mask.shape) == 2:
|
|
mask = jnp.expand_dims(mask, axis=-1)
|
|
if not time_major:
|
|
mask = swap_batch_timestep(mask)
|
|
|
|
if constants is None:
|
|
constants = []
|
|
|
|
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
|
if nest.is_nested(mask_t):
|
|
raise ValueError(
|
|
f"mask_t is expected to be tensor, but got {mask_t}"
|
|
)
|
|
if nest.is_nested(input_t):
|
|
raise ValueError(
|
|
f"input_t is expected to be tensor, but got {input_t}"
|
|
)
|
|
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
|
for _ in range(rank_diff):
|
|
mask_t = jnp.expand_dims(mask_t, -1)
|
|
multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])
|
|
return jnp.tile(mask_t, multiples)
|
|
|
|
if unroll:
|
|
if not time_steps:
|
|
raise ValueError("Unrolling requires a fixed number of timesteps.")
|
|
states = tuple(initial_states)
|
|
successive_states = []
|
|
successive_outputs = []
|
|
|
|
# Process the input tensors. The input tensor need to be split on the
|
|
# time_step dim, and reverse if go_backwards is True. In the case of
|
|
# nested input, the input is flattened and then transformed
|
|
# individually. The result of this will be a tuple of lists, each of
|
|
# the item in tuple is list of the tensor with shape (batch, feature)
|
|
def _process_single_input_t(input_t):
|
|
input_t = unstack(input_t) # unstack for time_step dim
|
|
if go_backwards:
|
|
input_t.reverse()
|
|
return input_t
|
|
|
|
if nest.is_nested(inputs):
|
|
processed_input = nest.map_structure(
|
|
_process_single_input_t, inputs
|
|
)
|
|
else:
|
|
processed_input = (_process_single_input_t(inputs),)
|
|
|
|
def _get_input_tensor(time):
|
|
inp = [t_[time] for t_ in processed_input]
|
|
return nest.pack_sequence_as(inputs, inp)
|
|
|
|
if mask is not None:
|
|
mask_list = unstack(mask)
|
|
if go_backwards:
|
|
mask_list.reverse()
|
|
|
|
for i in range(time_steps):
|
|
inp = _get_input_tensor(i)
|
|
mask_t = mask_list[i]
|
|
output, new_states = step_function(
|
|
inp, tuple(states) + tuple(constants)
|
|
)
|
|
tiled_mask_t = _expand_mask(mask_t, output)
|
|
|
|
if not successive_outputs:
|
|
prev_output = jnp.zeros_like(output)
|
|
else:
|
|
prev_output = successive_outputs[-1]
|
|
|
|
output = jnp.where(tiled_mask_t, output, prev_output)
|
|
|
|
flat_states = nest.flatten(states)
|
|
flat_new_states = nest.flatten(new_states)
|
|
tiled_mask_t = tuple(
|
|
_expand_mask(mask_t, s) for s in flat_states
|
|
)
|
|
flat_final_states = tuple(
|
|
jnp.where(m, s, ps)
|
|
for m, s, ps in zip(
|
|
tiled_mask_t, flat_new_states, flat_states
|
|
)
|
|
)
|
|
states = nest.pack_sequence_as(states, flat_final_states)
|
|
|
|
if return_all_outputs:
|
|
successive_outputs.append(output)
|
|
successive_states.append(states)
|
|
else:
|
|
successive_outputs = [output]
|
|
successive_states = [states]
|
|
last_output = successive_outputs[-1]
|
|
new_states = successive_states[-1]
|
|
outputs = jnp.stack(successive_outputs)
|
|
|
|
else: # mask is None
|
|
for i in range(time_steps):
|
|
inp = _get_input_tensor(i)
|
|
output, states = step_function(
|
|
inp, tuple(states) + tuple(constants)
|
|
)
|
|
if return_all_outputs:
|
|
successive_outputs.append(output)
|
|
successive_states.append(states)
|
|
else:
|
|
successive_outputs = [output]
|
|
successive_states = [states]
|
|
last_output = successive_outputs[-1]
|
|
new_states = successive_states[-1]
|
|
outputs = jnp.stack(successive_outputs)
|
|
|
|
else: # Unroll == False
|
|
if mask is not None:
|
|
|
|
def _step(states, current_input):
|
|
current_input, current_mask = current_input
|
|
is_masked = jnp.all(
|
|
jnp.logical_not(current_mask), axis=-1, keepdims=True
|
|
)
|
|
|
|
output_t, new_states = step_function(current_input, states)
|
|
|
|
if zero_output_for_mask:
|
|
masked_outs = jnp.where(
|
|
is_masked, jnp.zeros_like(output_t), output_t
|
|
)
|
|
else:
|
|
# Assume the first state is the previous output.
|
|
output_tm1 = states[0]
|
|
masked_outs = jnp.where(is_masked, output_tm1, output_t)
|
|
|
|
new_states = [
|
|
jnp.where(is_masked, s, ns)
|
|
for s, ns in zip(states, new_states)
|
|
]
|
|
return (new_states, masked_outs)
|
|
|
|
scan_xs = (inputs, mask)
|
|
|
|
else:
|
|
|
|
def _step(states, current_input):
|
|
output_t, new_states = step_function(current_input, states)
|
|
return new_states, output_t
|
|
|
|
scan_xs = inputs
|
|
|
|
with StatelessScope():
|
|
# We must use a stateless scope because `scan` will involve
|
|
# JAX tracing -- any variable update at this stage would
|
|
# be a leak.
|
|
new_states, outputs = lax.scan(
|
|
f=_step,
|
|
init=initial_states,
|
|
xs=scan_xs,
|
|
reverse=go_backwards,
|
|
)
|
|
if go_backwards:
|
|
outputs = jnp.flip(outputs, axis=0)
|
|
last_output = outputs[-1]
|
|
|
|
if not time_major:
|
|
outputs = nest.map_structure(swap_batch_timestep, outputs)
|
|
|
|
return last_output, outputs, new_states
|
|
|
|
|
|
def lstm(*args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
def gru(*args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
def unstack(x, axis=0):
|
|
return [
|
|
lax.index_in_dim(x, i, axis, keepdims=False)
|
|
for i in range(x.shape[axis])
|
|
]
|