Enable training tests and fix a range of bugs

This commit is contained in:
Francois Chollet 2023-05-19 11:40:17 -07:00
parent 770cb289f7
commit 427c533005
31 changed files with 160 additions and 159 deletions

@ -162,13 +162,13 @@ def softmax(x, axis=-1):
@keras_core_export("keras_core.activations.elu")
def elu(x, alpha=1.0):
def elu(x):
"""Exponential Linear Unit.
The exponential linear unit (ELU) with `alpha > 0` is define as:
- `x` if `x > 0`
- alpha * `exp(x) - 1` if `x < 0`
- `exp(x) - 1` if `x < 0`
ELUs have negative values which pushes the mean of the activations
closer to zero.
@ -186,7 +186,7 @@ def elu(x, alpha=1.0):
- [Clevert et al., 2016](https://arxiv.org/abs/1511.07289)
"""
return ops.elu(x, alpha=alpha)
return ops.elu(x)
@keras_core_export("keras_core.activations.selu")

@ -56,8 +56,8 @@ def hard_sigmoid(x):
return jnn.hard_sigmoid(x)
def elu(x, alpha=1.0):
return jnn.elu(x, alpha=alpha)
def elu(x):
return jnn.elu(x)
def selu(x):

@ -2,6 +2,8 @@ from jax import lax
from jax import numpy as jnp
from tensorflow import nest
from keras_core.backend.common.stateless_scope import StatelessScope
def rnn(
step_function,
@ -178,12 +180,16 @@ def rnn(
scan_xs = inputs
new_states, outputs = lax.scan(
f=_step,
init=initial_states,
xs=scan_xs,
reverse=go_backwards,
)
with StatelessScope():
# We must use a stateless scope because `scan` will involve
# JAX tracing -- any variable update at this stage would
# be a leak.
new_states, outputs = lax.scan(
f=_step,
init=initial_states,
xs=scan_xs,
reverse=go_backwards,
)
if go_backwards:
outputs = jnp.flip(outputs, axis=0)
last_output = outputs[-1]

@ -57,12 +57,8 @@ def hard_sigmoid(x):
return tf.clip_by_value(x, 0.0, 1.0)
def elu(x, alpha=1.0):
res = tf.nn.elu(x)
if alpha == 1:
return res
else:
return tf.where(x > 0, res, alpha * res)
def elu(x):
return tf.nn.elu(x)
def selu(x):

@ -42,9 +42,10 @@ def max(x, axis=None, keepdims=False, initial=None):
# TensorFlow returns -inf by default for an empty list, but for consistency
# with other backends and the numpy API we want to throw in this case.
size_x = size(x)
tf.assert_greater(
size(x),
tf.constant(0, dtype=tf.int64),
size_x,
tf.constant(0, dtype=size_x.dtype),
message="Cannot compute the max of an empty tensor.",
)

@ -55,7 +55,6 @@ class TensorFlowTrainer(base_trainer.Trainer):
self._loss_tracker.update_state(loss)
# Compute gradients
# TODO: move value conversion to TF
if self.trainable_weights:
trainable_weights = [v.value for v in self.trainable_weights]
gradients = tape.gradient(loss, trainable_weights)
@ -88,7 +87,6 @@ class TensorFlowTrainer(base_trainer.Trainer):
return y_pred
def make_train_function(self, force=False):
# TODO: support tf.distribute and steps_per_execution.
if self.train_function is not None and not force:
return self.train_function
@ -131,10 +129,10 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.train_function = train_function
def make_test_function(self, force=False):
# TODO: support tf.distribute and steps_per_execution.
if self.test_function is not None and not force:
return self.test_function
@tf.autograph.experimental.do_not_convert
def one_step_on_data(data):
"""Runs a single test step on a batch of data."""
return self.test_step(data)
@ -173,10 +171,10 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.test_function = test_function
def make_predict_function(self, force=False):
# TODO: support tf.distribute and steps_per_execution.
if self.predict_function is not None and not force:
return self.predict_function
@tf.autograph.experimental.do_not_convert
def one_step_on_data(data):
"""Runs a predict test step on a batch of data."""
return self.predict_step(data)

@ -47,10 +47,6 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Produce random number based on the uniform distribution.
The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range,
while the upper bound `maxval` is excluded.
Args:
shape: The shape of the random values to generate.
minval: Floats, defaults to 0. Lower bound of the range of
@ -81,10 +77,6 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Produce random number based on the truncated normal distribution.
The values are drawn from a normal distribution with specified mean and
standard deviation, discarding and re-drawing any samples that are more
than two standard deviations from the mean.
Args:
shape: The shape of the random values to generate.
mean: Floats, defaults to 0. Mean of the random values to generate.
@ -103,15 +95,14 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
across multiple calls, use as seed an instance
of `keras_core.backend.SeedGenerator`.
"""
# Take a larger standard normal dist, discard values outside 2 * stddev
# Offset by mean and stddev
x = normal(shape + (4,), mean=0, stddev=1, dtype=dtype, seed=seed)
valid = (x > -2) & (x < 2)
indexes = valid.max(-1, keepdim=True)[1]
trunc_x = torch.empty(shape)
trunc_x.data.copy_(x.gather(-1, indexes).squeeze(-1))
trunc_x.data.mul_(stddev).add_(mean)
return trunc_x
x = torch.empty(shape)
# TODO: setting seed globally via `manual_seed` might create side effects.
if seed is not None:
seed_val, _ = draw_seed(seed)
torch.manual_seed(int(seed_val))
return torch.nn.init.trunc_normal_(
x, mean=mean, std=stddev, a=-stddev * 2, b=stddev * 2
)
def dropout(inputs, rate, noise_shape=None, seed=None):

@ -8,9 +8,11 @@ class Activation(Layer):
"""Applies an activation function to an output.
Args:
activation: Activation function. It could be a callable, or the name of
an activation from the `keras_core.activations` namespace.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
activation: Activation function. It could be
a callable, or the name of an activation
from the `keras_core.activations` namespace.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
Example:

@ -10,22 +10,21 @@ class ELU(Layer):
Formula:
```
f(x) = alpha * (exp(x) - 1.) for x < 0
f(x) = (exp(x) - 1.) for x < 0
f(x) = x for x >= 0
```
Args:
alpha: float, slope of negative section. Defaults to 1.0.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
"""
def __init__(self, alpha=1.0, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.supports_masking = True
def call(self, inputs):
return activations.elu(inputs, alpha=self.alpha)
return activations.elu(inputs)
def compute_output_shape(self, input_shape):
return input_shape

@ -2,7 +2,6 @@ import numpy as np
from keras_core import testing
from keras_core.layers.activations import elu
import tensorflow as tf
class ELUTest(testing.TestCase):
@ -18,12 +17,7 @@ class ELUTest(testing.TestCase):
supports_masking=True,
)
def test_correctness(self):
x = np.random.random((2, 2, 5))
x = np.random.random((2, 5))
elu_layer = elu.ELU()
tf_elu_layer = tf.keras.layers.ELU()
self.assertAllClose(elu_layer(x), tf_elu_layer(x))
elu_layer = elu.ELU(alpha=0.7)
tf_elu_layer = tf.keras.layers.ELU(alpha=0.7)
self.assertAllClose(elu_layer(x), tf_elu_layer(x))
result = elu_layer(x[np.newaxis, :])[0]
self.assertAllClose(result, x, rtol=1e-05)

@ -13,37 +13,46 @@ class PReLU(Layer):
Formula:
``` python
f(x) = alpha * x for x < 0
f(x) = negative_slope * x for x < 0
f(x) = x for x >= 0
```
where `alpha` is a learned array with the same shape as x.
where `negative_slope` is a learned array with the same shape as x.
Args:
alpha_initializer: Initializer function for the weights.
alpha_regularizer: Regularizer for the weights.
alpha_constraint: Constraint for the weights.
shared_axes: The axes along which to share learnable parameters for the
activation function. For example, if the incoming feature maps are
from a 2D convolution with output shape
`(batch, height, width, channels)`, and you wish to share parameters
across space so that each filter only has one set of parameters,
negative_slope_initializer: Initializer function for the weights.
negative_slope_regularizer: Regularizer for the weights.
negative_slope_constraint: Constraint for the weights.
shared_axes: The axes along which to share learnable
parameters for the activation function.
For example, if the incoming feature maps
are from a 2D convolution
with output shape `(batch, height, width, channels)`,
and you wish to share parameters across space
so that each filter only has one set of parameters,
set `shared_axes=[1, 2]`.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
"""
def __init__(
self,
alpha_initializer="Zeros",
alpha_regularizer=None,
alpha_constraint=None,
negative_slope_initializer="Zeros",
negative_slope_regularizer=None,
negative_slope_constraint=None,
shared_axes=None,
**kwargs
):
super().__init__(**kwargs)
self.supports_masking = True
self.alpha_initializer = initializers.get(alpha_initializer)
self.alpha_regularizer = regularizers.get(alpha_regularizer)
self.alpha_constraint = constraints.get(alpha_constraint)
self.negative_slope_initializer = initializers.get(
negative_slope_initializer
)
self.negative_slope_regularizer = regularizers.get(
negative_slope_regularizer
)
self.negative_slope_constraint = constraints.get(
negative_slope_constraint
)
if shared_axes is None:
self.shared_axes = None
elif not isinstance(shared_axes, (list, tuple)):
@ -56,12 +65,12 @@ class PReLU(Layer):
if self.shared_axes is not None:
for i in self.shared_axes:
param_shape[i - 1] = 1
self.alpha = self.add_weight(
self.negative_slope = self.add_weight(
shape=param_shape,
name="alpha",
initializer=self.alpha_initializer,
regularizer=self.alpha_regularizer,
constraint=self.alpha_constraint,
name="negative_slope",
initializer=self.negative_slope_initializer,
regularizer=self.negative_slope_regularizer,
constraint=self.negative_slope_constraint,
)
# Set input spec
axes = {}
@ -74,21 +83,21 @@ class PReLU(Layer):
def call(self, inputs):
pos = activations.relu(inputs)
neg = -self.alpha * activations.relu(-inputs)
neg = -self.negative_slope * activations.relu(-inputs)
return pos + neg
def get_config(self):
config = super().get_config()
config.update(
{
"alpha_initializer": initializers.serialize(
self.alpha_initializer
"negative_slope_initializer": initializers.serialize(
self.negative_slope_initializer
),
"alpha_regularizer": regularizers.serialize(
self.alpha_regularizer
"negative_slope_regularizer": regularizers.serialize(
self.negative_slope_regularizer
),
"alpha_constraint": constraints.serialize(
self.alpha_constraint
"negative_slope_constraint": constraints.serialize(
self.negative_slope_constraint
),
"shared_axes": self.shared_axes,
}

@ -2,7 +2,6 @@ import numpy as np
from keras_core import testing
from keras_core.layers.activations import prelu
import tensorflow as tf
class PReLUTest(testing.TestCase):
@ -10,9 +9,9 @@ class PReLUTest(testing.TestCase):
self.run_layer_test(
prelu.PReLU,
init_kwargs={
"alpha_initializer": "zeros",
"alpha_regularizer": "L1",
"alpha_constraint": "MaxNorm",
"negative_slope_initializer": "zeros",
"negative_slope_regularizer": "L1",
"negative_slope_constraint": "MaxNorm",
"shared_axes": 1,
},
input_shape=(2, 3, 4),
@ -20,25 +19,15 @@ class PReLUTest(testing.TestCase):
)
def test_prelu_correctness(self):
inputs = np.random.randn(2, 10, 5, 3)
prelu_layer = prelu.PReLU(
alpha_initializer="glorot_uniform",
alpha_regularizer="l1",
alpha_constraint="non_neg",
shared_axes=(1, 2),
negative_slope_initializer="glorot_uniform",
negative_slope_regularizer="l1",
negative_slope_constraint="non_neg",
shared_axes=None,
)
tf_prelu_layer = tf.keras.layers.PReLU(
alpha_initializer="glorot_uniform",
alpha_regularizer="l1",
alpha_constraint="non_neg",
shared_axes=(1, 2),
)
prelu_layer.build(inputs.shape)
tf_prelu_layer.build(inputs.shape)
weights = np.random.random((1, 1, 3))
prelu_layer.alpha.assign(weights)
tf_prelu_layer.alpha.assign(weights)
self.assertAllClose(prelu_layer(inputs), tf_prelu_layer(inputs))
test_input = np.random.randn(10, 5)
result = prelu_layer(test_input)
expected_output = np.maximum(
0, test_input
) + prelu_layer.negative_slope.numpy() * np.minimum(0, test_input)
self.assertAllClose(result, expected_output)

@ -17,11 +17,7 @@ class ReLU(Layer):
Example:
``` python
relu_layer = keras_core.layers.activations.ReLU(
max_value=10,
negative_slope=0.5,
threshold=0,
)
relu_layer = relu.ReLU(max_value=10, negative_slope=0.5, threshold=0)
input = np.array([-10, -5, 0.0, 5, 10])
result = relu_layer(input)
# result = [-5. , -2.5, 0. , 5. , 10.]
@ -30,10 +26,12 @@ class ReLU(Layer):
Args:
max_value: Float >= 0. Maximum activation value. None means unlimited.
Defaults to `None`.
negative_slope: Float >= 0. Negative slope coefficient. Defaults to 0.0.
negative_slope: Float >= 0. Negative slope coefficient.
Defaults to 0.0.
threshold: Float >= 0. Threshold value for thresholded activation.
Defaults to 0.0.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
"""
def __init__(

@ -15,7 +15,7 @@ class Softmax(Layer):
```
Example:
>>>softmax_layer = keras_core.layers.activations.Softmax()
>>>softmax_layer = Softmax()
>>>input = np.array([1.0, 2.0, 1.0])
>>>result = softmax_layer(input)
[0.21194157, 0.5761169, 0.21194157]
@ -24,10 +24,11 @@ class Softmax(Layer):
Args:
axis: Integer, or list of Integers, axis along which the softmax
normalization is applied.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
Call arguments:
inputs: The inputs (logits) to the softmax layer.
inputs: The inputs, or logits to the softmax layer.
mask: A boolean mask of the same shape as `inputs`. The mask
specifies 1 to keep and 0 to mask. Defaults to `None`.

@ -20,6 +20,7 @@ class AdditiveAttentionTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
# Sale.
self.run_layer_test(
@ -35,6 +36,7 @@ class AdditiveAttentionTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
def test_attention_correctness(self):

@ -20,6 +20,7 @@ class AttentionTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
# Sale and concat.
self.run_layer_test(
@ -36,6 +37,7 @@ class AttentionTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
def test_attention_correctness(self):

@ -21,6 +21,7 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
self.run_layer_test(
@ -39,6 +40,7 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
@parameterized.named_parameters(
@ -78,6 +80,7 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
)
@parameterized.named_parameters(

@ -24,8 +24,8 @@ class LambdaTest(testing.TestCase):
self.run_layer_test(
layers.Lambda,
init_kwargs={"function": ops.square, "mask": ops.ones((2, 3))},
input_shape=(2, 3),
expected_output_shape=(2, 3),
input_shape=(2, 3, 4),
expected_output_shape=(2, 3, 4),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,

@ -706,7 +706,9 @@ class Layer(Operation):
for x in scope.losses:
if x in self._losses:
scope.losses.remove(x)
self._losses = []
self._losses.clear()
for layer in self._layers:
layer._clear_losses()
def add_metric(self):
# Permanently disabled

@ -49,8 +49,8 @@ def batch_dot(x, y, axes=None):
rank is 1, we reshape it to `(batch_size, 1)`.
"""
x_shape = tuple(ops.shape(x))
y_shape = tuple(ops.shape(y))
x_shape = x.shape
y_shape = y.shape
x_ndim = len(x_shape)
y_ndim = len(y_shape)
@ -301,8 +301,8 @@ class Dot(Merge):
if isinstance(self.axes, int):
if self.axes < 0:
axes = [
self.axes % x1.ndim,
self.axes % x2.ndim,
self.axes % len(x1.shape),
self.axes % len(x2.shape),
]
else:
axes = [self.axes] * 2
@ -310,7 +310,7 @@ class Dot(Merge):
axes = []
for i in range(len(self.axes)):
if self.axes[i] < 0:
axes.append(self.axes[i] % inputs[i].ndim)
axes.append(self.axes[i] % len(inputs[i].shape))
else:
axes.append(self.axes[i])

@ -146,19 +146,18 @@ class GroupNormalization(Layer):
super().build(input_shape)
def call(self, inputs):
input_shape = inputs.shape
reshaped_inputs = self._reshape_into_groups(inputs)
normalized_inputs = self._apply_normalization(
reshaped_inputs, input_shape
reshaped_inputs, inputs.shape
)
return ops.reshape(normalized_inputs, input_shape)
return ops.reshape(normalized_inputs, ops.shape(inputs))
def _reshape_into_groups(self, inputs):
input_shape = inputs.shape
group_shape = [input_shape[i] for i in range(len(input_shape))]
input_shape = ops.shape(inputs)
group_shape = list(inputs.shape)
for i, e in enumerate(group_shape):
if e is None:
group_shape[i] = input_shape[i]
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(self.axis, self.groups)

@ -215,7 +215,7 @@ class LayerNormalization(Layer):
outputs = ops.cast(outputs, input_dtype)
# If some components of the shape got lost due to adjustments, fix that.
outputs = ops.reshape(outputs, input_shape)
outputs = ops.reshape(outputs, ops.shape(inputs))
return outputs

@ -19,6 +19,7 @@ class HashedCrossingTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
run_training_check=False,
)
self.run_layer_test(
layers.HashedCrossing,
@ -30,6 +31,7 @@ class HashedCrossingTest(testing.TestCase):
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
run_training_check=False,
)
def test_correctness(self):

@ -14,6 +14,7 @@ class RandomCropTest(testing.TestCase):
},
input_shape=(2, 3, 4),
supports_masking=False,
run_training_check=False,
)
def test_random_crop_full(self):
@ -34,6 +35,7 @@ class RandomCropTest(testing.TestCase):
input_shape=(12, 8, 16, 3),
expected_output_shape=(12, 8, 8, 3),
supports_masking=False,
run_training_check=False,
)
def test_predicting_with_longer_height(self):
@ -46,6 +48,7 @@ class RandomCropTest(testing.TestCase):
input_shape=(12, 8, 16, 3),
expected_output_shape=(12, 10, 8, 3),
supports_masking=False,
run_training_check=False,
)
def test_predicting_with_longer_width(self):
@ -58,4 +61,5 @@ class RandomCropTest(testing.TestCase):
input_shape=(12, 8, 16, 3),
expected_output_shape=(12, 8, 18, 3),
supports_masking=False,
run_training_check=False,
)

@ -44,6 +44,12 @@ class RNNCellWithDropout(layers.Layer, DropoutRNNCell):
class DropoutRNNCellTest(testing.TestCase):
def test_seed_tracking(self):
cell = RNNCellWithDropout(3, seed=1337)
self.assertEqual(len(cell.non_trainable_variables), 1)
layer = layers.RNN(cell)
self.assertEqual(len(layer.non_trainable_variables), 1)
def test_basics(self):
self.run_layer_test(
layers.RNN,
@ -53,5 +59,6 @@ class DropoutRNNCellTest(testing.TestCase):
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=1,
supports_masking=True,
)

@ -247,9 +247,6 @@ class GRUCell(Layer, DropoutRNNCell):
hh = self.activation(x_h + recurrent_h)
else:
if 0.0 < self.dropout < 1.0:
inputs = inputs * dp_mask[0]
# inputs projected by all gate matrices at once
matrix_x = ops.matmul(inputs, self.kernel)
if self.use_bias:

@ -262,8 +262,6 @@ class LSTMCell(Layer, DropoutRNNCell):
h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
else:
if 0.0 < self.dropout < 1.0:
inputs = inputs * dp_mask[0]
z = ops.matmul(inputs, self.kernel)
z += ops.matmul(h_tm1, self.recurrent_kernel)

@ -239,12 +239,8 @@ def hard_sigmoid(x):
class Elu(Operation):
def __init__(self, alpha=1.0):
super().__init__()
self.alpha = alpha
def call(self, x):
return backend.nn.elu(x, alpha=self.alpha)
return backend.nn.elu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@ -253,10 +249,10 @@ class Elu(Operation):
@keras_core_export(
["keras_core.operations.elu", "keras_core.operations.nn.elu"]
)
def elu(x, alpha=1.0):
def elu(x):
if any_symbolic_tensors((x,)):
return Elu(alpha).symbolic_call(x)
return backend.nn.elu(x, alpha=alpha)
return Elu().symbolic_call(x)
return backend.nn.elu(x)
class Selu(Operation):

@ -637,10 +637,6 @@ class NNOpsCorrectnessTest(testing.TestCase):
knn.elu(x),
[-0.63212055, 0, 1, 2, 3],
)
self.assertAllClose(
knn.elu(x, alpha=0.5),
[-0.31606027, 0, 1, 2, 3],
)
def test_selu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)

@ -19,7 +19,11 @@ class SeedGenerator:
return [seed, 0]
self.state = Variable(
seed_initializer, shape=(2,), dtype="uint32", trainable=False
seed_initializer,
shape=(2,),
dtype="uint32",
trainable=False,
name="seed_generator_state",
)

@ -106,6 +106,7 @@ class TestCase(unittest.TestCase):
supports_masking=None,
expected_mask_shape=None,
custom_objects=None,
run_training_check=True,
):
"""Run basic checks on a layer.
@ -140,6 +141,8 @@ class TestCase(unittest.TestCase):
returned by compute_mask() (only supports 1 shape).
custom_objects: Dict of any custom objects to be
considered during deserialization.
run_training_check: Whether to attempt to train the layer
(if an input shape or input data was provided).
"""
if input_shape is not None and input_data is not None:
raise ValueError(
@ -271,7 +274,9 @@ class TestCase(unittest.TestCase):
model = TestModel(layer)
model.compile(optimizer="sgd", loss="mse", jit_compile=False)
model.fit(np.array(input_data), np.array(output_data))
input_data = nest.map_structure(lambda x: np.array(x), input_data)
output_data = nest.map_structure(lambda x: np.array(x), output_data)
model.fit(input_data, output_data, verbose=0)
# Build test.
if input_shape is not None:
@ -309,8 +314,8 @@ class TestCase(unittest.TestCase):
output_data = layer(input_data, **call_kwargs)
run_output_asserts(layer, output_data, eager=True)
# # Compiled training step - TODO
# run_training_step(layer, input_data, output_data)
if run_training_check:
run_training_step(layer, input_data, output_data)
def create_keras_tensors(input_shape, dtype):