Add GRU layer.
This commit is contained in:
parent
77b4fcc3dc
commit
42236e5d4e
@ -11,7 +11,6 @@ from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
||||
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
||||
from keras_core.backend.jax import core
|
||||
from keras_core.backend.jax import image
|
||||
from keras_core.backend.jax import math
|
||||
from keras_core.backend.jax import nn
|
||||
|
@ -1,7 +0,0 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def scatter(indices, values, shape):
|
||||
zeros = jnp.zeros(shape, values.dtype)
|
||||
key = tuple(jnp.moveaxis(indices, -1, 0))
|
||||
return zeros.at[key].set(values)
|
@ -8,12 +8,14 @@ from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
||||
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
||||
from keras_core.backend.tensorflow import core
|
||||
from keras_core.backend.tensorflow import image
|
||||
from keras_core.backend.tensorflow import math
|
||||
from keras_core.backend.tensorflow import nn
|
||||
from keras_core.backend.tensorflow import numpy
|
||||
from keras_core.backend.tensorflow import random
|
||||
from keras_core.backend.tensorflow.rnn import gru
|
||||
from keras_core.backend.tensorflow.rnn import lstm
|
||||
from keras_core.backend.tensorflow.rnn import rnn
|
||||
from keras_core.utils.naming import auto_name
|
||||
|
||||
DYNAMIC_SHAPES_OK = True
|
||||
|
@ -1,5 +0,0 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def scatter(indices, values, shape):
|
||||
return tf.scatter_nd(indices, values, shape)
|
@ -602,458 +602,3 @@ def binary_crossentropy(target, output, from_logits=False):
|
||||
bce = target * tf.math.log(output)
|
||||
bce += (1 - target) * tf.math.log(1 - output)
|
||||
return -bce
|
||||
|
||||
|
||||
def rnn(
|
||||
step_function,
|
||||
inputs,
|
||||
initial_states,
|
||||
go_backwards=False,
|
||||
mask=None,
|
||||
constants=None,
|
||||
unroll=False,
|
||||
input_length=None,
|
||||
time_major=False,
|
||||
zero_output_for_mask=False,
|
||||
return_all_outputs=True,
|
||||
):
|
||||
"""Iterates over the time dimension of a tensor.
|
||||
|
||||
Args:
|
||||
step_function: RNN step function.
|
||||
Args;
|
||||
`input`; Tensor with shape `(samples, ...)` (no time dimension),
|
||||
representing input for the batch of samples at a certain
|
||||
time step.
|
||||
`states`; List of tensors.
|
||||
Returns;
|
||||
`output`; Tensor with shape `(samples, output_dim)`
|
||||
(no time dimension).
|
||||
`new_states`; List of tensors, same length and shapes
|
||||
as 'states'. The first state in the list must be the
|
||||
output tensor at the previous timestep.
|
||||
inputs: Tensor of temporal data of shape `(samples, time, ...)`
|
||||
(at least 3D), or nested tensors, and each of which has shape
|
||||
`(samples, time, ...)`.
|
||||
initial_states: Tensor with shape `(samples, state_size)`
|
||||
(no time dimension), containing the initial values for the states
|
||||
used in the step function. In the case that state_size is in a
|
||||
nested shape, the shape of initial_states will also follow the
|
||||
nested structure.
|
||||
go_backwards: Boolean. If `True`, do the iteration over the time
|
||||
dimension in reverse order and return the reversed sequence.
|
||||
mask: Binary tensor with shape `(samples, time, 1)`,
|
||||
with a zero for every element that is masked.
|
||||
constants: List of constant values passed at each step.
|
||||
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
|
||||
input_length: An integer or a 1-D Tensor, depending on whether
|
||||
the time dimension is fixed-length or not. In case of variable
|
||||
length input, it is used for masking in case there's no mask
|
||||
specified.
|
||||
time_major: Boolean. If `True`, the inputs and outputs will be in shape
|
||||
`(timesteps, batch, ...)`, whereas in the False case, it will be
|
||||
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
|
||||
efficient because it avoids transposes at the beginning and end of
|
||||
the RNN calculation. However, most TensorFlow data is batch-major,
|
||||
so by default this function accepts input and emits output in
|
||||
batch-major form.
|
||||
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
|
||||
will be zeros, whereas in the `False` case, output from previous
|
||||
timestep is returned.
|
||||
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
|
||||
all timesteps in the sequence. If `False`, only return the output
|
||||
for the last timestep (which consumes less memory).
|
||||
|
||||
Returns:
|
||||
A tuple, `(last_output, outputs, new_states)`.
|
||||
- `last_output`: the latest output of the rnn,
|
||||
with shape `(samples, ...)`.
|
||||
- `outputs`:
|
||||
- If `return_all_outputs=True`: a tensor with shape
|
||||
`(samples, time, ...)` where each entry `outputs[s, t]` is the
|
||||
output of the step function at time `t` for sample `s`
|
||||
- Else, a tensor equal to `last_output` with shape
|
||||
`(samples, 1, ...)`
|
||||
- `new_states`: list of tensors, latest states returned by
|
||||
the step function, of shape `(samples, ...)`.
|
||||
"""
|
||||
|
||||
def swap_batch_timestep(input_t):
|
||||
# Swap the batch and timestep dim for the incoming tensor.
|
||||
axes = list(range(len(input_t.shape)))
|
||||
axes[0], axes[1] = 1, 0
|
||||
return tf.transpose(input_t, axes)
|
||||
|
||||
if not time_major:
|
||||
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
|
||||
|
||||
flatted_inputs = tf.nest.flatten(inputs)
|
||||
time_steps = flatted_inputs[0].shape[0]
|
||||
batch = flatted_inputs[0].shape[1]
|
||||
time_steps_t = tf.shape(flatted_inputs[0])[0]
|
||||
|
||||
for input_ in flatted_inputs:
|
||||
input_.shape.with_rank_at_least(3)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dtype != tf.bool:
|
||||
mask = tf.cast(mask, tf.bool)
|
||||
if len(mask.shape) == 2:
|
||||
mask = tf.expand_dims(mask, axis=-1)
|
||||
if not time_major:
|
||||
mask = swap_batch_timestep(mask)
|
||||
|
||||
if constants is None:
|
||||
constants = []
|
||||
|
||||
# tf.where needs its condition tensor to be the same shape as its two
|
||||
# result tensors, but in our case the condition (mask) tensor is
|
||||
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
|
||||
# So we need to broadcast the mask to match the shape of inputs.
|
||||
# That's what the tile call does, it just repeats the mask along its
|
||||
# second dimension n times.
|
||||
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
||||
if tf.nest.is_nested(mask_t):
|
||||
raise ValueError(
|
||||
f"mask_t is expected to be tensor, but got {mask_t}"
|
||||
)
|
||||
if tf.nest.is_nested(input_t):
|
||||
raise ValueError(
|
||||
f"input_t is expected to be tensor, but got {input_t}"
|
||||
)
|
||||
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
||||
for _ in range(rank_diff):
|
||||
mask_t = tf.expand_dims(mask_t, -1)
|
||||
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
|
||||
return tf.tile(mask_t, multiples)
|
||||
|
||||
if unroll:
|
||||
if not time_steps:
|
||||
raise ValueError("Unrolling requires a fixed number of timesteps.")
|
||||
states = tuple(initial_states)
|
||||
successive_states = []
|
||||
successive_outputs = []
|
||||
|
||||
# Process the input tensors. The input tensor need to be split on the
|
||||
# time_step dim, and reverse if go_backwards is True. In the case of
|
||||
# nested input, the input is flattened and then transformed
|
||||
# individually. The result of this will be a tuple of lists, each of
|
||||
# the item in tuple is list of the tensor with shape (batch, feature)
|
||||
def _process_single_input_t(input_t):
|
||||
input_t = tf.unstack(input_t) # unstack for time_step dim
|
||||
if go_backwards:
|
||||
input_t.reverse()
|
||||
return input_t
|
||||
|
||||
if tf.nest.is_nested(inputs):
|
||||
processed_input = tf.nest.map_structure(
|
||||
_process_single_input_t, inputs
|
||||
)
|
||||
else:
|
||||
processed_input = (_process_single_input_t(inputs),)
|
||||
|
||||
def _get_input_tensor(time):
|
||||
inp = [t_[time] for t_ in processed_input]
|
||||
return tf.nest.pack_sequence_as(inputs, inp)
|
||||
|
||||
if mask is not None:
|
||||
mask_list = tf.unstack(mask)
|
||||
if go_backwards:
|
||||
mask_list.reverse()
|
||||
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
mask_t = mask_list[i]
|
||||
output, new_states = step_function(
|
||||
inp, tuple(states) + tuple(constants)
|
||||
)
|
||||
tiled_mask_t = _expand_mask(mask_t, output)
|
||||
|
||||
if not successive_outputs:
|
||||
prev_output = tf.zeros_like(output)
|
||||
else:
|
||||
prev_output = successive_outputs[-1]
|
||||
|
||||
output = tf.where(tiled_mask_t, output, prev_output)
|
||||
|
||||
flat_states = tf.nest.flatten(states)
|
||||
flat_new_states = tf.nest.flatten(new_states)
|
||||
tiled_mask_t = tuple(
|
||||
_expand_mask(mask_t, s) for s in flat_states
|
||||
)
|
||||
flat_final_states = tuple(
|
||||
tf.where(m, s, ps)
|
||||
for m, s, ps in zip(
|
||||
tiled_mask_t, flat_new_states, flat_states
|
||||
)
|
||||
)
|
||||
states = tf.nest.pack_sequence_as(states, flat_final_states)
|
||||
|
||||
if return_all_outputs:
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
else:
|
||||
successive_outputs = [output]
|
||||
successive_states = [states]
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = tf.stack(successive_outputs)
|
||||
|
||||
if zero_output_for_mask:
|
||||
last_output = tf.where(
|
||||
_expand_mask(mask_list[-1], last_output),
|
||||
last_output,
|
||||
tf.zeros_like(last_output),
|
||||
)
|
||||
outputs = tf.where(
|
||||
_expand_mask(mask, outputs, fixed_dim=2),
|
||||
outputs,
|
||||
tf.zeros_like(outputs),
|
||||
)
|
||||
|
||||
else: # mask is None
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
output, states = step_function(
|
||||
inp, tuple(states) + tuple(constants)
|
||||
)
|
||||
if return_all_outputs:
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
else:
|
||||
successive_outputs = [output]
|
||||
successive_states = [states]
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = tf.stack(successive_outputs)
|
||||
|
||||
else: # Unroll == False
|
||||
states = tuple(initial_states)
|
||||
|
||||
# Create input tensor array, if the inputs is nested tensors, then it
|
||||
# will be flattened first, and tensor array will be created one per
|
||||
# flattened tensor.
|
||||
input_ta = tuple(
|
||||
tf.TensorArray(
|
||||
dtype=inp.dtype,
|
||||
size=time_steps_t,
|
||||
tensor_array_name=f"input_ta_{i}",
|
||||
)
|
||||
for i, inp in enumerate(flatted_inputs)
|
||||
)
|
||||
input_ta = tuple(
|
||||
ta.unstack(input_)
|
||||
if not go_backwards
|
||||
else ta.unstack(tf.reverse(input_, [0]))
|
||||
for ta, input_ in zip(input_ta, flatted_inputs)
|
||||
)
|
||||
|
||||
# Get the time(0) input and compute the output for that, the output will
|
||||
# be used to determine the dtype of output tensor array. Don't read from
|
||||
# input_ta due to TensorArray clear_after_read default to True.
|
||||
input_time_zero = tf.nest.pack_sequence_as(
|
||||
inputs, [inp[0] for inp in flatted_inputs]
|
||||
)
|
||||
# output_time_zero is used to determine the cell output shape and its
|
||||
# dtype. the value is discarded.
|
||||
output_time_zero, _ = step_function(
|
||||
input_time_zero, tuple(initial_states) + tuple(constants)
|
||||
)
|
||||
|
||||
output_ta_size = time_steps_t if return_all_outputs else 1
|
||||
output_ta = tuple(
|
||||
tf.TensorArray(
|
||||
dtype=out.dtype,
|
||||
size=output_ta_size,
|
||||
element_shape=out.shape,
|
||||
tensor_array_name=f"output_ta_{i}",
|
||||
)
|
||||
for i, out in enumerate(tf.nest.flatten(output_time_zero))
|
||||
)
|
||||
|
||||
time = tf.constant(0, dtype="int32", name="time")
|
||||
|
||||
if input_length is None:
|
||||
max_iterations = time_steps_t
|
||||
else:
|
||||
max_iterations = tf.reduce_max(input_length)
|
||||
|
||||
while_loop_kwargs = {
|
||||
"cond": lambda time, *_: time < time_steps_t,
|
||||
"maximum_iterations": max_iterations,
|
||||
"parallel_iterations": 32,
|
||||
"swap_memory": True,
|
||||
}
|
||||
if mask is not None:
|
||||
if go_backwards:
|
||||
mask = tf.reverse(mask, [0])
|
||||
|
||||
mask_ta = tf.TensorArray(
|
||||
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
|
||||
)
|
||||
mask_ta = mask_ta.unstack(mask)
|
||||
|
||||
def masking_fn(time):
|
||||
return mask_ta.read(time)
|
||||
|
||||
def compute_masked_output(mask_t, flat_out, flat_mask):
|
||||
tiled_mask_t = tuple(
|
||||
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
|
||||
for o in flat_out
|
||||
)
|
||||
return tuple(
|
||||
tf.where(m, o, fm)
|
||||
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
|
||||
)
|
||||
|
||||
elif isinstance(input_length, tf.Tensor):
|
||||
if go_backwards:
|
||||
max_len = tf.reduce_max(input_length, axis=0)
|
||||
rev_input_length = tf.subtract(max_len - 1, input_length)
|
||||
|
||||
def masking_fn(time):
|
||||
return tf.less(rev_input_length, time)
|
||||
|
||||
else:
|
||||
|
||||
def masking_fn(time):
|
||||
return tf.greater(input_length, time)
|
||||
|
||||
def compute_masked_output(mask_t, flat_out, flat_mask):
|
||||
return tuple(
|
||||
tf.where(mask_t, o, zo)
|
||||
for (o, zo) in zip(flat_out, flat_mask)
|
||||
)
|
||||
|
||||
else:
|
||||
masking_fn = None
|
||||
|
||||
if masking_fn is not None:
|
||||
# Mask for the T output will be base on the output of T - 1. In the
|
||||
# case T = 0, a zero filled tensor will be used.
|
||||
flat_zero_output = tuple(
|
||||
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
|
||||
)
|
||||
|
||||
def _step(time, output_ta_t, prev_output, *states):
|
||||
"""RNN step function.
|
||||
|
||||
Args:
|
||||
time: Current timestep value.
|
||||
output_ta_t: TensorArray.
|
||||
prev_output: tuple of outputs from time - 1.
|
||||
*states: List of states.
|
||||
|
||||
Returns:
|
||||
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
|
||||
"""
|
||||
current_input = tuple(ta.read(time) for ta in input_ta)
|
||||
# maybe set shape.
|
||||
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
||||
mask_t = masking_fn(time)
|
||||
output, new_states = step_function(
|
||||
current_input, tuple(states) + tuple(constants)
|
||||
)
|
||||
# mask output
|
||||
flat_output = tf.nest.flatten(output)
|
||||
flat_mask_output = (
|
||||
flat_zero_output
|
||||
if zero_output_for_mask
|
||||
else tf.nest.flatten(prev_output)
|
||||
)
|
||||
flat_new_output = compute_masked_output(
|
||||
mask_t, flat_output, flat_mask_output
|
||||
)
|
||||
|
||||
# mask states
|
||||
flat_state = tf.nest.flatten(states)
|
||||
flat_new_state = tf.nest.flatten(new_states)
|
||||
flat_final_state = compute_masked_output(
|
||||
mask_t, flat_new_state, flat_state
|
||||
)
|
||||
new_states = tf.nest.pack_sequence_as(
|
||||
new_states, flat_final_state
|
||||
)
|
||||
|
||||
ta_index_to_write = time if return_all_outputs else 0
|
||||
output_ta_t = tuple(
|
||||
ta.write(ta_index_to_write, out)
|
||||
for ta, out in zip(output_ta_t, flat_new_output)
|
||||
)
|
||||
|
||||
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
|
||||
new_states
|
||||
)
|
||||
|
||||
final_outputs = tf.while_loop(
|
||||
body=_step,
|
||||
loop_vars=(time, output_ta, flat_zero_output) + states,
|
||||
**while_loop_kwargs,
|
||||
)
|
||||
# Skip final_outputs[2] which is the output for final timestep.
|
||||
new_states = final_outputs[3:]
|
||||
else:
|
||||
|
||||
def _step(time, output_ta_t, *states):
|
||||
"""RNN step function.
|
||||
|
||||
Args:
|
||||
time: Current timestep value.
|
||||
output_ta_t: TensorArray.
|
||||
*states: List of states.
|
||||
|
||||
Returns:
|
||||
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
|
||||
"""
|
||||
current_input = tuple(ta.read(time) for ta in input_ta)
|
||||
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
||||
output, new_states = step_function(
|
||||
current_input, tuple(states) + tuple(constants)
|
||||
)
|
||||
flat_new_state = tf.nest.flatten(new_states)
|
||||
|
||||
flat_output = tf.nest.flatten(output)
|
||||
ta_index_to_write = time if return_all_outputs else 0
|
||||
output_ta_t = tuple(
|
||||
ta.write(ta_index_to_write, out)
|
||||
for ta, out in zip(output_ta_t, flat_output)
|
||||
)
|
||||
|
||||
new_states = tf.nest.pack_sequence_as(
|
||||
initial_states, flat_new_state
|
||||
)
|
||||
return (time + 1, output_ta_t) + tuple(new_states)
|
||||
|
||||
final_outputs = tf.while_loop(
|
||||
body=_step,
|
||||
loop_vars=(time, output_ta) + states,
|
||||
**while_loop_kwargs,
|
||||
)
|
||||
new_states = final_outputs[2:]
|
||||
|
||||
output_ta = final_outputs[1]
|
||||
|
||||
outputs = tuple(o.stack() for o in output_ta)
|
||||
last_output = tuple(o[-1] for o in outputs)
|
||||
|
||||
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
|
||||
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
|
||||
|
||||
# static shape inference
|
||||
def set_shape(output_):
|
||||
if isinstance(output_, tf.Tensor):
|
||||
shape = output_.shape.as_list()
|
||||
if return_all_outputs:
|
||||
shape[0] = time_steps
|
||||
else:
|
||||
shape[0] = 1
|
||||
shape[1] = batch
|
||||
output_.set_shape(shape)
|
||||
return output_
|
||||
|
||||
outputs = tf.nest.map_structure(set_shape, outputs)
|
||||
|
||||
if not time_major:
|
||||
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
|
||||
|
||||
return last_output, outputs, new_states
|
||||
|
763
keras_core/backend/tensorflow/rnn.py
Normal file
763
keras_core/backend/tensorflow/rnn.py
Normal file
@ -0,0 +1,763 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def rnn(
|
||||
step_function,
|
||||
inputs,
|
||||
initial_states,
|
||||
go_backwards=False,
|
||||
mask=None,
|
||||
constants=None,
|
||||
unroll=False,
|
||||
input_length=None,
|
||||
time_major=False,
|
||||
zero_output_for_mask=False,
|
||||
return_all_outputs=True,
|
||||
):
|
||||
"""Iterates over the time dimension of a tensor.
|
||||
|
||||
Args:
|
||||
step_function: RNN step function.
|
||||
Args;
|
||||
`input`; Tensor with shape `(samples, ...)` (no time dimension),
|
||||
representing input for the batch of samples at a certain
|
||||
time step.
|
||||
`states`; List of tensors.
|
||||
Returns;
|
||||
`output`; Tensor with shape `(samples, output_dim)`
|
||||
(no time dimension).
|
||||
`new_states`; List of tensors, same length and shapes
|
||||
as 'states'. The first state in the list must be the
|
||||
output tensor at the previous timestep.
|
||||
inputs: Tensor of temporal data of shape `(samples, time, ...)`
|
||||
(at least 3D), or nested tensors, and each of which has shape
|
||||
`(samples, time, ...)`.
|
||||
initial_states: Tensor with shape `(samples, state_size)`
|
||||
(no time dimension), containing the initial values for the states
|
||||
used in the step function. In the case that state_size is in a
|
||||
nested shape, the shape of initial_states will also follow the
|
||||
nested structure.
|
||||
go_backwards: Boolean. If `True`, do the iteration over the time
|
||||
dimension in reverse order and return the reversed sequence.
|
||||
mask: Binary tensor with shape `(samples, time, 1)`,
|
||||
with a zero for every element that is masked.
|
||||
constants: List of constant values passed at each step.
|
||||
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
|
||||
input_length: An integer or a 1-D Tensor, depending on whether
|
||||
the time dimension is fixed-length or not. In case of variable
|
||||
length input, it is used for masking in case there's no mask
|
||||
specified.
|
||||
time_major: Boolean. If `True`, the inputs and outputs will be in shape
|
||||
`(timesteps, batch, ...)`, whereas in the False case, it will be
|
||||
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
|
||||
efficient because it avoids transposes at the beginning and end of
|
||||
the RNN calculation. However, most TensorFlow data is batch-major,
|
||||
so by default this function accepts input and emits output in
|
||||
batch-major form.
|
||||
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
|
||||
will be zeros, whereas in the `False` case, output from previous
|
||||
timestep is returned.
|
||||
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
|
||||
all timesteps in the sequence. If `False`, only return the output
|
||||
for the last timestep (which consumes less memory).
|
||||
|
||||
Returns:
|
||||
A tuple, `(last_output, outputs, new_states)`.
|
||||
- `last_output`: the latest output of the rnn,
|
||||
with shape `(samples, ...)`.
|
||||
- `outputs`:
|
||||
- If `return_all_outputs=True`: a tensor with shape
|
||||
`(samples, time, ...)` where each entry `outputs[s, t]` is the
|
||||
output of the step function at time `t` for sample `s`
|
||||
- Else, a tensor equal to `last_output` with shape
|
||||
`(samples, 1, ...)`
|
||||
- `new_states`: list of tensors, latest states returned by
|
||||
the step function, of shape `(samples, ...)`.
|
||||
"""
|
||||
|
||||
def swap_batch_timestep(input_t):
|
||||
# Swap the batch and timestep dim for the incoming tensor.
|
||||
axes = list(range(len(input_t.shape)))
|
||||
axes[0], axes[1] = 1, 0
|
||||
return tf.transpose(input_t, axes)
|
||||
|
||||
if not time_major:
|
||||
inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
|
||||
|
||||
flatted_inputs = tf.nest.flatten(inputs)
|
||||
time_steps = flatted_inputs[0].shape[0]
|
||||
batch = flatted_inputs[0].shape[1]
|
||||
time_steps_t = tf.shape(flatted_inputs[0])[0]
|
||||
|
||||
for input_ in flatted_inputs:
|
||||
input_.shape.with_rank_at_least(3)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dtype != tf.bool:
|
||||
mask = tf.cast(mask, tf.bool)
|
||||
if len(mask.shape) == 2:
|
||||
mask = tf.expand_dims(mask, axis=-1)
|
||||
if not time_major:
|
||||
mask = swap_batch_timestep(mask)
|
||||
|
||||
if constants is None:
|
||||
constants = []
|
||||
|
||||
# tf.where needs its condition tensor to be the same shape as its two
|
||||
# result tensors, but in our case the condition (mask) tensor is
|
||||
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
|
||||
# So we need to broadcast the mask to match the shape of inputs.
|
||||
# That's what the tile call does, it just repeats the mask along its
|
||||
# second dimension n times.
|
||||
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
||||
if tf.nest.is_nested(mask_t):
|
||||
raise ValueError(
|
||||
f"mask_t is expected to be tensor, but got {mask_t}"
|
||||
)
|
||||
if tf.nest.is_nested(input_t):
|
||||
raise ValueError(
|
||||
f"input_t is expected to be tensor, but got {input_t}"
|
||||
)
|
||||
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
||||
for _ in range(rank_diff):
|
||||
mask_t = tf.expand_dims(mask_t, -1)
|
||||
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
|
||||
return tf.tile(mask_t, multiples)
|
||||
|
||||
if unroll:
|
||||
if not time_steps:
|
||||
raise ValueError("Unrolling requires a fixed number of timesteps.")
|
||||
states = tuple(initial_states)
|
||||
successive_states = []
|
||||
successive_outputs = []
|
||||
|
||||
# Process the input tensors. The input tensor need to be split on the
|
||||
# time_step dim, and reverse if go_backwards is True. In the case of
|
||||
# nested input, the input is flattened and then transformed
|
||||
# individually. The result of this will be a tuple of lists, each of
|
||||
# the item in tuple is list of the tensor with shape (batch, feature)
|
||||
def _process_single_input_t(input_t):
|
||||
input_t = tf.unstack(input_t) # unstack for time_step dim
|
||||
if go_backwards:
|
||||
input_t.reverse()
|
||||
return input_t
|
||||
|
||||
if tf.nest.is_nested(inputs):
|
||||
processed_input = tf.nest.map_structure(
|
||||
_process_single_input_t, inputs
|
||||
)
|
||||
else:
|
||||
processed_input = (_process_single_input_t(inputs),)
|
||||
|
||||
def _get_input_tensor(time):
|
||||
inp = [t_[time] for t_ in processed_input]
|
||||
return tf.nest.pack_sequence_as(inputs, inp)
|
||||
|
||||
if mask is not None:
|
||||
mask_list = tf.unstack(mask)
|
||||
if go_backwards:
|
||||
mask_list.reverse()
|
||||
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
mask_t = mask_list[i]
|
||||
output, new_states = step_function(
|
||||
inp, tuple(states) + tuple(constants)
|
||||
)
|
||||
tiled_mask_t = _expand_mask(mask_t, output)
|
||||
|
||||
if not successive_outputs:
|
||||
prev_output = tf.zeros_like(output)
|
||||
else:
|
||||
prev_output = successive_outputs[-1]
|
||||
|
||||
output = tf.where(tiled_mask_t, output, prev_output)
|
||||
|
||||
flat_states = tf.nest.flatten(states)
|
||||
flat_new_states = tf.nest.flatten(new_states)
|
||||
tiled_mask_t = tuple(
|
||||
_expand_mask(mask_t, s) for s in flat_states
|
||||
)
|
||||
flat_final_states = tuple(
|
||||
tf.where(m, s, ps)
|
||||
for m, s, ps in zip(
|
||||
tiled_mask_t, flat_new_states, flat_states
|
||||
)
|
||||
)
|
||||
states = tf.nest.pack_sequence_as(states, flat_final_states)
|
||||
|
||||
if return_all_outputs:
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
else:
|
||||
successive_outputs = [output]
|
||||
successive_states = [states]
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = tf.stack(successive_outputs)
|
||||
|
||||
if zero_output_for_mask:
|
||||
last_output = tf.where(
|
||||
_expand_mask(mask_list[-1], last_output),
|
||||
last_output,
|
||||
tf.zeros_like(last_output),
|
||||
)
|
||||
outputs = tf.where(
|
||||
_expand_mask(mask, outputs, fixed_dim=2),
|
||||
outputs,
|
||||
tf.zeros_like(outputs),
|
||||
)
|
||||
|
||||
else: # mask is None
|
||||
for i in range(time_steps):
|
||||
inp = _get_input_tensor(i)
|
||||
output, states = step_function(
|
||||
inp, tuple(states) + tuple(constants)
|
||||
)
|
||||
if return_all_outputs:
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
else:
|
||||
successive_outputs = [output]
|
||||
successive_states = [states]
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = tf.stack(successive_outputs)
|
||||
|
||||
else: # Unroll == False
|
||||
states = tuple(initial_states)
|
||||
|
||||
# Create input tensor array, if the inputs is nested tensors, then it
|
||||
# will be flattened first, and tensor array will be created one per
|
||||
# flattened tensor.
|
||||
input_ta = tuple(
|
||||
tf.TensorArray(
|
||||
dtype=inp.dtype,
|
||||
size=time_steps_t,
|
||||
tensor_array_name=f"input_ta_{i}",
|
||||
)
|
||||
for i, inp in enumerate(flatted_inputs)
|
||||
)
|
||||
input_ta = tuple(
|
||||
ta.unstack(input_)
|
||||
if not go_backwards
|
||||
else ta.unstack(tf.reverse(input_, [0]))
|
||||
for ta, input_ in zip(input_ta, flatted_inputs)
|
||||
)
|
||||
|
||||
# Get the time(0) input and compute the output for that, the output will
|
||||
# be used to determine the dtype of output tensor array. Don't read from
|
||||
# input_ta due to TensorArray clear_after_read default to True.
|
||||
input_time_zero = tf.nest.pack_sequence_as(
|
||||
inputs, [inp[0] for inp in flatted_inputs]
|
||||
)
|
||||
# output_time_zero is used to determine the cell output shape and its
|
||||
# dtype. the value is discarded.
|
||||
output_time_zero, _ = step_function(
|
||||
input_time_zero, tuple(initial_states) + tuple(constants)
|
||||
)
|
||||
|
||||
output_ta_size = time_steps_t if return_all_outputs else 1
|
||||
output_ta = tuple(
|
||||
tf.TensorArray(
|
||||
dtype=out.dtype,
|
||||
size=output_ta_size,
|
||||
element_shape=out.shape,
|
||||
tensor_array_name=f"output_ta_{i}",
|
||||
)
|
||||
for i, out in enumerate(tf.nest.flatten(output_time_zero))
|
||||
)
|
||||
|
||||
time = tf.constant(0, dtype="int32", name="time")
|
||||
|
||||
if input_length is None:
|
||||
max_iterations = time_steps_t
|
||||
else:
|
||||
max_iterations = tf.reduce_max(input_length)
|
||||
|
||||
while_loop_kwargs = {
|
||||
"cond": lambda time, *_: time < time_steps_t,
|
||||
"maximum_iterations": max_iterations,
|
||||
"parallel_iterations": 32,
|
||||
"swap_memory": True,
|
||||
}
|
||||
if mask is not None:
|
||||
if go_backwards:
|
||||
mask = tf.reverse(mask, [0])
|
||||
|
||||
mask_ta = tf.TensorArray(
|
||||
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
|
||||
)
|
||||
mask_ta = mask_ta.unstack(mask)
|
||||
|
||||
def masking_fn(time):
|
||||
return mask_ta.read(time)
|
||||
|
||||
def compute_masked_output(mask_t, flat_out, flat_mask):
|
||||
tiled_mask_t = tuple(
|
||||
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
|
||||
for o in flat_out
|
||||
)
|
||||
return tuple(
|
||||
tf.where(m, o, fm)
|
||||
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
|
||||
)
|
||||
|
||||
elif isinstance(input_length, tf.Tensor):
|
||||
if go_backwards:
|
||||
max_len = tf.reduce_max(input_length, axis=0)
|
||||
rev_input_length = tf.subtract(max_len - 1, input_length)
|
||||
|
||||
def masking_fn(time):
|
||||
return tf.less(rev_input_length, time)
|
||||
|
||||
else:
|
||||
|
||||
def masking_fn(time):
|
||||
return tf.greater(input_length, time)
|
||||
|
||||
def compute_masked_output(mask_t, flat_out, flat_mask):
|
||||
return tuple(
|
||||
tf.where(mask_t, o, zo)
|
||||
for (o, zo) in zip(flat_out, flat_mask)
|
||||
)
|
||||
|
||||
else:
|
||||
masking_fn = None
|
||||
|
||||
if masking_fn is not None:
|
||||
# Mask for the T output will be base on the output of T - 1. In the
|
||||
# case T = 0, a zero filled tensor will be used.
|
||||
flat_zero_output = tuple(
|
||||
tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
|
||||
)
|
||||
|
||||
def _step(time, output_ta_t, prev_output, *states):
|
||||
"""RNN step function.
|
||||
|
||||
Args:
|
||||
time: Current timestep value.
|
||||
output_ta_t: TensorArray.
|
||||
prev_output: tuple of outputs from time - 1.
|
||||
*states: List of states.
|
||||
|
||||
Returns:
|
||||
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
|
||||
"""
|
||||
current_input = tuple(ta.read(time) for ta in input_ta)
|
||||
# maybe set shape.
|
||||
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
||||
mask_t = masking_fn(time)
|
||||
output, new_states = step_function(
|
||||
current_input, tuple(states) + tuple(constants)
|
||||
)
|
||||
# mask output
|
||||
flat_output = tf.nest.flatten(output)
|
||||
flat_mask_output = (
|
||||
flat_zero_output
|
||||
if zero_output_for_mask
|
||||
else tf.nest.flatten(prev_output)
|
||||
)
|
||||
flat_new_output = compute_masked_output(
|
||||
mask_t, flat_output, flat_mask_output
|
||||
)
|
||||
|
||||
# mask states
|
||||
flat_state = tf.nest.flatten(states)
|
||||
flat_new_state = tf.nest.flatten(new_states)
|
||||
flat_final_state = compute_masked_output(
|
||||
mask_t, flat_new_state, flat_state
|
||||
)
|
||||
new_states = tf.nest.pack_sequence_as(
|
||||
new_states, flat_final_state
|
||||
)
|
||||
|
||||
ta_index_to_write = time if return_all_outputs else 0
|
||||
output_ta_t = tuple(
|
||||
ta.write(ta_index_to_write, out)
|
||||
for ta, out in zip(output_ta_t, flat_new_output)
|
||||
)
|
||||
|
||||
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
|
||||
new_states
|
||||
)
|
||||
|
||||
final_outputs = tf.while_loop(
|
||||
body=_step,
|
||||
loop_vars=(time, output_ta, flat_zero_output) + states,
|
||||
**while_loop_kwargs,
|
||||
)
|
||||
# Skip final_outputs[2] which is the output for final timestep.
|
||||
new_states = final_outputs[3:]
|
||||
else:
|
||||
|
||||
def _step(time, output_ta_t, *states):
|
||||
"""RNN step function.
|
||||
|
||||
Args:
|
||||
time: Current timestep value.
|
||||
output_ta_t: TensorArray.
|
||||
*states: List of states.
|
||||
|
||||
Returns:
|
||||
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
|
||||
"""
|
||||
current_input = tuple(ta.read(time) for ta in input_ta)
|
||||
current_input = tf.nest.pack_sequence_as(inputs, current_input)
|
||||
output, new_states = step_function(
|
||||
current_input, tuple(states) + tuple(constants)
|
||||
)
|
||||
flat_new_state = tf.nest.flatten(new_states)
|
||||
|
||||
flat_output = tf.nest.flatten(output)
|
||||
ta_index_to_write = time if return_all_outputs else 0
|
||||
output_ta_t = tuple(
|
||||
ta.write(ta_index_to_write, out)
|
||||
for ta, out in zip(output_ta_t, flat_output)
|
||||
)
|
||||
|
||||
new_states = tf.nest.pack_sequence_as(
|
||||
initial_states, flat_new_state
|
||||
)
|
||||
return (time + 1, output_ta_t) + tuple(new_states)
|
||||
|
||||
final_outputs = tf.while_loop(
|
||||
body=_step,
|
||||
loop_vars=(time, output_ta) + states,
|
||||
**while_loop_kwargs,
|
||||
)
|
||||
new_states = final_outputs[2:]
|
||||
|
||||
output_ta = final_outputs[1]
|
||||
|
||||
outputs = tuple(o.stack() for o in output_ta)
|
||||
last_output = tuple(o[-1] for o in outputs)
|
||||
|
||||
outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
|
||||
last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
|
||||
|
||||
# static shape inference
|
||||
def set_shape(output_):
|
||||
if isinstance(output_, tf.Tensor):
|
||||
shape = output_.shape.as_list()
|
||||
if return_all_outputs:
|
||||
shape[0] = time_steps
|
||||
else:
|
||||
shape[0] = 1
|
||||
shape[1] = batch
|
||||
output_.set_shape(shape)
|
||||
return output_
|
||||
|
||||
outputs = tf.nest.map_structure(set_shape, outputs)
|
||||
|
||||
if not time_major:
|
||||
outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
|
||||
|
||||
return last_output, outputs, new_states
|
||||
|
||||
|
||||
def gru(
|
||||
inputs,
|
||||
initial_state,
|
||||
mask,
|
||||
kernel,
|
||||
recurrent_kernel,
|
||||
bias,
|
||||
activation,
|
||||
recurrent_activation,
|
||||
return_sequences=False,
|
||||
go_backwards=False,
|
||||
unroll=False,
|
||||
time_major=False,
|
||||
reset_after=True,
|
||||
):
|
||||
args_supported = _do_gru_arguments_support_cudnn(
|
||||
activation=activation,
|
||||
recurrent_activation=recurrent_activation,
|
||||
unroll=unroll,
|
||||
bias=bias,
|
||||
reset_after=reset_after,
|
||||
)
|
||||
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
|
||||
if not args_supported or not inputs_supported:
|
||||
raise NotImplementedError
|
||||
|
||||
from keras_core.backend.tensorflow import Variable
|
||||
|
||||
if isinstance(kernel, Variable):
|
||||
kernel = kernel.value
|
||||
if isinstance(recurrent_kernel, Variable):
|
||||
recurrent_kernel = recurrent_kernel.value
|
||||
if isinstance(bias, Variable):
|
||||
bias = bias.value
|
||||
|
||||
try:
|
||||
return _cudnn_gru(
|
||||
inputs,
|
||||
initial_state,
|
||||
kernel,
|
||||
recurrent_kernel,
|
||||
bias,
|
||||
mask,
|
||||
time_major,
|
||||
go_backwards,
|
||||
return_sequences,
|
||||
)
|
||||
except tf.errors.InvalidArgumentError:
|
||||
# cuDNN op not found.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _do_gru_arguments_support_cudnn(
|
||||
activation,
|
||||
recurrent_activation,
|
||||
unroll,
|
||||
bias,
|
||||
reset_after,
|
||||
):
|
||||
from keras_core import activations
|
||||
from keras_core import operations as ops
|
||||
|
||||
return (
|
||||
activation in (activations.tanh, tf.tanh, ops.tanh)
|
||||
and recurrent_activation
|
||||
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
|
||||
and not unroll
|
||||
and bias is not None
|
||||
and reset_after
|
||||
)
|
||||
|
||||
|
||||
def _do_rnn_inputs_support_cudnn(mask, time_major):
|
||||
if tf.sysconfig.get_build_info()["is_rocm_build"]:
|
||||
if mask is not None:
|
||||
return tf.reduce_all(mask)
|
||||
return True
|
||||
if mask is None:
|
||||
return True
|
||||
if time_major:
|
||||
mask = tf.transpose(mask)
|
||||
return tf.logical_and(
|
||||
_is_sequence_right_padded(mask),
|
||||
tf.logical_not(_has_fully_masked_sequence(mask)),
|
||||
)
|
||||
|
||||
|
||||
def _is_sequence_right_padded(mask):
|
||||
"""Check the mask tensor and see if it right padded.
|
||||
|
||||
For cuDNN kernel, it uses the sequence length param to skip the tailing
|
||||
timestep. If the data is left padded, or not a strict right padding (has
|
||||
masked value in the middle of the sequence), then cuDNN kernel won't be work
|
||||
properly in those cases.
|
||||
|
||||
Left padded data: [[False, False, True, True, True]].
|
||||
Right padded data: [[True, True, True, False, False]].
|
||||
Mixture of mask/unmasked data: [[True, False, True, False, False]].
|
||||
|
||||
Note that for the mixed data example above, the actually data RNN should see
|
||||
are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
|
||||
not pollute the internal states.
|
||||
|
||||
Args:
|
||||
mask: the Boolean tensor with shape [batch, timestep]
|
||||
|
||||
Returns:
|
||||
boolean scalar tensor, whether the mask is strictly right padded.
|
||||
"""
|
||||
max_seq_length = tf.shape(mask)[1]
|
||||
count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
|
||||
right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
|
||||
return tf.reduce_all(tf.equal(mask, right_padded_mask))
|
||||
|
||||
|
||||
def _has_fully_masked_sequence(mask):
|
||||
# Cudnn kernel will error out if the input sequence contains any
|
||||
# fully masked data. We walk around this issue by rerouting the computation
|
||||
# to standard kernel, until the issue on cudnn side has been fixed. For a
|
||||
# fully masked sequence, it will contain all Falses. To make it easy to
|
||||
# check, we inverse the boolean, check if any of the sequence has all True.
|
||||
return tf.reduce_any(tf.reduce_all(tf.logical_not(mask), axis=1))
|
||||
|
||||
|
||||
def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
|
||||
"""Utility function convert variable to cuDNN compatible parameter.
|
||||
|
||||
Note that Keras weights for kernels are different from the cuDNN format.
|
||||
Eg.:
|
||||
|
||||
```
|
||||
Keras cuDNN
|
||||
[[0, 1, 2], <---> [[0, 2, 4],
|
||||
[3, 4, 5]] [1, 3, 5]]
|
||||
```
|
||||
|
||||
If the input weights need to be in a unified format, then set
|
||||
`transpose_weights=True` to convert the weights.
|
||||
|
||||
Args:
|
||||
weights: list of weights for the kernels and recurrent kernels.
|
||||
biases: list of biases for individual gate.
|
||||
shape: the shape for the converted variables that will be feed to cuDNN.
|
||||
transpose_weights: boolean, whether to transpose the weights.
|
||||
|
||||
Returns:
|
||||
The converted weights that can be feed to cuDNN ops as param.
|
||||
"""
|
||||
|
||||
def convert(w):
|
||||
return tf.transpose(w) if transpose_weights else w
|
||||
|
||||
weights = [tf.reshape(convert(x), shape) for x in weights]
|
||||
biases = [tf.reshape(x, shape) for x in biases]
|
||||
return tf.concat(weights + biases, axis=0)
|
||||
|
||||
|
||||
def _compute_sequence_length_from_mask(mask, time_major):
|
||||
"""Calculate the sequence length tensor (1-D) based on the masking tensor.
|
||||
|
||||
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
|
||||
any timestep that should be masked, the corresponding field will be False.
|
||||
Consider the following example:
|
||||
a = [[True, True, False, False],
|
||||
[True, True, True, False]]
|
||||
It is a (2, 4) tensor, and the corresponding sequence length result should
|
||||
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
|
||||
padded that could be checked by, e.g., `is_sequence_right_padded()`.
|
||||
|
||||
Args:
|
||||
mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
|
||||
time_major=True.
|
||||
time_major: Boolean, which indicates whether the mask is time major or
|
||||
batch major.
|
||||
Returns:
|
||||
sequence_length: 1D int32 tensor.
|
||||
"""
|
||||
timestep_index = 0 if time_major else 1
|
||||
return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)
|
||||
|
||||
|
||||
@tf.function(autograph=False)
|
||||
def _cudnn_gru(
|
||||
inputs,
|
||||
initial_state,
|
||||
kernel,
|
||||
recurrent_kernel,
|
||||
bias,
|
||||
mask,
|
||||
time_major,
|
||||
go_backwards,
|
||||
return_sequences,
|
||||
):
|
||||
"""GRU with cuDNN implementation which is only available for GPU."""
|
||||
if mask is not None:
|
||||
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
|
||||
else:
|
||||
sequence_lengths = None
|
||||
|
||||
if not time_major and sequence_lengths is None:
|
||||
inputs = tf.transpose(inputs, perm=(1, 0, 2))
|
||||
seq_axis, batch_axis = (0, 1)
|
||||
else:
|
||||
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
|
||||
# For init_h, cuDNN expects one more dim of num_layers before or after batch
|
||||
# dim for time major or batch major inputs respectively
|
||||
init_h = tf.expand_dims(initial_state, axis=seq_axis)
|
||||
|
||||
weights = tf.split(kernel, 3, axis=1)
|
||||
weights += tf.split(recurrent_kernel, 3, axis=1)
|
||||
# Note that the bias was initialized as shape (2, 3 * units), flatten it to
|
||||
# (6 * units)
|
||||
bias = tf.split(tf.reshape(bias, [-1]), 6)
|
||||
|
||||
if tf.sysconfig.get_build_info()["is_cuda_build"]:
|
||||
# Note that the gate order for cuDNN is different from the canonical
|
||||
# format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
|
||||
# The swap need to be done for kernel, recurrent_kernel, input_bias,
|
||||
# recurrent_bias.
|
||||
# z is update gate weights.
|
||||
# r is reset gate weights.
|
||||
# h is output gate weights.
|
||||
weights[0], weights[1] = weights[1], weights[0]
|
||||
weights[3], weights[4] = weights[4], weights[3]
|
||||
bias[0], bias[1] = bias[1], bias[0]
|
||||
bias[3], bias[4] = bias[4], bias[3]
|
||||
|
||||
params = _standardize_cudnn_weights(
|
||||
weights=weights,
|
||||
biases=bias,
|
||||
shape=tf.constant([-1]),
|
||||
transpose_weights=True,
|
||||
)
|
||||
|
||||
if sequence_lengths is not None:
|
||||
if go_backwards:
|
||||
# Three reversals are required. E.g.,
|
||||
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
|
||||
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
|
||||
# output_from_cudnn = [6, 5, 4, 0, 0]
|
||||
# expected_output = [0, 0, 6, 5 ,4]
|
||||
inputs = tf.reverse_sequence(
|
||||
inputs,
|
||||
sequence_lengths,
|
||||
seq_axis=seq_axis,
|
||||
batch_axis=batch_axis,
|
||||
)
|
||||
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
|
||||
input=inputs,
|
||||
input_h=init_h,
|
||||
input_c=0,
|
||||
params=params,
|
||||
is_training=True,
|
||||
rnn_mode="gru",
|
||||
sequence_lengths=sequence_lengths,
|
||||
time_major=time_major,
|
||||
)
|
||||
if go_backwards:
|
||||
outputs = tf.reverse_sequence(
|
||||
outputs,
|
||||
sequence_lengths,
|
||||
seq_axis=seq_axis,
|
||||
batch_axis=batch_axis,
|
||||
)
|
||||
outputs = tf.reverse(outputs, axis=[seq_axis])
|
||||
else:
|
||||
if go_backwards:
|
||||
# Reverse axis 0 since the input is already convert to time major.
|
||||
inputs = tf.reverse(inputs, axis=[0])
|
||||
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
|
||||
input=inputs,
|
||||
input_h=init_h,
|
||||
input_c=0,
|
||||
params=params,
|
||||
is_training=True,
|
||||
rnn_mode="gru",
|
||||
)
|
||||
|
||||
last_output = outputs[-1]
|
||||
if not time_major and sequence_lengths is None and return_sequences:
|
||||
outputs = tf.transpose(outputs, perm=[1, 0, 2])
|
||||
state = tf.squeeze(h, axis=seq_axis)
|
||||
|
||||
# In the case of variable length input, the cudnn kernel will fill zeros for
|
||||
# the output, whereas the default keras behavior is to bring over the
|
||||
# previous output for t-1, so that in the return_sequence=False case, user
|
||||
# can quickly get the final effect output instead just 0s at the last
|
||||
# timestep. In order to mimic the default keras behavior, we copy the final
|
||||
# h state as the last_output, since it is numerically same as the output.
|
||||
if sequence_lengths is not None:
|
||||
last_output = state
|
||||
|
||||
# Match CPU return format
|
||||
if not return_sequences:
|
||||
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
|
||||
|
||||
return (
|
||||
last_output,
|
||||
outputs,
|
||||
state,
|
||||
)
|
||||
|
||||
|
||||
def lstm():
|
||||
pass
|
@ -114,15 +114,12 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
)
|
||||
return outputs
|
||||
|
||||
def train_function(iterator):
|
||||
"""Runs a training execution with multiple steps."""
|
||||
for _ in tf.range(self.steps_per_execution):
|
||||
outputs = one_step_on_iterator(iterator)
|
||||
return outputs
|
||||
|
||||
if not self.run_eagerly:
|
||||
train_function = tf.function(train_function, reduce_retracing=True)
|
||||
|
||||
train_function = tf.function(
|
||||
one_step_on_iterator, reduce_retracing=True
|
||||
)
|
||||
else:
|
||||
train_function = one_step_on_iterator
|
||||
self.train_function = train_function
|
||||
|
||||
def make_test_function(self, force=False):
|
||||
@ -247,7 +244,6 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
shuffle=shuffle,
|
||||
class_weight=class_weight,
|
||||
distribute_strategy=self.distribute_strategy,
|
||||
steps_per_execution=self.steps_per_execution,
|
||||
)
|
||||
|
||||
# Container that configures and calls callbacks.
|
||||
@ -449,9 +445,7 @@ class TFEpochIterator(EpochIterator):
|
||||
self.data_adapter.get_tf_dataset()
|
||||
)
|
||||
)
|
||||
for step in range(
|
||||
0, self.steps_per_epoch, self.steps_per_execution
|
||||
):
|
||||
for step in range(self.steps_per_epoch):
|
||||
yield step, self._current_iterator
|
||||
else:
|
||||
iterator = iter(
|
||||
@ -460,14 +454,12 @@ class TFEpochIterator(EpochIterator):
|
||||
)
|
||||
)
|
||||
if self.num_batches:
|
||||
for step in range(
|
||||
0, self.num_batches, self.steps_per_execution
|
||||
):
|
||||
for step in range(self.num_batches):
|
||||
yield step, iterator
|
||||
else:
|
||||
step = -1
|
||||
while True:
|
||||
step += self.steps_per_execution
|
||||
step += 1
|
||||
self._steps_seen = step + 1
|
||||
yield step, iterator
|
||||
self.data_adapter.on_epoch_end()
|
||||
|
@ -99,6 +99,7 @@ from keras_core.layers.reshaping.permute import Permute
|
||||
from keras_core.layers.reshaping.repeat_vector import RepeatVector
|
||||
from keras_core.layers.reshaping.reshape import Reshape
|
||||
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
|
||||
from keras_core.layers.rnn.gru import GRU
|
||||
from keras_core.layers.rnn.rnn import RNN
|
||||
from keras_core.layers.rnn.simple_rnn import SimpleRNN
|
||||
from keras_core.layers.rnn.stacked_rnn_cells import StackedRNNCells
|
||||
|
@ -14,11 +14,11 @@ from keras_core.operations.operation_utils import compute_conv_output_shape
|
||||
class BaseConv(Layer):
|
||||
"""Abstract N-D convolution layer (private, used as implementation base).
|
||||
|
||||
This layer creates a convolution kernel that is convolved (actually
|
||||
cross-correlated) with the layer input to produce a tensor of outputs. If
|
||||
`use_bias` is True (and a `bias_initializer` is provided), a bias vector is
|
||||
created and added to the outputs. Finally, if `activation` is not `None`, it
|
||||
is applied to the outputs as well.
|
||||
This layer creates a convolution kernel that is convolved
|
||||
(actually cross-correlated) with the layer input to produce a tensor of
|
||||
outputs. If `use_bias` is True (and a `bias_initializer` is provided),
|
||||
a bias vector is created and added to the outputs. Finally, if
|
||||
`activation` is not `None`, it is applied to the outputs as well.
|
||||
|
||||
Note: layer attributes cannot be modified after the layer has been called
|
||||
once (except the `trainable` attribute).
|
||||
@ -27,12 +27,12 @@ class BaseConv(Layer):
|
||||
rank: int, the rank of the convolution, e.g. 2 for 2D convolution.
|
||||
filters: int, the dimension of the output space (the number of filters
|
||||
in the convolution).
|
||||
kernel_size: int or tuple/list of `rank` integers, specifying the size
|
||||
of the convolution window.
|
||||
strides: int or tuple/list of `rank` integers, specifying the stride
|
||||
length of the convolution. If only one int is specified, the same
|
||||
stride size will be used for all dimensions. `strides > 1` is
|
||||
incompatible with `dilation_rate > 1`.
|
||||
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
|
||||
size of the convolution window.
|
||||
strides: int or tuple/list of N integers, specifying the stride length
|
||||
of the convolution. If only one int is specified, the same stride
|
||||
size will be used for all dimensions. `stride value != 1` is
|
||||
incompatible with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -44,9 +44,9 @@ class BaseConv(Layer):
|
||||
`(batch, features, steps)`. It defaults to the `image_data_format`
|
||||
value found in your Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
dilation_rate: int or tuple/list of `rank` integers, specifying the
|
||||
dilation rate to use for dilated convolution. If only one int is
|
||||
specified, the same dilation rate will be used for all dimensions.
|
||||
dilation_rate: int or tuple/list of N integers, specifying the dilation
|
||||
rate to use for dilated convolution. If only one int is specified,
|
||||
the same dilation rate will be used for all dimensions.
|
||||
groups: A positive int specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters // groups` filters. The output is the
|
||||
@ -97,7 +97,7 @@ class BaseConv(Layer):
|
||||
super().__init__(
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
activity_regularizer=activity_regularizer,
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs,
|
||||
)
|
||||
self.rank = rank
|
||||
@ -144,14 +144,14 @@ class BaseConv(Layer):
|
||||
|
||||
if not all(self.kernel_size):
|
||||
raise ValueError(
|
||||
"The argument `kernel_size` cannot contain 0. Received "
|
||||
f"kernel_size={self.kernel_size}."
|
||||
"The argument `kernel_size` cannot contain 0(s). Received: "
|
||||
f"{self.kernel_size}"
|
||||
)
|
||||
|
||||
if not all(self.strides):
|
||||
raise ValueError(
|
||||
"The argument `strides` cannot contains 0. Received "
|
||||
f"strides={self.strides}"
|
||||
"The argument `strides` cannot contains 0(s). Received: "
|
||||
f"{self.strides}"
|
||||
)
|
||||
|
||||
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
|
||||
|
@ -14,25 +14,26 @@ from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
class BaseConvTranspose(Layer):
|
||||
"""Abstract N-D transposed convolution layer.
|
||||
"""Abstract N-D transpose convolution layer.
|
||||
|
||||
The need for transposed convolutions generally arises from the desire to use
|
||||
a transformation going in the opposite direction of a normal convolution,
|
||||
i.e., from something that has the shape of the output of some convolution to
|
||||
something that has the shape of its input while maintaining a connectivity
|
||||
pattern that is compatible with said convolution.
|
||||
The need for transposed convolutions generally arises
|
||||
from the desire to use a transformation going in the opposite direction
|
||||
of a normal convolution, i.e., from something that has the shape of the
|
||||
output of some convolution to something that has the shape of its input
|
||||
while maintaining a connectivity pattern that is compatible with
|
||||
said convolution.
|
||||
|
||||
Args:
|
||||
rank: int, the rank of the transposed convolution, e.g. 2 for 2D
|
||||
transposed convolution.
|
||||
filters: int, the dimension of the output space (the number of filters
|
||||
in the transposed convolution).
|
||||
kernel_size: int or tuple/list of `rank` integers, specifying the size
|
||||
of the transposed convolution window.
|
||||
strides: int or tuple/list of `rank` integers, specifying the stride
|
||||
length of the transposed convolution. If only one int is specified,
|
||||
the same stride size will be used for all dimensions.
|
||||
`strides > 1` is incompatible with `dilation_rate > 1`.
|
||||
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
|
||||
size of the transposed convolution window.
|
||||
strides: int or tuple/list of N integers, specifying the stride length
|
||||
of the transposed convolution. If only one int is specified, the
|
||||
same stride size will be used for all dimensions.
|
||||
`stride value != 1` is incompatible with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -44,9 +45,9 @@ class BaseConvTranspose(Layer):
|
||||
`(batch, features, steps)`. It defaults to the `image_data_format`
|
||||
value found in your Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
dilation_rate: int or tuple/list of `rank` integers, specifying the
|
||||
dilation rate to use for dilated convolution. If only one int is
|
||||
specified, the same dilation rate will be used for all dimensions.
|
||||
dilation_rate: int or tuple/list of N integers, specifying the dilation
|
||||
rate to use for dilated convolution. If only one int is specified,
|
||||
the same dilation rate will be used for all dimensions.
|
||||
activation: Activation function. If `None`, no activation is applied.
|
||||
use_bias: bool, if `True`, bias will be added to the output.
|
||||
kernel_initializer: Initializer for the convolution kernel. If `None`,
|
||||
@ -130,14 +131,14 @@ class BaseConvTranspose(Layer):
|
||||
|
||||
if not all(self.kernel_size):
|
||||
raise ValueError(
|
||||
"The argument `kernel_size` cannot contain 0. Received "
|
||||
f"kernel_size={self.kernel_size}."
|
||||
"The argument `kernel_size` cannot contain 0. Received: "
|
||||
f"{self.kernel_size}"
|
||||
)
|
||||
|
||||
if not all(self.strides):
|
||||
raise ValueError(
|
||||
"The argument `strides` cannot contains 0. Received "
|
||||
f"strides={self.strides}."
|
||||
"The argument `strides` cannot contains 0. Received: "
|
||||
f"{self.strides}"
|
||||
)
|
||||
|
||||
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
|
||||
|
@ -39,12 +39,12 @@ class BaseDepthwiseConv(Layer):
|
||||
depth_multiplier: The number of depthwise convolution output channels
|
||||
for each input channel. The total number of depthwise convolution
|
||||
output channels will be equal to `input_channel * depth_multiplier`.
|
||||
kernel_size: int or tuple/list of `rank` integers, specifying the size
|
||||
of the depthwise convolution window.
|
||||
strides: int or tuple/list of `rank` integers, specifying the stride
|
||||
length of the depthwise convolution. If only one int is specified,
|
||||
the same stride size will be used for all dimensions.
|
||||
`strides > 1` is incompatible with `dilation_rate > 1`.
|
||||
kernel_size: int or tuple/list of N integers (N=`rank`), specifying the
|
||||
size of the depthwise convolution window.
|
||||
strides: int or tuple/list of N integers, specifying the stride length
|
||||
of the depthwise convolution. If only one int is specified, the same
|
||||
stride size will be used for all dimensions. `stride value != 1` is
|
||||
incompatible with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -56,9 +56,9 @@ class BaseDepthwiseConv(Layer):
|
||||
`(batch, features, steps)`. It defaults to the `image_data_format`
|
||||
value found in your Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
dilation_rate: int or tuple/list of `rank` integers, specifying the
|
||||
dilation rate to use for dilated convolution. If only one int is
|
||||
specified, the same dilation rate will be used for all dimensions.
|
||||
dilation_rate: int or tuple/list of N integers, specifying the dilation
|
||||
rate to use for dilated convolution. If only one int is specified,
|
||||
the same dilation rate will be used for all dimensions.
|
||||
activation: Activation function. If `None`, no activation is applied.
|
||||
use_bias: bool, if `True`, bias will be added to the output.
|
||||
kernel_initializer: Initializer for the convolution kernel. If `None`,
|
||||
@ -143,14 +143,14 @@ class BaseDepthwiseConv(Layer):
|
||||
|
||||
if not all(self.kernel_size):
|
||||
raise ValueError(
|
||||
"The argument `kernel_size` cannot contain 0. Received "
|
||||
f"kernel_size={self.kernel_size}."
|
||||
"The argument `kernel_size` cannot contain 0(s). Received: "
|
||||
f"{self.kernel_size}"
|
||||
)
|
||||
|
||||
if not all(self.strides):
|
||||
raise ValueError(
|
||||
"The argument `strides` cannot contains 0. Received "
|
||||
f"strides={self.strides}"
|
||||
"The argument `strides` cannot contains 0(s). Received: "
|
||||
f"{self.strides}"
|
||||
)
|
||||
|
||||
if max(self.strides) > 1 and max(self.dilation_rate) > 1:
|
||||
|
@ -21,8 +21,8 @@ class Conv1D(BaseConv):
|
||||
kernel_size: int or tuple/list of 1 integer, specifying the size of the
|
||||
convolution window.
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the convolution. `stride value != 1` is incompatible with
|
||||
`dilation_rate != 1`.
|
||||
padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -72,9 +72,9 @@ class Conv1D(BaseConv):
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
|
||||
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
A 3D tensor with shape: `(batch_shape, filters, new_steps)`
|
||||
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
|
||||
|
||||
Returns:
|
||||
A 3D tensor representing `activation(conv1d(inputs, kernel) + bias)`.
|
||||
|
@ -25,8 +25,8 @@ class Conv1DTranspose(BaseConvTranspose):
|
||||
kernel_size: int or tuple/list of 1 integer, specifying the size of the
|
||||
transposed convolution window.
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the transposed convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the transposed convolution. `stride value != 1` is incompatible
|
||||
with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -40,6 +40,11 @@ class Conv1DTranspose(BaseConvTranspose):
|
||||
If you never set it, then it will be `"channels_last"`.
|
||||
dilation_rate: int or tuple/list of 1 integers, specifying the dilation
|
||||
rate to use for dilated transposed convolution.
|
||||
groups: A positive int specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters // groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function. If `None`, no activation is applied.
|
||||
use_bias: bool, if `True`, bias will be added to the output.
|
||||
kernel_initializer: Initializer for the convolution kernel. If `None`,
|
||||
@ -66,9 +71,9 @@ class Conv1DTranspose(BaseConvTranspose):
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
A 3D tensor with shape: `(batch_shape, new_steps, filters)`
|
||||
A 3D tensor with shape: `(batch_shape, new_steps, channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
A 3D tensor with shape: `(batch_shape, filters, new_steps)`
|
||||
A 3D tensor with shape: `(batch_shape, channels, new_steps)`
|
||||
|
||||
Returns:
|
||||
A 3D tensor representing
|
||||
|
@ -20,8 +20,8 @@ class Conv2D(BaseConv):
|
||||
kernel_size: int or tuple/list of 2 integer, specifying the size of the
|
||||
convolution window.
|
||||
strides: int or tuple/list of 2 integer, specifying the stride length
|
||||
of the convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the convolution. `stride value != 1` is incompatible with
|
||||
`dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -68,7 +68,7 @@ class Conv2D(BaseConv):
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
|
||||
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
|
||||
- If `data_format="channels_first"`:
|
||||
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`
|
||||
|
||||
|
@ -25,8 +25,8 @@ class Conv2DTranspose(BaseConvTranspose):
|
||||
kernel_size: int or tuple/list of 1 integer, specifying the size of the
|
||||
transposed convolution window.
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the transposed convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the transposed convolution. `stride value != 1` is incompatible
|
||||
with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -68,7 +68,7 @@ class Conv2DTranspose(BaseConvTranspose):
|
||||
|
||||
Output shape:
|
||||
- If `data_format="channels_last"`:
|
||||
A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`
|
||||
A 4D tensor with shape: `(batch_size, new_height, new_width filters)`
|
||||
- If `data_format="channels_first"`:
|
||||
A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`
|
||||
|
||||
|
@ -20,8 +20,8 @@ class Conv3D(BaseConv):
|
||||
kernel_size: int or tuple/list of 3 integer, specifying the size of the
|
||||
convolution window.
|
||||
strides: int or tuple/list of 3 integer, specifying the stride length
|
||||
of the convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the convolution. `stride value != 1` is incompatible with
|
||||
`dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -72,10 +72,10 @@ class Conv3D(BaseConv):
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
|
||||
filters)`
|
||||
channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, filters, new_spatial_dim1, new_spatial_dim2,
|
||||
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
|
||||
new_spatial_dim3)`
|
||||
|
||||
Returns:
|
||||
|
@ -25,8 +25,8 @@ class Conv3DTranspose(BaseConvTranspose):
|
||||
kernel_size: int or tuple/list of 1 integer, specifying the size of the
|
||||
transposed convolution window.
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the transposed convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the transposed convolution. `stride value != 1` is incompatible
|
||||
with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
@ -72,10 +72,10 @@ class Conv3DTranspose(BaseConvTranspose):
|
||||
- If `data_format="channels_last"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,
|
||||
filters)`
|
||||
channels)`
|
||||
- If `data_format="channels_first"`:
|
||||
5D tensor with shape:
|
||||
`(batch_size, filters, new_spatial_dim1, new_spatial_dim2,
|
||||
`(batch_size, channels, new_spatial_dim1, new_spatial_dim2,
|
||||
new_spatial_dim3)`
|
||||
|
||||
Returns:
|
||||
|
@ -417,4 +417,4 @@ class ConvTransposeCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
tf_keras_layer.bias.assign(bias_weights)
|
||||
outputs = layer(inputs)
|
||||
expected = tf_keras_layer(inputs)
|
||||
self.assertAllClose(outputs, expected)
|
||||
self.assertAllClose(outputs, expected, atol=1e-5)
|
||||
|
@ -34,8 +34,8 @@ class DepthwiseConv1D(BaseDepthwiseConv):
|
||||
kernel_size: int or tuple/list of 1 integer, specifying the size of the
|
||||
depthwise convolution window.
|
||||
strides: int or tuple/list of 1 integer, specifying the stride length
|
||||
of the convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the convolution. `stride value != 1` is incompatible
|
||||
with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
|
@ -34,8 +34,8 @@ class DepthwiseConv2D(BaseDepthwiseConv):
|
||||
kernel_size: int or tuple/list of 2 integer, specifying the size of the
|
||||
depthwise convolution window.
|
||||
strides: int or tuple/list of 2 integer, specifying the stride length
|
||||
of the depthwise convolution. `strides > 1` is incompatible with
|
||||
`dilation_rate > 1`.
|
||||
of the depthwise convolution. `stride value != 1` is incompatible
|
||||
with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
|
@ -29,8 +29,8 @@ class SeparableConv1D(BaseSeparableConv):
|
||||
depthwise convolution window.
|
||||
strides: int or tuple/list of 1 integers, specifying the stride length
|
||||
of the depthwise convolution. If only one int is specified, the same
|
||||
stride size will be used for all dimensions. `strides > 1` is
|
||||
incompatible with `dilation_rate > 1`.
|
||||
stride size will be used for all dimensions. `stride value != 1` is
|
||||
incompatible with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
|
@ -29,8 +29,8 @@ class SeparableConv2D(BaseSeparableConv):
|
||||
depthwise convolution window.
|
||||
strides: int or tuple/list of 2 integers, specifying the stride length
|
||||
of the depthwise convolution. If only one int is specified, the same
|
||||
stride size will be used for all dimensions. `strides > 1` is
|
||||
incompatible with `dilation_rate > 1`.
|
||||
stride size will be used for all dimensions. `stride value != 1` is
|
||||
incompatible with `dilation_rate != 1`.
|
||||
padding: string, either `"valid"` or `"same"` (case-insensitive).
|
||||
`"valid"` means no padding. `"same"` results in padding evenly to
|
||||
the left/right or up/down of the input such that output has the same
|
||||
|
607
keras_core/layers/rnn/gru.py
Normal file
607
keras_core/layers/rnn/gru.py
Normal file
@ -0,0 +1,607 @@
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import activations
|
||||
from keras_core import backend
|
||||
from keras_core import constraints
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.layers.rnn.dropout_rnn_cell import DropoutRNNCell
|
||||
from keras_core.layers.rnn.rnn import RNN
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.GRUCell")
|
||||
class GRUCell(Layer, DropoutRNNCell):
|
||||
"""Cell class for the GRU layer.
|
||||
|
||||
This class processes one step within the whole time sequence input, whereas
|
||||
`keras_core.layer.GRU` processes the whole sequence.
|
||||
|
||||
Args:
|
||||
units: Positive integer, dimensionality of the output space.
|
||||
activation: Activation function to use. Default: hyperbolic tangent
|
||||
(`tanh`). If you pass None, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
recurrent_activation: Activation function to use for the recurrent step.
|
||||
Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
|
||||
applied (ie. "linear" activation: `a(x) = x`).
|
||||
use_bias: Boolean, (default `True`), whether the layer
|
||||
should use a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix,
|
||||
used for the linear transformation of the inputs. Default:
|
||||
`"glorot_uniform"`.
|
||||
recurrent_initializer: Initializer for the `recurrent_kernel`
|
||||
weights matrix, used for the linear transformation
|
||||
of the recurrent state. Default: `"orthogonal"`.
|
||||
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix. Default: `None`.
|
||||
recurrent_regularizer: Regularizer function applied to the
|
||||
`recurrent_kernel` weights matrix. Default: `None`.
|
||||
bias_regularizer: Regularizer function applied to the bias vector.
|
||||
Default: `None`.
|
||||
kernel_constraint: Constraint function applied to the `kernel` weights
|
||||
matrix. Default: `None`.
|
||||
recurrent_constraint: Constraint function applied to the
|
||||
`recurrent_kernel` weights matrix. Default: `None`.
|
||||
bias_constraint: Constraint function applied to the bias vector.
|
||||
Default: `None`.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
linear transformation of the inputs. Default: 0.
|
||||
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
|
||||
for the linear transformation of the recurrent state. Default: 0.
|
||||
reset_after: GRU convention (whether to apply reset gate after or
|
||||
before matrix multiplication). False = "before",
|
||||
True = "after" (default and cuDNN compatible).
|
||||
|
||||
Call arguments:
|
||||
inputs: A 2D tensor, with shape `(batch, features)`.
|
||||
states: A 2D tensor with shape `(batch, units)`, which is the state
|
||||
from the previous time step.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode or in inference mode. Only relevant when `dropout` or
|
||||
`recurrent_dropout` is used.
|
||||
|
||||
Example:
|
||||
|
||||
>>> inputs = np.random.random((32, 10, 8))
|
||||
>>> rnn = keras_core.layers.RNN(keras_core.layers.GRUCell(4))
|
||||
>>> output = rnn(inputs)
|
||||
>>> output.shape
|
||||
(32, 4)
|
||||
>>> rnn = keras_core.layers.RNN(
|
||||
... keras_core.layers.GRUCell(4),
|
||||
... return_sequences=True,
|
||||
... return_state=True)
|
||||
>>> whole_sequence_output, final_state = rnn(inputs)
|
||||
>>> whole_sequence_output.shape
|
||||
(32, 10, 4)
|
||||
>>> final_state.shape
|
||||
(32, 4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
activation="tanh",
|
||||
recurrent_activation="sigmoid",
|
||||
use_bias=True,
|
||||
kernel_initializer="glorot_uniform",
|
||||
recurrent_initializer="orthogonal",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
recurrent_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
recurrent_constraint=None,
|
||||
bias_constraint=None,
|
||||
dropout=0.0,
|
||||
recurrent_dropout=0.0,
|
||||
reset_after=True,
|
||||
seed=None,
|
||||
**kwargs,
|
||||
):
|
||||
if units <= 0:
|
||||
raise ValueError(
|
||||
"Received an invalid value for argument `units`, "
|
||||
f"expected a positive integer, got {units}."
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self.units = units
|
||||
self.activation = activations.get(activation)
|
||||
self.recurrent_activation = activations.get(recurrent_activation)
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
self.recurrent_initializer = initializers.get(recurrent_initializer)
|
||||
self.bias_initializer = initializers.get(bias_initializer)
|
||||
|
||||
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
|
||||
self.bias_regularizer = regularizers.get(bias_regularizer)
|
||||
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.recurrent_constraint = constraints.get(recurrent_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
self.dropout = min(1.0, max(0.0, dropout))
|
||||
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
|
||||
self.seed = seed
|
||||
self.seed_generator = backend.random.SeedGenerator(seed=seed)
|
||||
|
||||
self.reset_after = reset_after
|
||||
self.state_size = self.units
|
||||
self.output_size = self.units
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
input_dim = input_shape[-1]
|
||||
self.kernel = self.add_weight(
|
||||
shape=(input_dim, self.units * 3),
|
||||
name="kernel",
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint,
|
||||
)
|
||||
self.recurrent_kernel = self.add_weight(
|
||||
shape=(self.units, self.units * 3),
|
||||
name="recurrent_kernel",
|
||||
initializer=self.recurrent_initializer,
|
||||
regularizer=self.recurrent_regularizer,
|
||||
constraint=self.recurrent_constraint,
|
||||
)
|
||||
|
||||
if self.use_bias:
|
||||
if not self.reset_after:
|
||||
bias_shape = (3 * self.units,)
|
||||
else:
|
||||
# separate biases for input and recurrent kernels
|
||||
# Note: the shape is intentionally different from CuDNNGRU
|
||||
# biases `(2 * 3 * self.units,)`, so that we can distinguish the
|
||||
# classes when loading and converting saved weights.
|
||||
bias_shape = (2, 3 * self.units)
|
||||
self.bias = self.add_weight(
|
||||
shape=bias_shape,
|
||||
name="bias",
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint,
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
h_tm1 = (
|
||||
states[0] if nest.is_nested(states) else states
|
||||
) # previous state
|
||||
|
||||
dp_mask = self.get_dropout_mask(inputs)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
|
||||
|
||||
if self.use_bias:
|
||||
if not self.reset_after:
|
||||
input_bias, recurrent_bias = self.bias, None
|
||||
else:
|
||||
input_bias, recurrent_bias = (
|
||||
ops.squeeze(e, axis=0)
|
||||
for e in ops.split(self.bias, self.bias.shape[0], axis=0)
|
||||
)
|
||||
|
||||
if training and 0.0 < self.dropout < 1.0:
|
||||
inputs *= dp_mask
|
||||
inputs_z = inputs
|
||||
inputs_r = inputs
|
||||
inputs_h = inputs
|
||||
|
||||
x_z = ops.matmul(inputs_z, self.kernel[:, : self.units])
|
||||
x_r = ops.matmul(inputs_r, self.kernel[:, self.units : self.units * 2])
|
||||
x_h = ops.matmul(inputs_h, self.kernel[:, self.units * 2 :])
|
||||
|
||||
if self.use_bias:
|
||||
x_z += input_bias[: self.units]
|
||||
x_r += input_bias[self.units : self.units * 2]
|
||||
x_h += input_bias[self.units * 2 :]
|
||||
|
||||
if training and 0.0 < self.recurrent_dropout < 1.0:
|
||||
h_tm1 *= rec_dp_mask
|
||||
h_tm1_z = h_tm1
|
||||
h_tm1_r = h_tm1
|
||||
h_tm1_h = h_tm1
|
||||
|
||||
recurrent_z = ops.matmul(
|
||||
h_tm1_z, self.recurrent_kernel[:, : self.units]
|
||||
)
|
||||
recurrent_r = ops.matmul(
|
||||
h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
|
||||
)
|
||||
if self.reset_after and self.use_bias:
|
||||
recurrent_z += recurrent_bias[: self.units]
|
||||
recurrent_r += recurrent_bias[self.units : self.units * 2]
|
||||
|
||||
z = self.recurrent_activation(x_z + recurrent_z)
|
||||
r = self.recurrent_activation(x_r + recurrent_r)
|
||||
|
||||
# reset gate applied after/before matrix multiplication
|
||||
if self.reset_after:
|
||||
recurrent_h = ops.matmul(
|
||||
h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
|
||||
)
|
||||
if self.use_bias:
|
||||
recurrent_h += recurrent_bias[self.units * 2 :]
|
||||
recurrent_h = r * recurrent_h
|
||||
else:
|
||||
recurrent_h = ops.matmul(
|
||||
r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
|
||||
)
|
||||
|
||||
hh = self.activation(x_h + recurrent_h)
|
||||
|
||||
# previous and candidate state mixed by update gate
|
||||
h = z * h_tm1 + (1 - z) * hh
|
||||
new_state = [h] if nest.is_nested(states) else h
|
||||
return h, new_state
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"units": self.units,
|
||||
"activation": activations.serialize(self.activation),
|
||||
"recurrent_activation": activations.serialize(
|
||||
self.recurrent_activation
|
||||
),
|
||||
"use_bias": self.use_bias,
|
||||
"kernel_initializer": initializers.serialize(
|
||||
self.kernel_initializer
|
||||
),
|
||||
"recurrent_initializer": initializers.serialize(
|
||||
self.recurrent_initializer
|
||||
),
|
||||
"bias_initializer": initializers.serialize(self.bias_initializer),
|
||||
"kernel_regularizer": regularizers.serialize(
|
||||
self.kernel_regularizer
|
||||
),
|
||||
"recurrent_regularizer": regularizers.serialize(
|
||||
self.recurrent_regularizer
|
||||
),
|
||||
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
||||
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
||||
"recurrent_constraint": constraints.serialize(
|
||||
self.recurrent_constraint
|
||||
),
|
||||
"bias_constraint": constraints.serialize(self.bias_constraint),
|
||||
"dropout": self.dropout,
|
||||
"recurrent_dropout": self.recurrent_dropout,
|
||||
"reset_after": self.reset_after,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
def get_initial_state(self, batch_size=None):
|
||||
return [
|
||||
ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype)
|
||||
]
|
||||
|
||||
|
||||
class GRU(RNN):
|
||||
"""Gated Recurrent Unit - Cho et al. 2014.
|
||||
|
||||
Based on available runtime hardware and constraints, this layer
|
||||
will choose different implementations (cuDNN-based or backend-native)
|
||||
to maximize the performance. If a GPU is available and all
|
||||
the arguments to the layer meet the requirement of the cuDNN kernel
|
||||
(see below for details), the layer will use a fast cuDNN implementation
|
||||
when using the TensorFlow backend.
|
||||
|
||||
The requirements to use the cuDNN implementation are:
|
||||
|
||||
1. `activation` == `tanh`
|
||||
2. `recurrent_activation` == `sigmoid`
|
||||
3. `recurrent_dropout` == 0
|
||||
4. `unroll` is `False`
|
||||
5. `use_bias` is `True`
|
||||
6. `reset_after` is `True`
|
||||
7. Inputs, if use masking, are strictly right-padded.
|
||||
8. Eager execution is enabled in the outermost context.
|
||||
|
||||
There are two variants of the GRU implementation. The default one is based
|
||||
on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to
|
||||
hidden state before matrix multiplication. The other one is based on
|
||||
[original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
|
||||
|
||||
The second variant is compatible with CuDNNGRU (GPU-only) and allows
|
||||
inference on CPU. Thus it has separate biases for `kernel` and
|
||||
`recurrent_kernel`. To use this variant, set `reset_after=True` and
|
||||
`recurrent_activation='sigmoid'`.
|
||||
|
||||
For example:
|
||||
|
||||
>>> inputs = np.random.random((32, 10, 8))
|
||||
>>> gru = keras_core.layers.GRU(4)
|
||||
>>> output = gru(inputs)
|
||||
>>> output.shape
|
||||
(32, 4)
|
||||
>>> gru = keras_core.layers.GRU(4, return_sequences=True, return_state=True)
|
||||
>>> whole_sequence_output, final_state = gru(inputs)
|
||||
>>> whole_sequence_output.shape
|
||||
(32, 10, 4)
|
||||
>>> final_state.shape
|
||||
(32, 4)
|
||||
|
||||
Args:
|
||||
units: Positive integer, dimensionality of the output space.
|
||||
activation: Activation function to use.
|
||||
Default: hyperbolic tangent (`tanh`).
|
||||
If you pass `None`, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
recurrent_activation: Activation function to use
|
||||
for the recurrent step.
|
||||
Default: sigmoid (`sigmoid`).
|
||||
If you pass `None`, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
use_bias: Boolean, (default `True`), whether the layer
|
||||
should use a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix,
|
||||
used for the linear transformation of the inputs. Default:
|
||||
`"glorot_uniform"`.
|
||||
recurrent_initializer: Initializer for the `recurrent_kernel`
|
||||
weights matrix, used for the linear transformation of the recurrent
|
||||
state. Default: `"orthogonal"`.
|
||||
bias_initializer: Initializer for the bias vector. Default: `"zeros"`.
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix. Default: `None`.
|
||||
recurrent_regularizer: Regularizer function applied to the
|
||||
`recurrent_kernel` weights matrix. Default: `None`.
|
||||
bias_regularizer: Regularizer function applied to the bias vector.
|
||||
Default: `None`.
|
||||
activity_regularizer: Regularizer function applied to the output of the
|
||||
layer (its "activation"). Default: `None`.
|
||||
kernel_constraint: Constraint function applied to the `kernel` weights
|
||||
matrix. Default: `None`.
|
||||
recurrent_constraint: Constraint function applied to the
|
||||
`recurrent_kernel` weights matrix. Default: `None`.
|
||||
bias_constraint: Constraint function applied to the bias vector.
|
||||
Default: `None`.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
linear transformation of the inputs. Default: 0.
|
||||
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
|
||||
for the linear transformation of the recurrent state. Default: 0.
|
||||
return_sequences: Boolean. Whether to return the last output
|
||||
in the output sequence, or the full sequence. Default: `False`.
|
||||
return_state: Boolean. Whether to return the last state in addition
|
||||
to the output. Default: `False`.
|
||||
go_backwards: Boolean (default `False`).
|
||||
If `True`, process the input sequence backwards and return the
|
||||
reversed sequence.
|
||||
stateful: Boolean (default False). If `True`, the last state
|
||||
for each sample at index i in a batch will be used as initial
|
||||
state for the sample of index i in the following batch.
|
||||
unroll: Boolean (default False).
|
||||
If `True`, the network will be unrolled,
|
||||
else a symbolic loop will be used.
|
||||
Unrolling can speed-up a RNN,
|
||||
although it tends to be more memory-intensive.
|
||||
Unrolling is only suitable for short sequences.
|
||||
reset_after: GRU convention (whether to apply reset gate after or
|
||||
before matrix multiplication). `False` is `"before"`,
|
||||
`True` is `"after"` (default and cuDNN compatible).
|
||||
|
||||
Call arguments:
|
||||
inputs: A 3D tensor, with shape `(batch, timesteps, feature)`.
|
||||
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
|
||||
a given timestep should be masked (optional).
|
||||
An individual `True` entry indicates that the corresponding timestep
|
||||
should be utilized, while a `False` entry indicates that the
|
||||
corresponding timestep should be ignored. Defaults to `None`.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode or in inference mode. This argument is passed to the
|
||||
cell when calling it. This is only relevant if `dropout` or
|
||||
`recurrent_dropout` is used (optional). Defaults to `None`.
|
||||
initial_state: List of initial state tensors to be passed to the first
|
||||
call of the cell (optional, `None` causes creation
|
||||
of zero-filled initial state tensors). Defaults to `None`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
activation="tanh",
|
||||
recurrent_activation="sigmoid",
|
||||
use_bias=True,
|
||||
kernel_initializer="glorot_uniform",
|
||||
recurrent_initializer="orthogonal",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
recurrent_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
recurrent_constraint=None,
|
||||
bias_constraint=None,
|
||||
dropout=0.0,
|
||||
recurrent_dropout=0.0,
|
||||
return_sequences=False,
|
||||
return_state=False,
|
||||
go_backwards=False,
|
||||
stateful=False,
|
||||
unroll=False,
|
||||
reset_after=True,
|
||||
seed=None,
|
||||
**kwargs,
|
||||
):
|
||||
cell = GRUCell(
|
||||
units,
|
||||
activation=activation,
|
||||
recurrent_activation=recurrent_activation,
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=kernel_initializer,
|
||||
recurrent_initializer=recurrent_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
recurrent_regularizer=recurrent_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
kernel_constraint=kernel_constraint,
|
||||
recurrent_constraint=recurrent_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
reset_after=reset_after,
|
||||
dtype=kwargs.get("dtype", None),
|
||||
trainable=kwargs.get("trainable", True),
|
||||
name="gru_cell",
|
||||
seed=seed,
|
||||
)
|
||||
super().__init__(
|
||||
cell,
|
||||
return_sequences=return_sequences,
|
||||
return_state=return_state,
|
||||
go_backwards=go_backwards,
|
||||
stateful=stateful,
|
||||
unroll=unroll,
|
||||
activity_regularizer=activity_regularizer,
|
||||
**kwargs,
|
||||
)
|
||||
self.input_spec = InputSpec(ndim=3)
|
||||
|
||||
def inner_loop(self, sequence, initial_state, mask, training=False):
|
||||
if nest.is_nested(initial_state):
|
||||
initial_state = initial_state[0]
|
||||
if nest.is_nested(mask):
|
||||
mask = mask[0]
|
||||
|
||||
try:
|
||||
# Backends are allowed to specify (optionally) optimized
|
||||
# implementation of the inner GRU loop. In the case of
|
||||
# TF for instance, it will leverage cuDNN when feasible, and
|
||||
# it will raise NotImplementedError otherwise.
|
||||
return backend.gru(
|
||||
sequence,
|
||||
initial_state,
|
||||
mask,
|
||||
kernel=self.cell.kernel,
|
||||
recurrent_kernel=self.cell.recurrent_kernel,
|
||||
bias=self.cell.bias,
|
||||
activation=self.cell.activation,
|
||||
recurrent_activation=self.cell.recurrent_activation,
|
||||
return_sequences=self.return_sequences,
|
||||
go_backwards=self.go_backwards,
|
||||
unroll=self.unroll,
|
||||
reset_after=self.cell.reset_after,
|
||||
)
|
||||
except NotImplementedError:
|
||||
return super().inner_loop(
|
||||
sequence, initial_state, mask=mask, training=training
|
||||
)
|
||||
|
||||
def call(self, sequence, initial_state=None, mask=None, training=None):
|
||||
return super().call(
|
||||
sequence, mask=mask, training=training, initial_state=initial_state
|
||||
)
|
||||
|
||||
@property
|
||||
def units(self):
|
||||
return self.cell.units
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
return self.cell.activation
|
||||
|
||||
@property
|
||||
def recurrent_activation(self):
|
||||
return self.cell.recurrent_activation
|
||||
|
||||
@property
|
||||
def use_bias(self):
|
||||
return self.cell.use_bias
|
||||
|
||||
@property
|
||||
def kernel_initializer(self):
|
||||
return self.cell.kernel_initializer
|
||||
|
||||
@property
|
||||
def recurrent_initializer(self):
|
||||
return self.cell.recurrent_initializer
|
||||
|
||||
@property
|
||||
def bias_initializer(self):
|
||||
return self.cell.bias_initializer
|
||||
|
||||
@property
|
||||
def kernel_regularizer(self):
|
||||
return self.cell.kernel_regularizer
|
||||
|
||||
@property
|
||||
def recurrent_regularizer(self):
|
||||
return self.cell.recurrent_regularizer
|
||||
|
||||
@property
|
||||
def bias_regularizer(self):
|
||||
return self.cell.bias_regularizer
|
||||
|
||||
@property
|
||||
def kernel_constraint(self):
|
||||
return self.cell.kernel_constraint
|
||||
|
||||
@property
|
||||
def recurrent_constraint(self):
|
||||
return self.cell.recurrent_constraint
|
||||
|
||||
@property
|
||||
def bias_constraint(self):
|
||||
return self.cell.bias_constraint
|
||||
|
||||
@property
|
||||
def dropout(self):
|
||||
return self.cell.dropout
|
||||
|
||||
@property
|
||||
def recurrent_dropout(self):
|
||||
return self.cell.recurrent_dropout
|
||||
|
||||
@property
|
||||
def reset_after(self):
|
||||
return self.cell.reset_after
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"units": self.units,
|
||||
"activation": activations.serialize(self.activation),
|
||||
"recurrent_activation": activations.serialize(
|
||||
self.recurrent_activation
|
||||
),
|
||||
"use_bias": self.use_bias,
|
||||
"kernel_initializer": initializers.serialize(
|
||||
self.kernel_initializer
|
||||
),
|
||||
"recurrent_initializer": initializers.serialize(
|
||||
self.recurrent_initializer
|
||||
),
|
||||
"bias_initializer": initializers.serialize(self.bias_initializer),
|
||||
"kernel_regularizer": regularizers.serialize(
|
||||
self.kernel_regularizer
|
||||
),
|
||||
"recurrent_regularizer": regularizers.serialize(
|
||||
self.recurrent_regularizer
|
||||
),
|
||||
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
||||
"activity_regularizer": regularizers.serialize(
|
||||
self.activity_regularizer
|
||||
),
|
||||
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
||||
"recurrent_constraint": constraints.serialize(
|
||||
self.recurrent_constraint
|
||||
),
|
||||
"bias_constraint": constraints.serialize(self.bias_constraint),
|
||||
"dropout": self.dropout,
|
||||
"recurrent_dropout": self.recurrent_dropout,
|
||||
"reset_after": self.reset_after,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
del base_config["cell"]
|
||||
return {**base_config, **config}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
172
keras_core/layers/rnn/gru_test.py
Normal file
172
keras_core/layers/rnn/gru_test.py
Normal file
@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "tensorflow",
|
||||
reason="Only implemented for TF for now.",
|
||||
)
|
||||
class GRUTest(testing.TestCase):
|
||||
def test_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.GRU,
|
||||
init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5},
|
||||
input_shape=(3, 2, 4),
|
||||
call_kwargs={"training": True},
|
||||
expected_output_shape=(3, 3),
|
||||
expected_num_trainable_weights=3,
|
||||
expected_num_non_trainable_weights=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.GRU,
|
||||
init_kwargs={
|
||||
"units": 3,
|
||||
"return_sequences": True,
|
||||
"bias_regularizer": "l1",
|
||||
"kernel_regularizer": "l2",
|
||||
"recurrent_regularizer": "l2",
|
||||
},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 2, 3),
|
||||
expected_num_losses=3,
|
||||
expected_num_trainable_weights=3,
|
||||
expected_num_non_trainable_weights=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_correctness(self):
|
||||
sequence = np.arange(72).reshape((3, 6, 4)).astype("float32")
|
||||
layer = layers.GRU(
|
||||
3,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.5217289, 0.5217289, 0.5217289],
|
||||
[0.6371659, 0.6371659, 0.6371659],
|
||||
[0.39384964, 0.39384964, 0.3938496],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
layer = layers.GRU(
|
||||
3,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
go_backwards=True,
|
||||
)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.24406259, 0.24406259, 0.24406259],
|
||||
[0.611516, 0.611516, 0.611516],
|
||||
[0.3928808, 0.3928808, 0.3928808],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
layer = layers.GRU(
|
||||
3,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
unroll=True,
|
||||
)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.5217289, 0.5217289, 0.5217289],
|
||||
[0.6371659, 0.6371659, 0.6371659],
|
||||
[0.39384964, 0.39384964, 0.3938496],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
layer = layers.GRU(
|
||||
3,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
reset_after=False,
|
||||
)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.51447755, 0.51447755, 0.51447755],
|
||||
[0.6426879, 0.6426879, 0.6426879],
|
||||
[0.40208298, 0.40208298, 0.40208298],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
layer = layers.GRU(
|
||||
3,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
use_bias=False,
|
||||
)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.49988455, 0.49988455, 0.49988455],
|
||||
[0.64701194, 0.64701194, 0.64701194],
|
||||
[0.4103359, 0.4103359, 0.4103359],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
def test_statefulness(self):
|
||||
sequence = np.arange(24).reshape((2, 3, 4)).astype("float32")
|
||||
layer = layers.GRU(
|
||||
4,
|
||||
stateful=True,
|
||||
kernel_initializer=initializers.Constant(0.01),
|
||||
recurrent_initializer=initializers.Constant(0.02),
|
||||
bias_initializer=initializers.Constant(0.03),
|
||||
)
|
||||
layer(sequence)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.29542392, 0.29542392, 0.29542392, 0.29542392],
|
||||
[0.5885018, 0.5885018, 0.5885018, 0.5885018],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
layer.reset_state()
|
||||
layer(sequence)
|
||||
output = layer(sequence)
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[
|
||||
[0.29542392, 0.29542392, 0.29542392, 0.29542392],
|
||||
[0.5885018, 0.5885018, 0.5885018, 0.5885018],
|
||||
]
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
# TODO: test masking
|
0
keras_core/layers/rnn/lstm.py
Normal file
0
keras_core/layers/rnn/lstm.py
Normal file
0
keras_core/layers/rnn/lstm_test.py
Normal file
0
keras_core/layers/rnn/lstm_test.py
Normal file
@ -321,6 +321,32 @@ class RNN(Layer):
|
||||
for v in self.states:
|
||||
v.assign(ops.zeros_like(v))
|
||||
|
||||
def inner_loop(self, sequence, initial_state, mask, training=False):
|
||||
cell_kwargs = {}
|
||||
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
|
||||
cell_kwargs["training"] = training
|
||||
|
||||
def step(inputs, states):
|
||||
output, new_states = self.cell(inputs, states, **cell_kwargs)
|
||||
if not nest.is_nested(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
|
||||
if not nest.is_nested(initial_state):
|
||||
initial_state = [initial_state]
|
||||
|
||||
return backend.rnn(
|
||||
step,
|
||||
sequence,
|
||||
initial_state,
|
||||
go_backwards=self.go_backwards,
|
||||
mask=mask,
|
||||
unroll=self.unroll,
|
||||
input_length=sequence.shape[1],
|
||||
zero_output_for_mask=self.zero_output_for_mask,
|
||||
return_all_outputs=self.return_sequences,
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
sequence,
|
||||
@ -335,18 +361,12 @@ class RNN(Layer):
|
||||
"time dimension is undefined. \n"
|
||||
"- If using a Sequential model, "
|
||||
"specify the time dimension by passing "
|
||||
"an `input_shape` or `batch_input_shape` "
|
||||
"argument to your first layer. If your "
|
||||
"first layer is an Embedding, you can "
|
||||
"also use the `input_length` argument.\n"
|
||||
"an `Input()` as your first layer.\n"
|
||||
"- If using the functional API, specify "
|
||||
"the time dimension by passing a `shape` "
|
||||
"or `batch_shape` argument to your Input layer."
|
||||
"or `batch_shape` argument to your `Input()`."
|
||||
)
|
||||
|
||||
cell_kwargs = {}
|
||||
if isinstance(self.cell, Layer) and self.cell._call_has_training_arg():
|
||||
cell_kwargs["training"] = training
|
||||
if initial_state is None:
|
||||
if self.stateful:
|
||||
initial_state = self.states
|
||||
@ -366,22 +386,11 @@ class RNN(Layer):
|
||||
lambda x: ops.cast(x, dtype=self.compute_dtype), initial_state
|
||||
)
|
||||
|
||||
def step(inputs, states):
|
||||
output, new_states = self.cell(inputs, states, **cell_kwargs)
|
||||
if not nest.is_nested(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
|
||||
last_output, outputs, states = backend.nn.rnn(
|
||||
step,
|
||||
sequence,
|
||||
initial_state,
|
||||
go_backwards=self.go_backwards,
|
||||
last_output, outputs, states = self.inner_loop(
|
||||
sequence=sequence,
|
||||
initial_state=initial_state,
|
||||
mask=mask,
|
||||
unroll=self.unroll,
|
||||
input_length=timesteps,
|
||||
zero_output_for_mask=self.zero_output_for_mask,
|
||||
return_all_outputs=self.return_sequences,
|
||||
training=training,
|
||||
)
|
||||
self._maybe_reset_dropout_masks(self.cell)
|
||||
|
||||
|
@ -51,7 +51,7 @@ class SimpleRNNCell(Layer, DropoutRNNCell):
|
||||
for the linear transformation of the recurrent state. Default: 0.
|
||||
|
||||
Call arguments:
|
||||
inputs: A 2D tensor, with shape `(batch, feature)`.
|
||||
inputs: A 2D tensor, with shape `(batch, features)`.
|
||||
states: A 2D tensor with shape `(batch, units)`, which is the state
|
||||
from the previous time step.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
@ -334,7 +334,7 @@ class SimpleRNN(RNN):
|
||||
dropout=dropout,
|
||||
recurrent_dropout=recurrent_dropout,
|
||||
seed=seed,
|
||||
dtype=kwargs.get("dtype"),
|
||||
dtype=kwargs.get("dtype", None),
|
||||
trainable=kwargs.get("trainable", True),
|
||||
name="simple_rnn_cell",
|
||||
)
|
||||
@ -347,10 +347,9 @@ class SimpleRNN(RNN):
|
||||
unroll=unroll,
|
||||
**kwargs,
|
||||
)
|
||||
self.activity_regularizer = regularizers.get(activity_regularizer)
|
||||
self.input_spec = [InputSpec(ndim=3)]
|
||||
|
||||
def call(self, inputs, mask=None, training=None, initial_state=None):
|
||||
def call(self, inputs, initial_state=None, mask=None, training=None):
|
||||
return super().call(
|
||||
inputs, mask=mask, training=training, initial_state=initial_state
|
||||
)
|
||||
|
@ -102,5 +102,7 @@ class MetricTest(testing.TestCase):
|
||||
self.assertEqual(len(metric.variables), 6)
|
||||
|
||||
def test_serialization(self):
|
||||
# TODO
|
||||
pass
|
||||
self.run_class_serialization_test(
|
||||
ExampleMetric(name="mse"),
|
||||
custom_objects={"ExampleMetric": ExampleMetric},
|
||||
)
|
||||
|
@ -1,22 +0,0 @@
|
||||
"""
|
||||
scatter
|
||||
"""
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.backend import KerasTensor
|
||||
from keras_core.backend import any_symbolic_tensors
|
||||
from keras_core.operations.operation import Operation
|
||||
|
||||
|
||||
class Scatter(Operation):
|
||||
def call(self, indices, values, shape):
|
||||
return backend.core.scatter(indices, values, shape)
|
||||
|
||||
def compute_output_spec(self, indices, values, shape):
|
||||
return KerasTensor(shape, dtype=values.dtype)
|
||||
|
||||
|
||||
def scatter(indices, values, shape):
|
||||
if any_symbolic_tensors((indices, values, shape)):
|
||||
return Scatter().symbolic_call(indices, values, shape)
|
||||
return backend.core.scatter(indices, values, shape)
|
@ -1,64 +0,0 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.operations import core
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class CoreOpsDynamicShapeTest(testing.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class CoreOpsStaticShapeTest(testing.TestCase):
|
||||
def test_scatter(self):
|
||||
# Requires dtype
|
||||
indices = np.array([[0]], dtype="int32")
|
||||
values = np.array([0], dtype="int32")
|
||||
shape = (8,)
|
||||
self.assertEqual(core.scatter(indices, values, shape).shape, (8,))
|
||||
|
||||
|
||||
class CoreOpsCorrectnessTest(testing.TestCase):
|
||||
def test_scatter(self):
|
||||
# Test 1D
|
||||
indices = np.array([[1], [3], [4], [7]])
|
||||
values = np.array([9, 10, 11, 12])
|
||||
self.assertAllClose(
|
||||
core.scatter(indices, values, (8,)),
|
||||
[0, 9, 0, 10, 11, 0, 0, 12],
|
||||
)
|
||||
# Test 2D
|
||||
indices = np.array([[0, 1], [2, 0]])
|
||||
values = np.array([5, 10])
|
||||
self.assertAllClose(
|
||||
core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]]
|
||||
)
|
||||
# Test 3D
|
||||
indices = np.array([[1], [3]])
|
||||
values = np.array(
|
||||
[
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
]
|
||||
)
|
||||
self.assertAllClose(
|
||||
core.scatter(indices, values, (4, 4, 4)),
|
||||
[
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
],
|
||||
)
|
||||
# Test slices
|
||||
indices = np.array([[2], [4]])
|
||||
values = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
self.assertAllClose(
|
||||
core.scatter(indices, values, (6, 3)),
|
||||
[[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],
|
||||
)
|
@ -67,11 +67,15 @@ def get_metric(identifier, y_true, y_pred):
|
||||
y_true, y_pred
|
||||
)
|
||||
if is_binary:
|
||||
metric_obj = metrics_module.binary_accuracy
|
||||
metric_obj = metrics_module.BinaryAccuracy(name=str(identifier))
|
||||
elif is_sparse_categorical:
|
||||
metric_obj = metrics_module.sparse_categorical_accuracy
|
||||
metric_obj = metrics_module.SparseCategoricalAccuracy(
|
||||
name=str(identifier)
|
||||
)
|
||||
else:
|
||||
metric_obj = metrics_module.categorical_accuracy
|
||||
metric_obj = metrics_module.CategoricalAccuracy(
|
||||
name=str(identifier)
|
||||
)
|
||||
|
||||
if not isinstance(metric_obj, metrics_module.Metric):
|
||||
if isinstance(identifier, str):
|
||||
|
@ -178,6 +178,21 @@ class TestCompileMetrics(testing.TestCase):
|
||||
self.assertAllClose(result["mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
|
||||
|
||||
def test_name_conversions(self):
|
||||
compile_metrics = CompileMetrics(
|
||||
metrics=["acc", "accuracy"],
|
||||
weighted_metrics=[],
|
||||
)
|
||||
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
|
||||
y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])
|
||||
compile_metrics.build(y_true, y_pred)
|
||||
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
|
||||
result = compile_metrics.result()
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertAllClose(result["acc"], 0.333333)
|
||||
self.assertAllClose(result["accuracy"], 0.333333)
|
||||
|
||||
|
||||
class TestCompileLoss(testing.TestCase):
|
||||
def test_single_output_case(self):
|
||||
|
@ -60,10 +60,8 @@ class EpochIterator:
|
||||
steps_per_epoch=None,
|
||||
shuffle=False,
|
||||
class_weight=None,
|
||||
steps_per_execution=1,
|
||||
):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.steps_per_execution = steps_per_execution
|
||||
if steps_per_epoch:
|
||||
self._current_iterator = None
|
||||
self._insufficient_data = False
|
||||
@ -106,7 +104,9 @@ class EpochIterator:
|
||||
"sample_weights", "the sample weights", "PyDataset"
|
||||
)
|
||||
elif isinstance(x, types.GeneratorType):
|
||||
self.data_adapter = generator_data_adapter.GeneratorDataAdapter(x)
|
||||
self.data_adapter = generator_data_adapter.GeneratorDataAdapter(
|
||||
x, shuffle=shuffle
|
||||
)
|
||||
if y is not None:
|
||||
raise_unsupported_arg("y", "the targets", "PyDataset")
|
||||
if sample_weight is not None:
|
||||
|
@ -26,7 +26,6 @@ class Trainer:
|
||||
metrics=None,
|
||||
weighted_metrics=None,
|
||||
run_eagerly=False,
|
||||
steps_per_execution=1,
|
||||
jit_compile=True,
|
||||
):
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
@ -50,7 +49,6 @@ class Trainer:
|
||||
self.stop_training = False
|
||||
self.compiled = True
|
||||
self._loss_tracker = metrics_module.Mean(name="loss")
|
||||
self.steps_per_execution = steps_per_execution
|
||||
|
||||
self._compile_config = serialization_lib.SerializableDict(
|
||||
optimizer=optimizer,
|
||||
@ -59,7 +57,6 @@ class Trainer:
|
||||
metrics=metrics,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
steps_per_execution=steps_per_execution,
|
||||
jit_compile=jit_compile,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
@ -8,7 +7,6 @@ from keras_core import losses
|
||||
from keras_core import metrics
|
||||
from keras_core import optimizers
|
||||
from keras_core import testing
|
||||
from keras_core.callbacks.callback import Callback
|
||||
|
||||
if backend.backend() == "jax":
|
||||
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||
@ -215,32 +213,3 @@ class TestTrainer(testing.TestCase):
|
||||
|
||||
def test_predict_flow_jit(self):
|
||||
self._test_predict_flow(run_eagerly=False, jit_compile=True)
|
||||
|
||||
# TODO: Remove the skipif when implemented steps_per_execution for JAX.
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "tensorflow",
|
||||
reason="JAX does not support steps_per_execution yet",
|
||||
)
|
||||
def test_steps_per_execution_steps_count(self):
|
||||
class StepCount(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.count = 0
|
||||
self.batches = [0, 3, 6]
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
assert batch == self.batches[self.count]
|
||||
self.count += 1
|
||||
|
||||
x = np.ones((100, 4))
|
||||
y = np.ones((100, 1))
|
||||
model = ExampleModel(units=1)
|
||||
model.compile(loss="mse", optimizer="adam", steps_per_execution=3)
|
||||
step_count = StepCount()
|
||||
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count])
|
||||
self.assertEqual(step_count.count, 3)
|
||||
|
||||
model_2 = ExampleModel(units=1)
|
||||
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
|
||||
model_2.fit(x=x, y=y, batch_size=16)
|
||||
self.assertAllClose(model.get_weights(), model_2.get_weights())
|
||||
|
Loading…
Reference in New Issue
Block a user