Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-05-17 16:06:25 -07:00
parent f94df3479a
commit b781ef7e53
16 changed files with 353 additions and 110 deletions

@ -481,22 +481,28 @@ class JAXTrainer(base_trainer.Trainer):
)
def _predict_step(trainable_variables, non_trainable_variables, data):
return [
self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
]
outputs, _ = self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
return outputs
def _predict_multi_step(
trainable_variables, non_trainable_variables, data
):
outputs = []
for single_step_data in data:
outputs += _predict_step(
outputs = _predict_step(
trainable_variables, non_trainable_variables, data[:1]
)
for single_step_data in data[1:]:
step_outputs = _predict_step(
trainable_variables,
non_trainable_variables,
[single_step_data],
)
outputs = tf.nest.map_structure(
lambda t1, t2: jax.numpy.concatenate([t1, t2]),
outputs,
step_outputs,
)
return outputs
if self.steps_per_execution > 1:
@ -519,29 +525,30 @@ class JAXTrainer(base_trainer.Trainer):
callbacks.on_predict_begin()
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
multi_step_return_values = predict_step(
batch_outputs = predict_step(
trainable_variables, non_trainable_variables, x
)
for batch_outputs, _ in multi_step_return_values:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -115,13 +115,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def mutli_step_on_iterator(iterator):
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
train_function = mutli_step_on_iterator
train_function = multi_step_on_iterator
else:
train_function = one_step_on_iterator
@ -157,13 +157,13 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def mutli_step_on_iterator(iterator):
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if self.steps_per_execution > 1:
test_function = mutli_step_on_iterator
test_function = multi_step_on_iterator
else:
test_function = one_step_on_iterator
@ -186,9 +186,8 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
def one_step_on_iterator(iterator):
"""Runs a single predict step given a Dataset iterator."""
data = next(iterator)
def one_step_on_data_distributed(data):
data = data[0]
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,)
)
@ -197,20 +196,21 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.distribute_strategy,
reduction=self.distribute_reduction_method,
)
return [outputs]
return outputs
def mutli_step_on_iterator(iterator):
outputs = []
for _ in tf.range(self.steps_per_execution):
outputs += one_step_on_iterator(iterator)
def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
return outputs
if self.steps_per_execution > 1:
# TODO(haifengj): Use multi_step_on_iterator.
# predict_function = mutli_step_on_iterator
predict_function = one_step_on_iterator
predict_function = multi_step_on_data
else:
predict_function = one_step_on_iterator
predict_function = one_step_on_data_distributed
if not self.run_eagerly:
predict_function = tf.function(
@ -431,28 +431,48 @@ class TensorFlowTrainer(base_trainer.Trainer):
model=self,
)
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
def get_data(iterator):
"""Returns data for the next execution."""
data = []
for _ in range(self.steps_per_execution):
try:
single_step_data = next(iterator)
except (StopIteration, tf.errors.OutOfRangeError) as e:
if len(data) > 0:
# Suppress the error when still have remaining data.
return data
else:
# Re-raise the error for
# TFEpochIterator.catch_stop_iteration() to catch when
# no data left.
raise e
data.append(single_step_data)
return data
self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
multi_batch_outputs = self.predict_function(iterator)
for batch_outputs in multi_batch_outputs:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(

@ -103,6 +103,7 @@ from keras_core.layers.reshaping.permute import Permute
from keras_core.layers.reshaping.repeat_vector import RepeatVector
from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional
from keras_core.layers.rnn.conv_lstm1d import ConvLSTM1D
from keras_core.layers.rnn.conv_lstm2d import ConvLSTM2D

@ -0,0 +1,122 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.ZeroPadding3D")
class ZeroPadding3D(Layer):
"""Zero-padding layer for 3D data (spatial or spatio-temporal).
Examples:
>>> input_shape = (1, 1, 2, 2, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> y = keras_core.layers.ZeroPadding3D(padding=2)(x)
>>> y.shape
(1, 5, 6, 6, 3)
Args:
padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding is applied to depth, height,
and width.
- If tuple of 3 ints: interpreted as three different symmetric
padding values for depth, height, and width:
`(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
- If tuple of 3 tuples of 2 ints: interpreted as
`((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
data_format: A string, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
When unspecified, uses `image_data_format` value found in your Keras
config file at `~/.keras/keras.json` (if exists). Defaults to
`"channels_last"`.
Input shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_axis_to_pad, second_axis_to_pad,
third_axis_to_pad)`
Output shape:
5D tensor with shape:
- If `data_format` is `"channels_last"`:
`(batch_size, first_padded_axis, second_padded_axis,
third_axis_to_pad, depth)`
- If `data_format` is `"channels_first"`:
`(batch_size, depth, first_padded_axis, second_padded_axis,
third_axis_to_pad)`
"""
def __init__(
self,
padding=((1, 1), (1, 1), (1, 1)),
data_format=None,
name=None,
dtype=None,
):
super().__init__(name=name, dtype=dtype)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(padding, int):
self.padding = (
(padding, padding),
(padding, padding),
(padding, padding),
)
elif hasattr(padding, "__len__"):
if len(padding) != 3:
raise ValueError(
f"`padding` should have 3 elements. Received: {padding}."
)
dim1_padding = padding[0]
if isinstance(dim1_padding, int):
dim1_padding = (dim1_padding, dim1_padding)
dim2_padding = padding[1]
if isinstance(dim2_padding, int):
dim2_padding = (dim2_padding, dim2_padding)
dim3_padding = padding[2]
if isinstance(dim3_padding, int):
dim3_padding = (dim3_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
self.padding = (dim1_padding, dim2_padding, dim3_padding)
else:
raise ValueError(
"`padding` should be either an int, a tuple of 3 ints "
"(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), "
"or a tuple of 3 tuples of 2 ints "
"((left_dim1_pad, right_dim1_pad),"
" (left_dim2_pad, right_dim2_pad),"
" (left_dim3_pad, right_dim2_pad)). "
f"Received: {padding}."
)
self.input_spec = InputSpec(ndim=5)
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
spatial_dims_offset = 2 if self.data_format == "channels_first" else 1
for index in range(0, 3):
if output_shape[index + spatial_dims_offset] is not None:
output_shape[index + spatial_dims_offset] += (
self.padding[index][0] + self.padding[index][1]
)
return tuple(output_shape)
def call(self, inputs):
if self.data_format == "channels_first":
all_dims_padding = ((0, 0), (0, 0), *self.padding)
else:
all_dims_padding = ((0, 0), *self.padding, (0, 0))
return ops.pad(inputs, all_dims_padding)
def get_config(self):
config = {"padding": self.padding, "data_format": self.data_format}
base_config = super().get_config()
return {**base_config, **config}

@ -0,0 +1,84 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("channels_first", "channels_first"), ("channels_last", "channels_last")
)
def test_zero_padding_3d(self, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=((1, 2), (3, 4), (0, 2)), data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 1:-2, 3:-4, 0:-2], inputs)
else:
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
for index in [0, 1, 2, -1, -2, -3, -4]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
for index in [-1, -2]:
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 1:-2, 3:-4, 0:-2, :], inputs)
@parameterized.product(
(
{"padding": ((2, 2), (2, 2), (2, 2))}, # 3 tuples
{"padding": (2, 2, 2)}, # 1 tuple
{"padding": 2}, # 1 int
),
(
{"data_format": "channels_first"},
{"data_format": "channels_last"},
),
)
def test_zero_padding_3d_with_same_padding(self, padding, data_format):
inputs = np.random.rand(1, 2, 3, 4, 5)
outputs = layers.ZeroPadding3D(
padding=padding, data_format=data_format
)(inputs)
if data_format == "channels_first":
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, :, :, :, index], 0.0)
self.assertAllClose(outputs[:, :, 2:-2, 2:-2, 2:-2], inputs)
else:
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, index, :, :, :], 0.0)
self.assertAllClose(outputs[:, :, index, :, :], 0.0)
self.assertAllClose(outputs[:, :, :, index, :], 0.0)
self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_zero_padding_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5))
permuted = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)
self.assertEqual(permuted.shape, (None, 5, 10, 15, 5))
def test_zero_padding_3d_errors_if_padding_argument_invalid(self):
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1,))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding=(1, 2, 3, 4))
with self.assertRaises(ValueError):
layers.ZeroPadding3D(padding="1")

@ -21,15 +21,15 @@ class Adadelta(optimizer.Optimizer):
learning rate can be set, as in most other Keras optimizers.
Args:
learning_rate: Initial value for the learning rate: a floating
point value, Defaults to 0.001. Note that `Adadelta` tends
to benefit from higher initial learning rate values compared to
other optimizers.
To match the exact form in the original paper, use 1.0.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adadelta`
tends to benefit from higher initial learning rate values compared
to other optimizers. To match the exact form in the original paper,
use 1.0.
rho: A floating point value. The decay rate. Defaults to 0.95.
epsilon: Small floating point value used to maintain numerical
stability.
Defaults to 1e-7.
epsilon: Small floating point value for maintaining numerical stability.
{{base_optimizer_keyword_args}}
Reference:

@ -17,8 +17,10 @@ class Adafactor(optimizer.Optimizer):
last 2 dimensions separately in its accumulator variables.
Args:
learning_rate: Initial value for the learning rate:
a floating point value, Defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator
away from 0.
@ -129,9 +131,7 @@ class Adafactor(optimizer.Optimizer):
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
one = ops.cast(1.0, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
if self.relative_step: # TODO: add learning_rate_schedule logic
# If `relative_step=True` and learning rate is a constant, we
# apply the relative step algorithm.
if not callable(self._learning_rate) and self.relative_step:
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
r = self._r[self._get_variable_index(variable)]

@ -14,18 +14,16 @@ class Adagrad(optimizer.Optimizer):
the smaller the updates.
Args:
learning_rate: Initial value for the learning rate:
a floating point value,
Defaults to 0.001.
Note that `Adagrad` tends to benefit from higher initial
learning rate values compared to other optimizers.
To match the exact form in the original paper, use 1.0.
initial_accumulator_value: Floating point value.
Starting value for the accumulators (per-parameter
momentum values).
Must be non-negative.
epsilon: Small floating point value used to maintain
numerical stability.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001. Note that `Adagrad` tends
to benefit from higher initial learning rate values compared to
other optimizers. To match the exact form in the original paper,
use 1.0.
initial_accumulator_value: Floating point value. Starting value for the
accumulators (per-parameter momentum values). Must be non-negative.
epsilon: Small floating point value for maintaining numerical stability.
{{base_optimizer_keyword_args}}
Reference:

@ -18,9 +18,10 @@ class Adam(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates. Defaults to

@ -34,9 +34,10 @@ class Adamax(optimizer.Optimizer):
```
Args:
learning_rate: A floating point value, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor. The exponential decay
rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential decay

@ -21,9 +21,10 @@ class AdamW(optimizer.Optimizer):
data/parameters*".
Args:
learning_rate: A floating point value or a callable that takes no
arguments and returns the actual value to use. The learning rate.
Defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -52,9 +52,10 @@ class Ftrl(optimizer.Optimizer):
is replaced with a gradient with shrinkage.
Args:
learning_rate: A floating point value, or a callable that
takes no arguments and returns the actual value to use. The learning
rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero
for a fixed learning rate.

@ -12,9 +12,10 @@ class Nadam(optimizer.Optimizer):
Nesterov momentum.
Args:
learning_rate: A floating point value or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to `0.001`.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates.

@ -18,8 +18,10 @@ class RMSprop(optimizer.Optimizer):
gradients, and uses that average to estimate the variance.
Args:
learning_rate: Initial value for the learning rate: a floating point
value, defaults to 0.001.
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
rho: float, defaults to 0.9. Discounting factor for the old gradients.
momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the
momentum value, with a decay rate equals to `1 - momentum`.

@ -28,16 +28,16 @@ class SGD(optimizer.Optimizer):
```
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to 0.001.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0, i.e.,
vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.01.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0,
i.e., vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}
"""
def __init__(

@ -228,6 +228,7 @@ class TestTrainer(testing.TestCase):
x = np.ones((100, 4))
y = np.ones((100, 1))
batch_size = 16
model = ExampleModel(units=1)
model.compile(
loss="mse",
@ -240,8 +241,11 @@ class TestTrainer(testing.TestCase):
model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=16, verbose=0)
model_2.fit(x=x, y=y, batch_size=batch_size, verbose=0)
self.assertAllClose(model.get_weights(), model_2.get_weights())
self.assertAllClose(model.predict(x), model_2.predict(x))
self.assertAllClose(
model.predict(x, batch_size=batch_size),
model_2.predict(x, batch_size=batch_size),
)
self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))