Add GRU layer.

This commit is contained in:
Francois Chollet 2023-05-12 15:39:48 -07:00
parent 77b4fcc3dc
commit 42236e5d4e
36 changed files with 1706 additions and 722 deletions

@ -11,7 +11,6 @@ from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
from keras_core.backend.jax import core
from keras_core.backend.jax import image
from keras_core.backend.jax import math
from keras_core.backend.jax import nn

@ -1,7 +0,0 @@
import jax.numpy as jnp
def scatter(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0))
return zeros.at[key].set(values)

@ -8,12 +8,14 @@ from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
from keras_core.backend.tensorflow import core
from keras_core.backend.tensorflow import image
from keras_core.backend.tensorflow import math
from keras_core.backend.tensorflow import nn
from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random
from keras_core.backend.tensorflow.rnn import gru
from keras_core.backend.tensorflow.rnn import lstm
from keras_core.backend.tensorflow.rnn import rnn
from keras_core.utils.naming import auto_name
DYNAMIC_SHAPES_OK = True

@ -1,5 +0,0 @@
import tensorflow as tf
def scatter(indices, values, shape):
return tf.scatter_nd(indices, values, shape)

@ -602,458 +602,3 @@ def binary_crossentropy(target, output, from_logits=False):
bce = target * tf.math.log(output)
bce += (1 - target) * tf.math.log(1 - output)
return -bce
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
"""Iterates over the time dimension of a tensor.
Args:
step_function: RNN step function.
Args;
`input`; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
`states`; List of tensors.
Returns;
`output`; Tensor with shape `(samples, output_dim)`
(no time dimension).
`new_states`; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D), or nested tensors, and each of which has shape
`(samples, time, ...)`.
initial_states: Tensor with shape `(samples, state_size)`
(no time dimension), containing the initial values for the states
used in the step function. In the case that state_size is in a
nested shape, the shape of initial_states will also follow the
nested structure.
go_backwards: Boolean. If `True`, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: An integer or a 1-D Tensor, depending on whether
the time dimension is fixed-length or not. In case of variable
length input, it is used for masking in case there's no mask
specified.
time_major: Boolean. If `True`, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of
the RNN calculation. However, most TensorFlow data is batch-major,
so by default this function accepts input and emits output in
batch-major form.
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
will be zeros, whereas in the `False` case, output from previous
timestep is returned.
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
all timesteps in the sequence. If `False`, only return the output
for the last timestep (which consumes less memory).
Returns:
A tuple, `(last_output, outputs, new_states)`.
- `last_output`: the latest output of the rnn,
with shape `(samples, ...)`.
- `outputs`:
- If `return_all_outputs=True`: a tensor with shape
`(samples, time, ...)` where each entry `outputs[s, t]` is the
output of the step function at time `t` for sample `s`
- Else, a tensor equal to `last_output` with shape
`(samples, 1, ...)`
- `new_states`: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return tf.transpose(input_t, axes)
if not time_major:
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
flatted_inputs = tf.nest.flatten(inputs)
time_steps = flatted_inputs[0].shape[0]
batch = flatted_inputs[0].shape[1]
time_steps_t = tf.shape(flatted_inputs[0])[0]
for input_ in flatted_inputs:
input_.shape.with_rank_at_least(3)
if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) == 2:
mask = tf.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
# tf.where needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
# So we need to broadcast the mask to match the shape of inputs.
# That's what the tile call does, it just repeats the mask along its
# second dimension n times.
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tf.nest.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tf.nest.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = tf.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
return tf.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tf.nest.is_nested(inputs):
processed_input = tf.nest.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return tf.nest.pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = tf.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
flat_states = tf.nest.flatten(states)
flat_new_states = tf.nest.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
tf.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = tf.nest.pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
if zero_output_for_mask:
last_output = tf.where(
_expand_mask(mask_list[-1], last_output),
last_output,
tf.zeros_like(last_output),
)
outputs = tf.where(
_expand_mask(mask, outputs, fixed_dim=2),
outputs,
tf.zeros_like(outputs),
)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else: # Unroll == False
states = tuple(initial_states)
# Create input tensor array, if the inputs is nested tensors, then it
# will be flattened first, and tensor array will be created one per
# flattened tensor.
input_ta = tuple(
tf.TensorArray(
dtype=inp.dtype,
size=time_steps_t,
tensor_array_name=f"input_ta_{i}",
)
for i, inp in enumerate(flatted_inputs)
)
input_ta = tuple(
ta.unstack(input_)
if not go_backwards
else ta.unstack(tf.reverse(input_, [0]))
for ta, input_ in zip(input_ta, flatted_inputs)
)
# Get the time(0) input and compute the output for that, the output will
# be used to determine the dtype of output tensor array. Don't read from
# input_ta due to TensorArray clear_after_read default to True.
input_time_zero = tf.nest.pack_sequence_as(
inputs, [inp[0] for inp in flatted_inputs]
)
# output_time_zero is used to determine the cell output shape and its
# dtype. the value is discarded.
output_time_zero, _ = step_function(
input_time_zero, tuple(initial_states) + tuple(constants)
)
output_ta_size = time_steps_t if return_all_outputs else 1
output_ta = tuple(
tf.TensorArray(
dtype=out.dtype,
size=output_ta_size,
element_shape=out.shape,
tensor_array_name=f"output_ta_{i}",
)
for i, out in enumerate(tf.nest.flatten(output_time_zero))
)
time = tf.constant(0, dtype="int32", name="time")
if input_length is None:
max_iterations = time_steps_t
else:
max_iterations = tf.reduce_max(input_length)
while_loop_kwargs = {
"cond": lambda time, *_: time < time_steps_t,
"maximum_iterations": max_iterations,
"parallel_iterations": 32,
"swap_memory": True,
}
if mask is not None:
if go_backwards:
mask = tf.reverse(mask, [0])
mask_ta = tf.TensorArray(
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
)
mask_ta = mask_ta.unstack(mask)
def masking_fn(time):
return mask_ta.read(time)
def compute_masked_output(mask_t, flat_out, flat_mask):
tiled_mask_t = tuple(
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
for o in flat_out
)
return tuple(
tf.where(m, o, fm)
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
)
elif isinstance(input_length, tf.Tensor):
if go_backwards:
max_len = tf.reduce_max(input_length, axis=0)
rev_input_length = tf.subtract(max_len - 1, input_length)
def masking_fn(time):
return tf.less(rev_input_length, time)
else:
def masking_fn(time):
return tf.greater(input_length, time)
def compute_masked_output(mask_t, flat_out, flat_mask):
return tuple(
tf.where(mask_t, o, zo)
for (o, zo) in zip(flat_out, flat_mask)
)
else:
masking_fn = None
if masking_fn is not None:
# Mask for the T output will be base on the output of T - 1. In the
# case T = 0, a zero filled tensor will be used.
flat_zero_output = tuple(
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
)
def _step(time, output_ta_t, prev_output, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
# maybe set shape.
current_input = tf.nest.pack_sequence_as(inputs, current_input)
mask_t = masking_fn(time)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
# mask output
flat_output = tf.nest.flatten(output)
flat_mask_output = (
flat_zero_output
if zero_output_for_mask
else tf.nest.flatten(prev_output)
)
flat_new_output = compute_masked_output(
mask_t, flat_output, flat_mask_output
)
# mask states
flat_state = tf.nest.flatten(states)
flat_new_state = tf.nest.flatten(new_states)
flat_final_state = compute_masked_output(
mask_t, flat_new_state, flat_state
)
new_states = tf.nest.pack_sequence_as(
new_states, flat_final_state
)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_new_output)
)
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
new_states
)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta, flat_zero_output) + states,
**while_loop_kwargs,
)
# Skip final_outputs[2] which is the output for final timestep.
new_states = final_outputs[3:]
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
Returns:
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
current_input = tf.nest.pack_sequence_as(inputs, current_input)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
flat_new_state = tf.nest.flatten(new_states)
flat_output = tf.nest.flatten(output)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_output)
)
new_states = tf.nest.pack_sequence_as(
initial_states, flat_new_state
)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs,
)
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = tuple(o.stack() for o in output_ta)
last_output = tuple(o[-1] for o in outputs)
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
# static shape inference
def set_shape(output_):
if isinstance(output_, tf.Tensor):
shape = output_.shape.as_list()
if return_all_outputs:
shape[0] = time_steps
else:
shape[0] = 1
shape[1] = batch
output_.set_shape(shape)
return output_
outputs = tf.nest.map_structure(set_shape, outputs)
if not time_major:
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states

@ -0,0 +1,763 @@
import tensorflow as tf
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
"""Iterates over the time dimension of a tensor.
Args:
step_function: RNN step function.
Args;
`input`; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
`states`; List of tensors.
Returns;
`output`; Tensor with shape `(samples, output_dim)`
(no time dimension).
`new_states`; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D), or nested tensors, and each of which has shape
`(samples, time, ...)`.
initial_states: Tensor with shape `(samples, state_size)`
(no time dimension), containing the initial values for the states
used in the step function. In the case that state_size is in a
nested shape, the shape of initial_states will also follow the
nested structure.
go_backwards: Boolean. If `True`, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: An integer or a 1-D Tensor, depending on whether
the time dimension is fixed-length or not. In case of variable
length input, it is used for masking in case there's no mask
specified.
time_major: Boolean. If `True`, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of
the RNN calculation. However, most TensorFlow data is batch-major,
so by default this function accepts input and emits output in
batch-major form.
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
will be zeros, whereas in the `False` case, output from previous
timestep is returned.
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
all timesteps in the sequence. If `False`, only return the output
for the last timestep (which consumes less memory).
Returns:
A tuple, `(last_output, outputs, new_states)`.
- `last_output`: the latest output of the rnn,
with shape `(samples, ...)`.
- `outputs`:
- If `return_all_outputs=True`: a tensor with shape
`(samples, time, ...)` where each entry `outputs[s, t]` is the
output of the step function at time `t` for sample `s`
- Else, a tensor equal to `last_output` with shape
`(samples, 1, ...)`
- `new_states`: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return tf.transpose(input_t, axes)
if not time_major:
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
flatted_inputs = tf.nest.flatten(inputs)
time_steps = flatted_inputs[0].shape[0]
batch = flatted_inputs[0].shape[1]
time_steps_t = tf.shape(flatted_inputs[0])[0]
for input_ in flatted_inputs:
input_.shape.with_rank_at_least(3)
if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) == 2:
mask = tf.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
# tf.where needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
# So we need to broadcast the mask to match the shape of inputs.
# That's what the tile call does, it just repeats the mask along its
# second dimension n times.
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tf.nest.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tf.nest.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = tf.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
return tf.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tf.nest.is_nested(inputs):
processed_input = tf.nest.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return tf.nest.pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = tf.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
flat_states = tf.nest.flatten(states)
flat_new_states = tf.nest.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
tf.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = tf.nest.pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
if zero_output_for_mask:
last_output = tf.where(
_expand_mask(mask_list[-1], last_output),
last_output,
tf.zeros_like(last_output),
)
outputs = tf.where(
_expand_mask(mask, outputs, fixed_dim=2),
outputs,
tf.zeros_like(outputs),
)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else: # Unroll == False
states = tuple(initial_states)
# Create input tensor array, if the inputs is nested tensors, then it
# will be flattened first, and tensor array will be created one per
# flattened tensor.
input_ta = tuple(
tf.TensorArray(
dtype=inp.dtype,
size=time_steps_t,
tensor_array_name=f"input_ta_{i}",
)
for i, inp in enumerate(flatted_inputs)
)
input_ta = tuple(
ta.unstack(input_)
if not go_backwards
else ta.unstack(tf.reverse(input_, [0]))
for ta, input_ in zip(input_ta, flatted_inputs)
)
# Get the time(0) input and compute the output for that, the output will
# be used to determine the dtype of output tensor array. Don't read from
# input_ta due to TensorArray clear_after_read default to True.
input_time_zero = tf.nest.pack_sequence_as(
inputs, [inp[0] for inp in flatted_inputs]
)
# output_time_zero is used to determine the cell output shape and its
# dtype. the value is discarded.
output_time_zero, _ = step_function(
input_time_zero, tuple(initial_states) + tuple(constants)
)
output_ta_size = time_steps_t if return_all_outputs else 1
output_ta = tuple(
tf.TensorArray(
dtype=out.dtype,
size=output_ta_size,
element_shape=out.shape,
tensor_array_name=f"output_ta_{i}",
)
for i, out in enumerate(tf.nest.flatten(output_time_zero))
)
time = tf.constant(0, dtype="int32", name="time")
if input_length is None:
max_iterations = time_steps_t
else:
max_iterations = tf.reduce_max(input_length)
while_loop_kwargs = {
"cond": lambda time, *_: time < time_steps_t,
"maximum_iterations": max_iterations,
"parallel_iterations": 32,
"swap_memory": True,
}
if mask is not None:
if go_backwards:
mask = tf.reverse(mask, [0])
mask_ta = tf.TensorArray(
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
)
mask_ta = mask_ta.unstack(mask)
def masking_fn(time):
return mask_ta.read(time)
def compute_masked_output(mask_t, flat_out, flat_mask):
tiled_mask_t = tuple(
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
for o in flat_out
)
return tuple(
tf.where(m, o, fm)
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
)
elif isinstance(input_length, tf.Tensor):
if go_backwards:
max_len = tf.reduce_max(input_length, axis=0)
rev_input_length = tf.subtract(max_len - 1, input_length)
def masking_fn(time):
return tf.less(rev_input_length, time)
else:
def masking_fn(time):
return tf.greater(input_length, time)
def compute_masked_output(mask_t, flat_out, flat_mask):
return tuple(
tf.where(mask_t, o, zo)
for (o, zo) in zip(flat_out, flat_mask)
)
else:
masking_fn = None
if masking_fn is not None:
# Mask for the T output will be base on the output of T - 1. In the
# case T = 0, a zero filled tensor will be used.
flat_zero_output = tuple(
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
)
def _step(time, output_ta_t, prev_output, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
# maybe set shape.
current_input = tf.nest.pack_sequence_as(inputs, current_input)
mask_t = masking_fn(time)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
# mask output
flat_output = tf.nest.flatten(output)
flat_mask_output = (
flat_zero_output
if zero_output_for_mask
else tf.nest.flatten(prev_output)
)
flat_new_output = compute_masked_output(
mask_t, flat_output, flat_mask_output
)
# mask states
flat_state = tf.nest.flatten(states)
flat_new_state = tf.nest.flatten(new_states)
flat_final_state = compute_masked_output(
mask_t, flat_new_state, flat_state
)
new_states = tf.nest.pack_sequence_as(
new_states, flat_final_state
)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_new_output)
)
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
new_states
)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta, flat_zero_output) + states,
**while_loop_kwargs,
)
# Skip final_outputs[2] which is the output for final timestep.
new_states = final_outputs[3:]
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
Returns:
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
current_input = tf.nest.pack_sequence_as(inputs, current_input)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
flat_new_state = tf.nest.flatten(new_states)
flat_output = tf.nest.flatten(output)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_output)
)
new_states = tf.nest.pack_sequence_as(
initial_states, flat_new_state
)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs,
)
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = tuple(o.stack() for o in output_ta)
last_output = tuple(o[-1] for o in outputs)
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
# static shape inference
def set_shape(output_):
if isinstance(output_, tf.Tensor):
shape = output_.shape.as_list()
if return_all_outputs:
shape[0] = time_steps
else:
shape[0] = 1
shape[1] = batch
output_.set_shape(shape)
return output_
outputs = tf.nest.map_structure(set_shape, outputs)
if not time_major:
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
def gru(
inputs,
initial_state,
mask,
kernel,
recurrent_kernel,
bias,
activation,
recurrent_activation,
return_sequences=False,
go_backwards=False,
unroll=False,
time_major=False,
reset_after=True,
):
args_supported = _do_gru_arguments_support_cudnn(
activation=activation,
recurrent_activation=recurrent_activation,
unroll=unroll,
bias=bias,
reset_after=reset_after,
)
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
if not args_supported or not inputs_supported:
raise NotImplementedError
from keras_core.backend.tensorflow import Variable
if isinstance(kernel, Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
bias = bias.value
try:
return _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
)
except tf.errors.InvalidArgumentError:
# cuDNN op not found.
raise NotImplementedError
def _do_gru_arguments_support_cudnn(
activation,
recurrent_activation,
unroll,
bias,
reset_after,
):
from keras_core import activations
from keras_core import operations as ops
return (
activation in (activations.tanh, tf.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
and not unroll
and bias is not None
and reset_after
)
def _do_rnn_inputs_support_cudnn(mask, time_major):
if tf.sysconfig.get_build_info()["is_rocm_build"]:
if mask is not None:
return tf.reduce_all(mask)
return True
if mask is None:
return True
if time_major:
mask = tf.transpose(mask)
return tf.logical_and(
_is_sequence_right_padded(mask),
tf.logical_not(_has_fully_masked_sequence(mask)),
)
def _is_sequence_right_padded(mask):
"""Check the mask tensor and see if it right padded.
For cuDNN kernel, it uses the sequence length param to skip the tailing
timestep. If the data is left padded, or not a strict right padding (has
masked value in the middle of the sequence), then cuDNN kernel won't be work
properly in those cases.
Left padded data: [[False, False, True, True, True]].
Right padded data: [[True, True, True, False, False]].
Mixture of mask/unmasked data: [[True, False, True, False, False]].
Note that for the mixed data example above, the actually data RNN should see
are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
not pollute the internal states.
Args:
mask: the Boolean tensor with shape [batch, timestep]
Returns:
boolean scalar tensor, whether the mask is strictly right padded.
"""
max_seq_length = tf.shape(mask)[1]
count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
return tf.reduce_all(tf.equal(mask, right_padded_mask))
def _has_fully_masked_sequence(mask):
# Cudnn kernel will error out if the input sequence contains any
# fully masked data. We walk around this issue by rerouting the computation
# to standard kernel, until the issue on cudnn side has been fixed. For a
# fully masked sequence, it will contain all Falses. To make it easy to
# check, we inverse the boolean, check if any of the sequence has all True.
return tf.reduce_any(tf.reduce_all(tf.logical_not(mask), axis=1))
def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
"""Utility function convert variable to cuDNN compatible parameter.
Note that Keras weights for kernels are different from the cuDNN format.
Eg.:
```
Keras cuDNN
[[0, 1, 2], <---> [[0, 2, 4],
[3, 4, 5]] [1, 3, 5]]
```
If the input weights need to be in a unified format, then set
`transpose_weights=True` to convert the weights.
Args:
weights: list of weights for the kernels and recurrent kernels.
biases: list of biases for individual gate.
shape: the shape for the converted variables that will be feed to cuDNN.
transpose_weights: boolean, whether to transpose the weights.
Returns:
The converted weights that can be feed to cuDNN ops as param.
"""
def convert(w):
return tf.transpose(w) if transpose_weights else w
weights = [tf.reshape(convert(x), shape) for x in weights]
biases = [tf.reshape(x, shape) for x in biases]
return tf.concat(weights + biases, axis=0)
def _compute_sequence_length_from_mask(mask, time_major):
"""Calculate the sequence length tensor (1-D) based on the masking tensor.
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
any timestep that should be masked, the corresponding field will be False.
Consider the following example:
a = [[True, True, False, False],
[True, True, True, False]]
It is a (2, 4) tensor, and the corresponding sequence length result should
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
padded that could be checked by, e.g., `is_sequence_right_padded()`.
Args:
mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
time_major=True.
time_major: Boolean, which indicates whether the mask is time major or
batch major.
Returns:
sequence_length: 1D int32 tensor.
"""
timestep_index = 0 if time_major else 1
return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)
@tf.function(autograph=False)
def _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
):
"""GRU with cuDNN implementation which is only available for GPU."""
if mask is not None:
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
else:
sequence_lengths = None
if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
# For init_h, cuDNN expects one more dim of num_layers before or after batch
# dim for time major or batch major inputs respectively
init_h = tf.expand_dims(initial_state, axis=seq_axis)
weights = tf.split(kernel, 3, axis=1)
weights += tf.split(recurrent_kernel, 3, axis=1)
# Note that the bias was initialized as shape (2, 3 * units), flatten it to
# (6 * units)
bias = tf.split(tf.reshape(bias, [-1]), 6)
if tf.sysconfig.get_build_info()["is_cuda_build"]:
# Note that the gate order for cuDNN is different from the canonical
# format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
# The swap need to be done for kernel, recurrent_kernel, input_bias,
# recurrent_bias.
# z is update gate weights.
# r is reset gate weights.
# h is output gate weights.
weights[0], weights[1] = weights[1], weights[0]
weights[3], weights[4] = weights[4], weights[3]
bias[0], bias[1] = bias[1], bias[0]
bias[3], bias[4] = bias[4], bias[3]
params = _standardize_cudnn_weights(
weights=weights,
biases=bias,
shape=tf.constant([-1]),
transpose_weights=True,
)
if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
# output_from_cudnn = [6, 5, 4, 0, 0]
# expected_output = [0, 0, 6, 5 ,4]
inputs = tf.reverse_sequence(
inputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
sequence_lengths=sequence_lengths,
time_major=time_major,
)
if go_backwards:
outputs = tf.reverse_sequence(
outputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs = tf.reverse(outputs, axis=[seq_axis])
else:
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = tf.reverse(inputs, axis=[0])
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
)
last_output = outputs[-1]
if not time_major and sequence_lengths is None and return_sequences:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
state = tf.squeeze(h, axis=seq_axis)
# In the case of variable length input, the cudnn kernel will fill zeros for
# the output, whereas the default keras behavior is to bring over the
# previous output for t-1, so that in the return_sequence=False case, user
# can quickly get the final effect output instead just 0s at the last
# timestep. In order to mimic the default keras behavior, we copy the final
# h state as the last_output, since it is numerically same as the output.
if sequence_lengths is not None:
last_output = state
# Match CPU return format
if not return_sequences:
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
return (
last_output,
outputs,
state,
)
def lstm():
pass

@ -114,15 +114,12 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def train_function(iterator):
"""Runs a training execution with multiple steps."""
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if not self.run_eagerly:
train_function = tf.function(train_function, reduce_retracing=True)
train_function = tf.function(
one_step_on_iterator, reduce_retracing=True
)
else:
train_function = one_step_on_iterator
self.train_function = train_function
def make_test_function(self, force=False):
@ -247,7 +244,6 @@ class TensorFlowTrainer(base_trainer.Trainer):
shuffle=shuffle,
class_weight=class_weight,
distribute_strategy=self.distribute_strategy,
steps_per_execution=self.steps_per_execution,
)
# Container that configures and calls callbacks.
@ -449,9 +445,7 @@ class TFEpochIterator(EpochIterator):
self.data_adapter.get_tf_dataset()
)
)
for step in range(
0, self.steps_per_epoch, self.steps_per_execution
):
for step in range(self.steps_per_epoch):
yield step, self._current_iterator
else:
iterator = iter(
@ -460,14 +454,12 @@ class TFEpochIterator(EpochIterator):
)
)
if self.num_batches:
for step in range(
0, self.num_batches, self.steps_per_execution
):
for step in range(self.num_batches):
yield step, iterator
else:
step = -1
while True:
step += self.steps_per_execution
step += 1
self._steps_seen = step + 1
yield step, iterator
self.data_adapter.on_epoch_end()

@ -99,6 +99,7 @@ from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.rnn.gru import GRU
from keras_core.layers.rnn.rnn import RNN
from keras_core.layers.rnn.simple_rnn import SimpleRNN
from keras_core.layers.rnn.stacked_rnn_cells import StackedRNNCells

@ -14,11 +14,11 @@ from keras_core.operations.operation_utils import compute_conv_output_shape
class BaseConv(Layer):
"""Abstract N-D convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved (actually
cross-correlated) with the layer input to produce a tensor of outputs. If
`use_bias` is True (and a `bias_initializer` is provided), a bias vector is
created and added to the outputs. Finally, if `activation` is not `None`, it
is applied to the outputs as well.
This layer creates a convolution kernel that is convolved
(actually cross-correlated) with the layer input to produce a tensor of
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
Note: layer attributes cannot be modified after the layer has been called
once (except the `trainable` attribute).
@ -27,12 +27,12 @@ class BaseConv(Layer):
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
filters: int, the dimension of the output space (the number of filters
in the convolution).
kernel_size: int or tuple/list of `rank` integers, specifying the size
of the convolution window.
strides: int or tuple/list of `rank` integers, specifying the stride
length of the convolution. If only one int is specified, the same
stride size will be used for all dimensions. `strides > 1` is
incompatible with `dilation_rate > 1`.
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the convolution. If only one int is specified, the same stride
size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -44,9 +44,9 @@ class BaseConv(Layer):
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of `rank` integers, specifying the
dilation rate to use for dilated convolution. If only one int is
specified, the same dilation rate will be used for all dimensions.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
@ -97,7 +97,7 @@ class BaseConv(Layer):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs,
)
self.rank = rank
@ -144,14 +144,14 @@ class BaseConv(Layer):
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received "
f"kernel_size={self.kernel_size}."
"The argument `kernel_size` cannot contain 0(s). Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received "
f"strides={self.strides}"
"The argument `strides` cannot contains 0(s). Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:

@ -14,25 +14,26 @@ from keras_core.layers.layer import Layer
class BaseConvTranspose(Layer):
"""Abstract N-D transposed convolution layer.
"""Abstract N-D transpose convolution layer.
The need for transposed convolutions generally arises from the desire to use
a transformation going in the opposite direction of a normal convolution,
i.e., from something that has the shape of the output of some convolution to
something that has the shape of its input while maintaining a connectivity
pattern that is compatible with said convolution.
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
Args:
rank: int, the rank of the transposed convolution, e.g. 2 for 2D
transposed convolution.
filters: int, the dimension of the output space (the number of filters
in the transposed convolution).
kernel_size: int or tuple/list of `rank` integers, specifying the size
of the transposed convolution window.
strides: int or tuple/list of `rank` integers, specifying the stride
length of the transposed convolution. If only one int is specified,
the same stride size will be used for all dimensions.
`strides > 1` is incompatible with `dilation_rate > 1`.
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the transposed convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the transposed convolution. If only one int is specified, the
same stride size will be used for all dimensions.
`stride value != 1` is incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -44,9 +45,9 @@ class BaseConvTranspose(Layer):
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of `rank` integers, specifying the
dilation rate to use for dilated convolution. If only one int is
specified, the same dilation rate will be used for all dimensions.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
@ -130,14 +131,14 @@ class BaseConvTranspose(Layer):
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received "
f"kernel_size={self.kernel_size}."
"The argument `kernel_size` cannot contain 0. Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received "
f"strides={self.strides}."
"The argument `strides` cannot contains 0. Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:

@ -39,12 +39,12 @@ class BaseDepthwiseConv(Layer):
depth_multiplier: The number of depthwise convolution output channels
for each input channel. The total number of depthwise convolution
output channels will be equal to `input_channel * depth_multiplier`.
kernel_size: int or tuple/list of `rank` integers, specifying the size
of the depthwise convolution window.
strides: int or tuple/list of `rank` integers, specifying the stride
length of the depthwise convolution. If only one int is specified,
the same stride size will be used for all dimensions.
`strides > 1` is incompatible with `dilation_rate > 1`.
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
size of the depthwise convolution window.
strides: int or tuple/list of N integers, specifying the stride length
of the depthwise convolution. If only one int is specified, the same
stride size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -56,9 +56,9 @@ class BaseDepthwiseConv(Layer):
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of `rank` integers, specifying the
dilation rate to use for dilated convolution. If only one int is
specified, the same dilation rate will be used for all dimensions.
dilation_rate: int or tuple/list of N integers, specifying the dilation
rate to use for dilated convolution. If only one int is specified,
the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
@ -143,14 +143,14 @@ class BaseDepthwiseConv(Layer):
if not all(self.kernel_size):
raise ValueError(
"The argument `kernel_size` cannot contain 0. Received "
f"kernel_size={self.kernel_size}."
"The argument `kernel_size` cannot contain 0(s). Received: "
f"{self.kernel_size}"
)
if not all(self.strides):
raise ValueError(
"The argument `strides` cannot contains 0. Received "
f"strides={self.strides}"
"The argument `strides` cannot contains 0(s). Received: "
f"{self.strides}"
)
if max(self.strides) > 1 and max(self.dilation_rate) > 1:

@ -21,8 +21,8 @@ class Conv1D(BaseConv):
kernel_size: int or tuple/list of 1 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -72,9 +72,9 @@ class Conv1D(BaseConv):
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, filters, new_steps)`
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
Returns:
A 3D tensor representing `activation(conv1d(inputs, kernel) + bias)`.

@ -25,8 +25,8 @@ class Conv1DTranspose(BaseConvTranspose):
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -40,6 +40,11 @@ class Conv1DTranspose(BaseConvTranspose):
If you never set it, then it will be `"channels_last"`.
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
rate to use for dilated transposed convolution.
groups: A positive int specifying the number of groups in which the
input is split along the channel axis. Each group is convolved
separately with `filters // groups` filters. The output is the
concatenation of all the `groups` results along the channel axis.
Input channels and `filters` must both be divisible by `groups`.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
@ -66,9 +71,9 @@ class Conv1DTranspose(BaseConvTranspose):
Output shape:
- If `data_format="channels_last"`:
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
- If `data_format="channels_first"`:
A 3D tensor with shape: `(batch_shape, filters, new_steps)`
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
Returns:
A 3D tensor representing

@ -20,8 +20,8 @@ class Conv2D(BaseConv):
kernel_size: int or tuple/list of 2 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 2 integer, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -68,7 +68,7 @@ class Conv2D(BaseConv):
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`

@ -25,8 +25,8 @@ class Conv2DTranspose(BaseConvTranspose):
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -68,7 +68,7 @@ class Conv2DTranspose(BaseConvTranspose):
Output shape:
- If `data_format="channels_last"`:
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
- If `data_format="channels_first"`:
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`

@ -20,8 +20,8 @@ class Conv3D(BaseConv):
kernel_size: int or tuple/list of 3 integer, specifying the size of the
convolution window.
strides: int or tuple/list of 3 integer, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the convolution. `stride value != 1` is incompatible with
`dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -72,10 +72,10 @@ class Conv3D(BaseConv):
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
filters)`
channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, filters, new_spatial_dim1, new_spatial_dim2,
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
new_spatial_dim3)`
Returns:

@ -25,8 +25,8 @@ class Conv3DTranspose(BaseConvTranspose):
kernel_size: int or tuple/list of 1 integer, specifying the size of the
transposed convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the transposed convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the transposed convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same
@ -72,10 +72,10 @@ class Conv3DTranspose(BaseConvTranspose):
- If `data_format="channels_last"`:
5D tensor with shape:
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
filters)`
channels)`
- If `data_format="channels_first"`:
5D tensor with shape:
`(batch_size, filters, new_spatial_dim1, new_spatial_dim2,
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
new_spatial_dim3)`
Returns:

@ -417,4 +417,4 @@ class ConvTransposeCorrectnessTest(testing.TestCase, parameterized.TestCase):
tf_keras_layer.bias.assign(bias_weights)
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
self.assertAllClose(outputs, expected, atol=1e-5)

@ -34,8 +34,8 @@ class DepthwiseConv1D(BaseDepthwiseConv):
kernel_size: int or tuple/list of 1 integer, specifying the size of the
depthwise convolution window.
strides: int or tuple/list of 1 integer, specifying the stride length
of the convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same

@ -34,8 +34,8 @@ class DepthwiseConv2D(BaseDepthwiseConv):
kernel_size: int or tuple/list of 2 integer, specifying the size of the
depthwise convolution window.
strides: int or tuple/list of 2 integer, specifying the stride length
of the depthwise convolution. `strides > 1` is incompatible with
`dilation_rate > 1`.
of the depthwise convolution. `stride value != 1` is incompatible
with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same

@ -29,8 +29,8 @@ class SeparableConv1D(BaseSeparableConv):
depthwise convolution window.
strides: int or tuple/list of 1 integers, specifying the stride length
of the depthwise convolution. If only one int is specified, the same
stride size will be used for all dimensions. `strides > 1` is
incompatible with `dilation_rate > 1`.
stride size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same

@ -29,8 +29,8 @@ class SeparableConv2D(BaseSeparableConv):
depthwise convolution window.
strides: int or tuple/list of 2 integers, specifying the stride length
of the depthwise convolution. If only one int is specified, the same
stride size will be used for all dimensions. `strides > 1` is
incompatible with `dilation_rate > 1`.
stride size will be used for all dimensions. `stride value != 1` is
incompatible with `dilation_rate != 1`.
padding: string, either `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the same

@ -0,0 +1,607 @@
from tensorflow import nest
from keras_core import activations
from keras_core import backend
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.layers.rnn.dropout_rnn_cell import DropoutRNNCell
from keras_core.layers.rnn.rnn import RNN
@keras_core_export("keras_core.layers.GRUCell")
class GRUCell(Layer, DropoutRNNCell):
"""Cell class for the GRU layer.
This class processes one step within the whole time sequence input, whereas
`keras_core.layer.GRU` processes the whole sequence.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use. Default: hyperbolic tangent
(`tanh`). If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use for the recurrent step.
Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
applied (ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer
should use a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`"glorot_uniform"`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation
of the recurrent state. Default: `"orthogonal"`.
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector.
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector.
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). False = "before",
True = "after" (default and cuDNN compatible).
Call arguments:
inputs: A 2D tensor, with shape `(batch, features)`.
states: A 2D tensor with shape `(batch, units)`, which is the state
from the previous time step.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Example:
>>> inputs = np.random.random((32, 10, 8))
>>> rnn = keras_core.layers.RNN(keras_core.layers.GRUCell(4))
>>> output = rnn(inputs)
>>> output.shape
(32, 4)
>>> rnn = keras_core.layers.RNN(
... keras_core.layers.GRUCell(4),
... return_sequences=True,
... return_state=True)
>>> whole_sequence_output, final_state = rnn(inputs)
>>> whole_sequence_output.shape
(32, 10, 4)
>>> final_state.shape
(32, 4)
"""
def __init__(
self,
units,
activation="tanh",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
reset_after=True,
seed=None,
**kwargs,
):
if units <= 0:
raise ValueError(
"Received an invalid value for argument `units`, "
f"expected a positive integer, got {units}."
)
super().__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1.0, max(0.0, dropout))
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
self.seed = seed
self.seed_generator = backend.random.SeedGenerator(seed=seed)
self.reset_after = reset_after
self.state_size = self.units
self.output_size = self.units
def build(self, input_shape):
super().build(input_shape)
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 3),
name="kernel",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 3),
name="recurrent_kernel",
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
)
if self.use_bias:
if not self.reset_after:
bias_shape = (3 * self.units,)
else:
# separate biases for input and recurrent kernels
# Note: the shape is intentionally different from CuDNNGRU
# biases `(2 * 3 * self.units,)`, so that we can distinguish the
# classes when loading and converting saved weights.
bias_shape = (2, 3 * self.units)
self.bias = self.add_weight(
shape=bias_shape,
name="bias",
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
h_tm1 = (
states[0] if nest.is_nested(states) else states
) # previous state
dp_mask = self.get_dropout_mask(inputs)
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
if self.use_bias:
if not self.reset_after:
input_bias, recurrent_bias = self.bias, None
else:
input_bias, recurrent_bias = (
ops.squeeze(e, axis=0)
for e in ops.split(self.bias, self.bias.shape[0], axis=0)
)
if training and 0.0 < self.dropout < 1.0:
inputs *= dp_mask
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
x_z = ops.matmul(inputs_z, self.kernel[:, : self.units])
x_r = ops.matmul(inputs_r, self.kernel[:, self.units : self.units * 2])
x_h = ops.matmul(inputs_h, self.kernel[:, self.units * 2 :])
if self.use_bias:
x_z += input_bias[: self.units]
x_r += input_bias[self.units : self.units * 2]
x_h += input_bias[self.units * 2 :]
if training and 0.0 < self.recurrent_dropout < 1.0:
h_tm1 *= rec_dp_mask
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
recurrent_z = ops.matmul(
h_tm1_z, self.recurrent_kernel[:, : self.units]
)
recurrent_r = ops.matmul(
h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
)
if self.reset_after and self.use_bias:
recurrent_z += recurrent_bias[: self.units]
recurrent_r += recurrent_bias[self.units : self.units * 2]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
# reset gate applied after/before matrix multiplication
if self.reset_after:
recurrent_h = ops.matmul(
h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
)
if self.use_bias:
recurrent_h += recurrent_bias[self.units * 2 :]
recurrent_h = r * recurrent_h
else:
recurrent_h = ops.matmul(
r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
)
hh = self.activation(x_h + recurrent_h)
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
new_state = [h] if nest.is_nested(states) else h
return h, new_state
def get_config(self):
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
}
base_config = super().get_config()
return {**base_config, **config}
def get_initial_state(self, batch_size=None):
return [
ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype)
]
class GRU(RNN):
"""Gated Recurrent Unit - Cho et al. 2014.
Based on available runtime hardware and constraints, this layer
will choose different implementations (cuDNN-based or backend-native)
to maximize the performance. If a GPU is available and all
the arguments to the layer meet the requirement of the cuDNN kernel
(see below for details), the layer will use a fast cuDNN implementation
when using the TensorFlow backend.
The requirements to use the cuDNN implementation are:
1. `activation` == `tanh`
2. `recurrent_activation` == `sigmoid`
3. `recurrent_dropout` == 0
4. `unroll` is `False`
5. `use_bias` is `True`
6. `reset_after` is `True`
7. Inputs, if use masking, are strictly right-padded.
8. Eager execution is enabled in the outermost context.
There are two variants of the GRU implementation. The default one is based
on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to
hidden state before matrix multiplication. The other one is based on
[original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
The second variant is compatible with CuDNNGRU (GPU-only) and allows
inference on CPU. Thus it has separate biases for `kernel` and
`recurrent_kernel`. To use this variant, set `reset_after=True` and
`recurrent_activation='sigmoid'`.
For example:
>>> inputs = np.random.random((32, 10, 8))
>>> gru = keras_core.layers.GRU(4)
>>> output = gru(inputs)
>>> output.shape
(32, 4)
>>> gru = keras_core.layers.GRU(4, return_sequences=True, return_state=True)
>>> whole_sequence_output, final_state = gru(inputs)
>>> whole_sequence_output.shape
(32, 10, 4)
>>> final_state.shape
(32, 4)
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: sigmoid (`sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer
should use a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`"glorot_uniform"`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `"orthogonal"`.
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector.
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector.
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state in addition
to the output. Default: `False`.
go_backwards: Boolean (default `False`).
If `True`, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If `True`, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If `True`, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). `False` is `"before"`,
`True` is `"after"` (default and cuDNN compatible).
Call arguments:
inputs: A 3D tensor, with shape `(batch, timesteps, feature)`.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
a given timestep should be masked (optional).
An individual `True` entry indicates that the corresponding timestep
should be utilized, while a `False` entry indicates that the
corresponding timestep should be ignored. Defaults to `None`.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the
cell when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used (optional). Defaults to `None`.
initial_state: List of initial state tensors to be passed to the first
call of the cell (optional, `None` causes creation
of zero-filled initial state tensors). Defaults to `None`.
"""
def __init__(
self,
units,
activation="tanh",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
reset_after=True,
seed=None,
**kwargs,
):
cell = GRUCell(
units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
reset_after=reset_after,
dtype=kwargs.get("dtype", None),
trainable=kwargs.get("trainable", True),
name="gru_cell",
seed=seed,
)
super().__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
activity_regularizer=activity_regularizer,
**kwargs,
)
self.input_spec = InputSpec(ndim=3)
def inner_loop(self, sequence, initial_state, mask, training=False):
if nest.is_nested(initial_state):
initial_state = initial_state[0]
if nest.is_nested(mask):
mask = mask[0]
try:
# Backends are allowed to specify (optionally) optimized
# implementation of the inner GRU loop. In the case of
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.gru(
sequence,
initial_state,
mask,
kernel=self.cell.kernel,
recurrent_kernel=self.cell.recurrent_kernel,
bias=self.cell.bias,
activation=self.cell.activation,
recurrent_activation=self.cell.recurrent_activation,
return_sequences=self.return_sequences,
go_backwards=self.go_backwards,
unroll=self.unroll,
reset_after=self.cell.reset_after,
)
except NotImplementedError:
return super().inner_loop(
sequence, initial_state, mask=mask, training=training
)
def call(self, sequence, initial_state=None, mask=None, training=None):
return super().call(
sequence, mask=mask, training=training, initial_state=initial_state
)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def recurrent_activation(self):
return self.cell.recurrent_activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
@property
def reset_after(self):
return self.cell.reset_after
def get_config(self):
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"recurrent_activation": activations.serialize(
self.recurrent_activation
),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"recurrent_initializer": initializers.serialize(
self.recurrent_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"recurrent_regularizer": regularizers.serialize(
self.recurrent_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"recurrent_constraint": constraints.serialize(
self.recurrent_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"reset_after": self.reset_after,
}
base_config = super().get_config()
del base_config["cell"]
return {**base_config, **config}
@classmethod
def from_config(cls, config):
return cls(**config)

@ -0,0 +1,172 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import initializers
from keras_core import layers
from keras_core import testing
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Only implemented for TF for now.",
)
class GRUTest(testing.TestCase):
def test_basics(self):
self.run_layer_test(
layers.GRU,
init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5},
input_shape=(3, 2, 4),
call_kwargs={"training": True},
expected_output_shape=(3, 3),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
self.run_layer_test(
layers.GRU,
init_kwargs={
"units": 3,
"return_sequences": True,
"bias_regularizer": "l1",
"kernel_regularizer": "l2",
"recurrent_regularizer": "l2",
},
input_shape=(3, 2, 4),
expected_output_shape=(3, 2, 3),
expected_num_losses=3,
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
supports_masking=True,
)
def test_correctness(self):
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.5217289, 0.5217289, 0.5217289],
[0.6371659, 0.6371659, 0.6371659],
[0.39384964, 0.39384964, 0.3938496],
]
),
output,
)
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
go_backwards=True,
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.24406259, 0.24406259, 0.24406259],
[0.611516, 0.611516, 0.611516],
[0.3928808, 0.3928808, 0.3928808],
]
),
output,
)
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
unroll=True,
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.5217289, 0.5217289, 0.5217289],
[0.6371659, 0.6371659, 0.6371659],
[0.39384964, 0.39384964, 0.3938496],
]
),
output,
)
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
reset_after=False,
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.51447755, 0.51447755, 0.51447755],
[0.6426879, 0.6426879, 0.6426879],
[0.40208298, 0.40208298, 0.40208298],
]
),
output,
)
layer = layers.GRU(
3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
use_bias=False,
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.49988455, 0.49988455, 0.49988455],
[0.64701194, 0.64701194, 0.64701194],
[0.4103359, 0.4103359, 0.4103359],
]
),
output,
)
def test_statefulness(self):
sequence = np.arange(24).reshape((2, 3, 4)).astype("float32")
layer = layers.GRU(
4,
stateful=True,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
layer(sequence)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.29542392, 0.29542392, 0.29542392, 0.29542392],
[0.5885018, 0.5885018, 0.5885018, 0.5885018],
]
),
output,
)
layer.reset_state()
layer(sequence)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[0.29542392, 0.29542392, 0.29542392, 0.29542392],
[0.5885018, 0.5885018, 0.5885018, 0.5885018],
]
),
output,
)
# TODO: test masking

@ -321,6 +321,32 @@ class RNN(Layer):
for v in self.states:
v.assign(ops.zeros_like(v))
def inner_loop(self, sequence, initial_state, mask, training=False):
cell_kwargs = {}
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
cell_kwargs["training"] = training
def step(inputs, states):
output, new_states = self.cell(inputs, states, **cell_kwargs)
if not nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
if not nest.is_nested(initial_state):
initial_state = [initial_state]
return backend.rnn(
step,
sequence,
initial_state,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=sequence.shape[1],
zero_output_for_mask=self.zero_output_for_mask,
return_all_outputs=self.return_sequences,
)
def call(
self,
sequence,
@ -335,18 +361,12 @@ class RNN(Layer):
"time dimension is undefined. \n"
"- If using a Sequential model, "
"specify the time dimension by passing "
"an `input_shape` or `batch_input_shape` "
"argument to your first layer. If your "
"first layer is an Embedding, you can "
"also use the `input_length` argument.\n"
"an `Input()` as your first layer.\n"
"- If using the functional API, specify "
"the time dimension by passing a `shape` "
"or `batch_shape` argument to your Input layer."
"or `batch_shape` argument to your `Input()`."
)
cell_kwargs = {}
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
cell_kwargs["training"] = training
if initial_state is None:
if self.stateful:
initial_state = self.states
@ -366,22 +386,11 @@ class RNN(Layer):
lambda x: ops.cast(x, dtype=self.compute_dtype), initial_state
)
def step(inputs, states):
output, new_states = self.cell(inputs, states, **cell_kwargs)
if not nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = backend.nn.rnn(
step,
sequence,
initial_state,
go_backwards=self.go_backwards,
last_output, outputs, states = self.inner_loop(
sequence=sequence,
initial_state=initial_state,
mask=mask,
unroll=self.unroll,
input_length=timesteps,
zero_output_for_mask=self.zero_output_for_mask,
return_all_outputs=self.return_sequences,
training=training,
)
self._maybe_reset_dropout_masks(self.cell)

@ -51,7 +51,7 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
for the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape `(batch, feature)`.
inputs: A 2D tensor, with shape `(batch, features)`.
states: A 2D tensor with shape `(batch, units)`, which is the state
from the previous time step.
training: Python boolean indicating whether the layer should behave in
@ -334,7 +334,7 @@ class SimpleRNN(RNN):
dropout=dropout,
recurrent_dropout=recurrent_dropout,
seed=seed,
dtype=kwargs.get("dtype"),
dtype=kwargs.get("dtype", None),
trainable=kwargs.get("trainable", True),
name="simple_rnn_cell",
)
@ -347,10 +347,9 @@ class SimpleRNN(RNN):
unroll=unroll,
**kwargs,
)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
def call(self, inputs, mask=None, training=None, initial_state=None):
def call(self, inputs, initial_state=None, mask=None, training=None):
return super().call(
inputs, mask=mask, training=training, initial_state=initial_state
)

@ -102,5 +102,7 @@ class MetricTest(testing.TestCase):
self.assertEqual(len(metric.variables), 6)
def test_serialization(self):
# TODO
pass
self.run_class_serialization_test(
ExampleMetric(name="mse"),
custom_objects={"ExampleMetric": ExampleMetric},
)

@ -1,22 +0,0 @@
"""
scatter
"""
from keras_core import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
class Scatter(Operation):
def call(self, indices, values, shape):
return backend.core.scatter(indices, values, shape)
def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
def scatter(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return Scatter().symbolic_call(indices, values, shape)
return backend.core.scatter(indices, values, shape)

@ -1,64 +0,0 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import testing
from keras_core.operations import core
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class CoreOpsDynamicShapeTest(testing.TestCase):
pass
class CoreOpsStaticShapeTest(testing.TestCase):
def test_scatter(self):
# Requires dtype
indices = np.array([[0]], dtype="int32")
values = np.array([0], dtype="int32")
shape = (8,)
self.assertEqual(core.scatter(indices, values, shape).shape, (8,))
class CoreOpsCorrectnessTest(testing.TestCase):
def test_scatter(self):
# Test 1D
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
self.assertAllClose(
core.scatter(indices, values, (8,)),
[0, 9, 0, 10, 11, 0, 0, 12],
)
# Test 2D
indices = np.array([[0, 1], [2, 0]])
values = np.array([5, 10])
self.assertAllClose(
core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]]
)
# Test 3D
indices = np.array([[1], [3]])
values = np.array(
[
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
]
)
self.assertAllClose(
core.scatter(indices, values, (4, 4, 4)),
[
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
],
)
# Test slices
indices = np.array([[2], [4]])
values = np.array([[1, 2, 3], [4, 5, 6]])
self.assertAllClose(
core.scatter(indices, values, (6, 3)),
[[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],
)

@ -67,11 +67,15 @@ def get_metric(identifier, y_true, y_pred):
y_true, y_pred
)
if is_binary:
metric_obj = metrics_module.binary_accuracy
metric_obj = metrics_module.BinaryAccuracy(name=str(identifier))
elif is_sparse_categorical:
metric_obj = metrics_module.sparse_categorical_accuracy
metric_obj = metrics_module.SparseCategoricalAccuracy(
name=str(identifier)
)
else:
metric_obj = metrics_module.categorical_accuracy
metric_obj = metrics_module.CategoricalAccuracy(
name=str(identifier)
)
if not isinstance(metric_obj, metrics_module.Metric):
if isinstance(identifier, str):

@ -178,6 +178,21 @@ class TestCompileMetrics(testing.TestCase):
self.assertAllClose(result["mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
def test_name_conversions(self):
compile_metrics = CompileMetrics(
metrics=["acc", "accuracy"],
weighted_metrics=[],
)
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])
compile_metrics.build(y_true, y_pred)
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertEqual(len(result), 2)
self.assertAllClose(result["acc"], 0.333333)
self.assertAllClose(result["accuracy"], 0.333333)
class TestCompileLoss(testing.TestCase):
def test_single_output_case(self):

@ -60,10 +60,8 @@ class EpochIterator:
steps_per_epoch=None,
shuffle=False,
class_weight=None,
steps_per_execution=1,
):
self.steps_per_epoch = steps_per_epoch
self.steps_per_execution = steps_per_execution
if steps_per_epoch:
self._current_iterator = None
self._insufficient_data = False
@ -106,7 +104,9 @@ class EpochIterator:
"sample_weights", "the sample weights", "PyDataset"
)
elif isinstance(x, types.GeneratorType):
self.data_adapter = generator_data_adapter.GeneratorDataAdapter(x)
self.data_adapter = generator_data_adapter.GeneratorDataAdapter(
x, shuffle=shuffle
)
if y is not None:
raise_unsupported_arg("y", "the targets", "PyDataset")
if sample_weight is not None:

@ -26,7 +26,6 @@ class Trainer:
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
jit_compile=True,
):
self.optimizer = optimizers.get(optimizer)
@ -50,7 +49,6 @@ class Trainer:
self.stop_training = False
self.compiled = True
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution
self._compile_config = serialization_lib.SerializableDict(
optimizer=optimizer,
@ -59,7 +57,6 @@ class Trainer:
metrics=metrics,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile,
)

@ -1,5 +1,4 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import initializers
@ -8,7 +7,6 @@ from keras_core import losses
from keras_core import metrics
from keras_core import optimizers
from keras_core import testing
from keras_core.callbacks.callback import Callback
if backend.backend() == "jax":
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
@ -215,32 +213,3 @@ class TestTrainer(testing.TestCase):
def test_predict_flow_jit(self):
self._test_predict_flow(run_eagerly=False, jit_compile=True)
# TODO: Remove the skipif when implemented steps_per_execution for JAX.
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="JAX does not support steps_per_execution yet",
)
def test_steps_per_execution_steps_count(self):
class StepCount(Callback):
def __init__(self):
super().__init__()
self.count = 0
self.batches = [0, 3, 6]
def on_batch_begin(self, batch, logs=None):
assert batch == self.batches[self.count]
self.count += 1
x = np.ones((100, 4))
y = np.ones((100, 1))
model = ExampleModel(units=1)
model.compile(loss="mse", optimizer="adam", steps_per_execution=3)
step_count = StepCount()
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count])
self.assertEqual(step_count.count, 3)
model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=16)
self.assertAllClose(model.get_weights(), model_2.get_weights())