Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-05-17 16:06:25 -07:00
parent f94df3479a
commit b781ef7e53
16 changed files with 353 additions and 110 deletions

@ -481,22 +481,28 @@ class JAXTrainer(base_trainer.Trainer):
)
def _predict_step(trainable_variables, non_trainable_variables, data):
return [
self.stateless_call(
outputs, _ = self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
]
return outputs
def _predict_multi_step(
trainable_variables, non_trainable_variables, data
):
outputs = []
for single_step_data in data:
outputs += _predict_step(
outputs = _predict_step(
trainable_variables, non_trainable_variables, data[:1]
)
for single_step_data in data[1:]:
step_outputs = _predict_step(
trainable_variables,
non_trainable_variables,
[single_step_data],
)
outputs = tf.nest.map_structure(
lambda t1, t2: jax.numpy.concatenate([t1, t2]),
outputs,
step_outputs,
)
return outputs
if self.steps_per_execution > 1:
@ -519,15 +525,7 @@ class JAXTrainer(base_trainer.Trainer):
callbacks.on_predict_begin()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
multi_step_return_values = predict_step(
trainable_variables, non_trainable_variables, x
)
for batch_outputs, _ in multi_step_return_values:
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
@ -536,12 +534,21 @@ class JAXTrainer(base_trainer.Trainer):
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs = predict_step(
trainable_variables, non_trainable_variables, x
)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -115,13 +115,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def mutli_step_on_iterator(iterator):
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
train_function = mutli_step_on_iterator
train_function = multi_step_on_iterator
else:
train_function = one_step_on_iterator
@ -157,13 +157,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def mutli_step_on_iterator(iterator):
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
test_function = mutli_step_on_iterator
test_function = multi_step_on_iterator
else:
test_function = one_step_on_iterator
@ -186,9 +186,8 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
def one_step_on_iterator(iterator):
"""Runs a single predict step given a Dataset iterator."""
data = next(iterator)
def one_step_on_data_distributed(data):
data = data[0]
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,)
)
@ -197,20 +196,21 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.distribute_strategy,
reduction=self.distribute_reduction_method,
)
return [outputs]
return outputs
def mutli_step_on_iterator(iterator):
outputs = []
for _ in tf.range(self.steps_per_execution):
outputs += one_step_on_iterator(iterator)
def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
return outputs
if self.steps_per_execution > 1:
# TODO(haifengj): Use multi_step_on_iterator.
# predict_function = mutli_step_on_iterator
predict_function = one_step_on_iterator
predict_function = multi_step_on_data
else:
predict_function = one_step_on_iterator
predict_function = one_step_on_data_distributed
if not self.run_eagerly:
predict_function = tf.function(
@ -431,14 +431,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
model=self,
)
self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
multi_batch_outputs = self.predict_function(iterator)
for batch_outputs in multi_batch_outputs:
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
@ -447,12 +440,39 @@ class TensorFlowTrainer(base_trainer.Trainer):
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
def get_data(iterator):
"""Returns data for the next execution."""
data = []
for _ in range(self.steps_per_execution):
try:
single_step_data = next(iterator)
except (StopIteration, tf.errors.OutOfRangeError) as e:
if len(data) > 0:
# Suppress the error when still have remaining data.
return data
else:
# Re-raise the error for
# TFEpochIterator.catch_stop_iteration() to catch when
# no data left.
raise e
data.append(single_step_data)
return data
self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -103,6 +103,7 @@ from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional
from keras_core.layers.rnn.conv_lstm1d import ConvLSTM1D
from keras_core.layers.rnn.conv_lstm2d import ConvLSTM2D

@ -0,0 +1,122 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.ZeroPadding3D")
class ZeroPadding3D(Layer):
"""Zero-padding layer for 3D data (spatial or spatio-temporal).
Examples:
>>> input_shape = (1, 1, 2, 2, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> y = keras_core.layers.ZeroPadding3D(padding=2)(x)
>>> y.shape
(1, 5, 6, 6, 3)
Args:
padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding is applied to depth, height,
and width.
- If tuple of 3 ints: interpreted as three different symmetric
padding values for depth, height, and width:
`(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad)`
Output shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_padded_axis, second_padded_axis,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_padded_axis, second_padded_axis,
third_axis_to_pad)`
"""
def __init__(
self,
padding=((1, 1), (1, 1), (1, 1)),
data_format=None,
name=None,
dtype=None,
):
super().__init__(name=name, dtype=dtype)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(padding, int):
self.padding = (
(padding, padding),
(padding, padding),
(padding, padding),
)
elif hasattr(padding, "__len__"):
if len(padding) != 3:
raise ValueError(
f"`padding` should have 3 elements. Received: {padding}."
)
dim1_padding = padding[0]
if isinstance(dim1_padding, int):
dim1_padding = (dim1_padding, dim1_padding)
dim2_padding = padding[1]
if isinstance(dim2_padding, int):
dim2_padding = (dim2_padding, dim2_padding)
dim3_padding = padding[2]
if isinstance(dim3_padding, int):
dim3_padding = (dim3_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
else:
raise ValueError(
"`padding` should be either an int, a tuple of 3 ints "
"(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), "
"or a tuple of 3 tuples of 2 ints "
"((left_dim1_pad, right_dim1_pad),"
" (left_dim2_pad, right_dim2_pad),"
" (left_dim3_pad, right_dim2_pad)). "
f"Received: {padding}."
)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
spatial_dims_offset = 2 if self.data_format == "channels_first" else 1
for index in range(0, 3):
if output_shape[index + spatial_dims_offset] is not None:
output_shape[index + spatial_dims_offset] += (
self.padding[index][0] + self.padding[index][1]
)
return tuple(output_shape)
def call(self, inputs):
if self.data_format == "channels_first":
all_dims_padding = ((0, 0), (0, 0), *self.padding)
else:
all_dims_padding = ((0, 0), *self.padding, (0, 0))
return ops.pad(inputs, all_dims_padding)
def get_config(self):
config = {"padding": self.padding, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,84 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("channels_first", "channels_first"), ("channels_last", "channels_last")
)
def test_zero_padding_3d(self, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=((1, 2), (3, 4), (0, 2)), data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 1:-2, 3:-4, 0:-2], inputs)
else:
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 1:-2, 3:-4, 0:-2, :], inputs)
@parameterized.product(
(
{"padding": ((2, 2), (2, 2), (2, 2))}, # 3 tuples
{"padding": (2, 2, 2)}, # 1 tuple
{"padding": 2}, # 1 int
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_zero_padding_3d_with_same_padding(self, padding, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=padding, data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 2:-2, 2:-2, 2:-2], inputs)
else:
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_zero_padding_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5))
permuted = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 5, 10, 15, 5))
def test_zero_padding_3d_errors_if_padding_argument_invalid(self):
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1,))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2, 3, 4))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding="1")

@ -21,15 +21,15 @@ class Adadelta(optimizer.Optimizer):
learning rate can be set, as in most other Keras optimizers.
Args:
learning_rate: Initial value for the learning rate: a floating
point value, Defaults to 0.001. Note that `Adadelta` tends
to benefit from higher initial learning rate values compared to
other optimizers.
To match the exact form in the original paper, use 1.0.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adadelta`
tends to benefit from higher initial learning rate values compared
to other optimizers. To match the exact form in the original paper,
use 1.0.
rho: A floating point value. The decay rate. Defaults to 0.95.
epsilon: Small floating point value used to maintain numerical
stability.
Defaults to 1e-7.
epsilon: Small floating point value for maintaining numerical stability.
{{base_optimizer_keyword_args}}
Reference:

@ -17,8 +17,10 @@ class Adafactor(optimizer.Optimizer):
last 2 dimensions separately in its accumulator variables.
Args:
learning_rate: Initial value for the learning rate:
a floating point value, Defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator
away from 0.
@ -129,9 +131,7 @@ class Adafactor(optimizer.Optimizer):
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
one = ops.cast(1.0, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
if self.relative_step: # TODO: add learning_rate_schedule logic
# If `relative_step=True` and learning rate is a constant, we
# apply the relative step algorithm.
if not callable(self._learning_rate) and self.relative_step:
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
r = self._r[self._get_variable_index(variable)]

@ -14,18 +14,16 @@ class Adagrad(optimizer.Optimizer):
the smaller the updates.
Args:
learning_rate: Initial value for the learning rate:
a floating point value,
Defaults to 0.001.
Note that `Adagrad` tends to benefit from higher initial
learning rate values compared to other optimizers.
To match the exact form in the original paper, use 1.0.
initial_accumulator_value: Floating point value.
Starting value for the accumulators (per-parameter
momentum values).
Must be non-negative.
epsilon: Small floating point value used to maintain
numerical stability.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adagrad` tends
to benefit from higher initial learning rate values compared to
other optimizers. To match the exact form in the original paper,
use 1.0.
initial_accumulator_value: Floating point value. Starting value for the
accumulators (per-parameter momentum values). Must be non-negative.
epsilon: Small floating point value for maintaining numerical stability.
{{base_optimizer_keyword_args}}
Reference:

@ -18,9 +18,10 @@ class Adam(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates. Defaults to

@ -34,9 +34,10 @@ class Adamax(optimizer.Optimizer):
```
Args:
learning_rate: A floating point value, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor. The exponential decay
rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential decay

@ -21,9 +21,10 @@ class AdamW(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -52,9 +52,10 @@ class Ftrl(optimizer.Optimizer):
is replaced with a gradient with shrinkage.
Args:
learning_rate: A floating point value, or a callable that
takes no arguments and returns the actual value to use. The learning
rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero
for a fixed learning rate.

@ -12,9 +12,10 @@ class Nadam(optimizer.Optimizer):
Nesterov momentum.
Args:
learning_rate: A floating point value or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -18,8 +18,10 @@ class RMSprop(optimizer.Optimizer):
gradients, and uses that average to estimate the variance.
Args:
learning_rate: Initial value for the learning rate: a floating point
value, defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
rho: float, defaults to 0.9. Discounting factor for the old gradients.
momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the
momentum value, with a decay rate equals to `1 - momentum`.

@ -28,13 +28,13 @@ class SGD(optimizer.Optimizer):
```
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.01.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0, i.e.,
vanilla gradient descent.
the relevant direction and dampens oscillations. Defaults to 0,
i.e., vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}

@ -228,6 +228,7 @@ class TestTrainer(testing.TestCase):
x = np.ones((100, 4))
y = np.ones((100, 1))
batch_size = 16
model = ExampleModel(units=1)
model.compile(
loss="mse",
@ -240,8 +241,11 @@ class TestTrainer(testing.TestCase):
model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=16, verbose=0)
model_2.fit(x=x, y=y, batch_size=batch_size, verbose=0)
self.assertAllClose(model.get_weights(), model_2.get_weights())
self.assertAllClose(model.predict(x), model_2.predict(x))
self.assertAllClose(
model.predict(x, batch_size=batch_size),
model_2.predict(x, batch_size=batch_size),
)
self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))