Add the EarlyStopping callback (#44)

* add earlystopping callback

* addressing comments

* address comments

* addressing comments

* remove unused imports
This commit is contained in:
Haifeng Jin 2023-04-27 16:47:34 -07:00 committed by Francois Chollet
parent 1a273b1de6
commit 038d7bb200
6 changed files with 509 additions and 0 deletions

@ -1,5 +1,6 @@
from keras_core.callbacks.callback import Callback from keras_core.callbacks.callback import Callback
from keras_core.callbacks.callback_list import CallbackList from keras_core.callbacks.callback_list import CallbackList
from keras_core.callbacks.early_stopping import EarlyStopping
from keras_core.callbacks.history import History from keras_core.callbacks.history import History
from keras_core.callbacks.lambda_callback import LambdaCallback from keras_core.callbacks.lambda_callback import LambdaCallback
from keras_core.callbacks.progbar_logger import ProgbarLogger from keras_core.callbacks.progbar_logger import ProgbarLogger

@ -0,0 +1,185 @@
import warnings
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils
@keras_core_export("keras_core.callbacks.EarlyStopping")
class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving.
Assuming the goal of a training is to minimize the loss. With this, the
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
`model.fit()` training loop will check at end of every epoch whether
the loss is no longer decreasing, considering the `min_delta` and
`patience` if applicable. Once it's found no longer decreasing,
`model.stop_training` is marked True and the training terminates.
The quantity to be monitored needs to be available in `logs` dict.
To make it so, pass the loss or metrics at `model.compile()`.
Args:
monitor: Quantity to be monitored. Defaults to `"val_loss"`.
min_delta: Minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement. Defaults to `0`.
patience: Number of epochs with no improvement after which training will
be stopped. Defaults to `0`.
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays
messages when the callback takes an action. Defaults to `0`.
mode: One of `{"auto", "min", "max"}`. In `min` mode, training will stop
when the quantity monitored has stopped decreasing; in `"max"` mode
it will stop when the quantity monitored has stopped increasing; in
`"auto"` mode, the direction is automatically inferred from the name
of the monitored quantity. Defaults to `"auto"`.
baseline: Baseline value for the monitored quantity. If not `None`,
training will stop if the model doesn't show improvement over the
baseline. Defaults to `None`.
restore_best_weights: Whether to restore model weights from the epoch
with the best value of the monitored quantity. If `False`, the model
weights obtained at the last step of training are used. An epoch
will be restored regardless of the performance relative to the
`baseline`. If no epoch improves on `baseline`, training will run
for `patience` epochs and restore weights from the best epoch in
that set. Defaults to `False`.
start_from_epoch: Number of epochs to wait before starting to monitor
improvement. This allows for a warm-up period in which no
improvement is expected and thus training will not be stopped.
Defaults to `0`.
Example:
>>> callback = keras_core.callbacks.EarlyStopping(monitor='loss',
... patience=3)
>>> # This callback will stop the training when there is no improvement in
>>> # the loss for three consecutive epochs.
>>> model = keras_core.models.Sequential([keras_core.layers.Dense(10)])
>>> model.compile(keras_core.optimizers.SGD(), loss='mse')
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> len(history.history['loss']) # Only 4 epochs are run.
4
"""
def __init__(
self,
monitor="val_loss",
min_delta=0,
patience=0,
verbose=0,
mode="auto",
baseline=None,
restore_best_weights=False,
start_from_epoch=0,
):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.baseline = baseline
self.min_delta = abs(min_delta)
self.wait = 0
self.stopped_epoch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
self.start_from_epoch = start_from_epoch
if mode not in ["auto", "min", "max"]:
warnings.warn(
f"EarlyStopping mode {mode} is unknown, fallback to auto mode.",
stacklevel=2,
)
mode = "auto"
if mode == "min":
self.monitor_op = ops.less
elif mode == "max":
self.monitor_op = ops.greater
else:
if (
self.monitor.endswith("acc")
or self.monitor.endswith("accuracy")
or self.monitor.endswith("auc")
):
self.monitor_op = ops.greater
else:
self.monitor_op = ops.less
if self.monitor_op == ops.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = (
float("inf") if self.monitor_op == ops.less else -float("inf")
)
self.best_weights = None
self.best_epoch = 0
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None or epoch < self.start_from_epoch:
# If no monitor value exists or still in initial warm-up stage.
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first epoch if no progress is ever made.
self.best_weights = self.model.get_weights()
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
self.best_epoch = epoch
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
# Only restart wait if we beat both the baseline and our previous
# best.
if self.baseline is None or self._is_improvement(
current, self.baseline
):
self.wait = 0
return
# Only check after the first epoch.
if self.wait >= self.patience and epoch > 0:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights and self.best_weights is not None:
if self.verbose > 0:
io_utils.print_msg(
"Restoring model weights from "
"the end of the best epoch: "
f"{self.best_epoch + 1}."
)
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
io_utils.print_msg(
f"Epoch {self.stopped_epoch + 1}: early stopping"
)
def get_monitor_value(self, logs):
logs = logs or {}
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn(
(
f"Early stopping conditioned on metric `{self.monitor}` "
"which is not available. "
f"Available metrics are: {','.join(list(logs.keys()))}"
),
stacklevel=2,
)
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)

@ -0,0 +1,207 @@
import numpy as np
from keras_core import callbacks
from keras_core import layers
from keras_core import models
from keras_core import testing
class EarlyStoppingTest(testing.TestCase):
def test_early_stopping(self):
x_train = np.random.random((10, 5))
y_train = np.random.random((10, 1))
x_test = np.random.random((10, 5))
y_test = np.random.random((10, 1))
model = models.Sequential(
(
layers.Dense(1, activation="relu"),
layers.Dense(1, activation="relu"),
)
)
model.compile(
loss="mae",
optimizer="adam",
metrics=["mse"],
)
cases = [
("max", "val_mse"),
("min", "val_loss"),
("auto", "val_mse"),
("auto", "loss"),
("unknown", "unknown"),
]
for mode, monitor in cases:
patience = 0
cbks = [
callbacks.EarlyStopping(
patience=patience, monitor=monitor, mode=mode
)
]
model.fit(
x_train,
y_train,
batch_size=5,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=5,
verbose=0,
)
def test_early_stopping_patience(self):
cases = [0, 1, 2, 3]
losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5]
for patience in cases:
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
stopper.model = models.Sequential()
stopper.model.compile(loss="mse", optimizer="sgd")
stopper.on_train_begin()
for epoch, loss in enumerate(losses):
stopper.on_epoch_end(epoch=epoch, logs={"loss": loss})
if stopper.model.stop_training:
break
self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2)
def test_early_stopping_reuse(self):
patience = 3
data = np.random.random((100, 1))
labels = np.where(data > 0.5, 1, 0)
model = models.Sequential(
(
layers.Dense(1, activation="relu"),
layers.Dense(1, activation="relu"),
)
)
model.compile(
optimizer="sgd",
loss="mae",
metrics=["mse"],
)
weights = model.get_weights()
# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
stopper = callbacks.EarlyStopping(monitor="mse", patience=patience)
hist = model.fit(
data, labels, callbacks=[stopper], verbose=0, epochs=20
)
assert len(hist.epoch) >= patience
def test_early_stopping_with_baseline(self):
baseline = 0.6
x_train = np.random.random((10, 5))
y_train = np.random.random((10, 1))
model = models.Sequential(
(
layers.Dense(1, activation="relu"),
layers.Dense(1, activation="relu"),
)
)
model.compile(optimizer="sgd", loss="mae", metrics=["mse"])
patience = 3
stopper = callbacks.EarlyStopping(
monitor="mse", patience=patience, baseline=baseline
)
hist = model.fit(
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
)
assert len(hist.epoch) >= patience
def test_early_stopping_final_weights_when_restoring_model_weights(self):
class DummyModel:
def __init__(self):
self.stop_training = False
self.weights = -1
def get_weights(self):
return self.weights
def set_weights(self, weights):
self.weights = weights
def set_weight_to_epoch(self, epoch):
self.weights = epoch
early_stop = callbacks.EarlyStopping(
monitor="val_loss", patience=2, restore_best_weights=True
)
early_stop.model = DummyModel()
losses = [0.2, 0.15, 0.1, 0.11, 0.12]
# The best configuration is in the epoch 2 (loss = 0.1000).
epochs_trained = 0
early_stop.on_train_begin()
for epoch in range(len(losses)):
epochs_trained += 1
early_stop.model.set_weight_to_epoch(epoch=epoch)
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
if early_stop.model.stop_training:
break
# The best configuration is in epoch 2 (loss = 0.1000),
# and while patience = 2, we're restoring the best weights,
# so we end up at the epoch with the best weights, i.e. epoch 2
self.assertEqual(early_stop.model.get_weights(), 2)
# Check early stopping when no model beats the baseline.
early_stop = callbacks.EarlyStopping(
monitor="val_loss",
patience=5,
baseline=0.5,
restore_best_weights=True,
)
early_stop.model = DummyModel()
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
# The best configuration is in the epoch 2 (loss = 0.7000).
epochs_trained = 0
early_stop.on_train_begin()
for epoch in range(len(losses)):
epochs_trained += 1
early_stop.model.set_weight_to_epoch(epoch=epoch)
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
if early_stop.model.stop_training:
break
# No epoch improves on the baseline, so we should train for only 5
# epochs, and restore the second model.
self.assertEqual(epochs_trained, 5)
self.assertEqual(early_stop.model.get_weights(), 2)
def test_early_stopping_with_start_from_epoch(self):
x_train = np.random.random((10, 5))
y_train = np.random.random((10, 1))
model = models.Sequential(
(
layers.Dense(1, activation="relu"),
layers.Dense(1, activation="relu"),
)
)
model.compile(optimizer="sgd", loss="mae", metrics=["mse"])
start_from_epoch = 2
patience = 3
stopper = callbacks.EarlyStopping(
monitor="mse",
patience=patience,
start_from_epoch=start_from_epoch,
)
history = model.fit(
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
)
# Test 'patience' argument functions correctly when used
# in conjunction with 'start_from_epoch'.
self.assertGreaterEqual(len(history.epoch), patience + start_from_epoch)
start_from_epoch = 2
patience = 0
stopper = callbacks.EarlyStopping(
monitor="mse",
patience=patience,
start_from_epoch=start_from_epoch,
)
history = model.fit(
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
)
# Test for boundary condition when 'patience' = 0.
self.assertGreaterEqual(len(history.epoch), start_from_epoch)

@ -1,5 +1,8 @@
import numpy as np
from keras_core import backend from keras_core import backend
from keras_core import operations as ops from keras_core import operations as ops
from keras_core.api_export import keras_core_export
def l2_normalize(x, axis=0): def l2_normalize(x, axis=0):
@ -7,3 +10,61 @@ def l2_normalize(x, axis=0):
square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True) square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon))) l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon)))
return ops.multiply(x, l2_norm) return ops.multiply(x, l2_norm)
@keras_core_export("keras_core.utils.to_categorical")
def to_categorical(x, num_classes=None):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with `categorical_crossentropy`.
Args:
x: Array-like with class values to be converted into a matrix
(integers from 0 to `num_classes - 1`).
num_classes: Total number of classes. If `None`, this would be inferred
as `max(x) + 1`. Defaults to `None`.
Returns:
A binary matrix representation of the input as a NumPy array. The class
axis is placed last.
Example:
>>> a = keras_core.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> print(a)
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
>>> b = np.array([.9, .04, .03, .03,
... .3, .45, .15, .13,
... .04, .01, .94, .05,
... .12, .21, .5, .17],
... shape=[4, 4])
>>> loss = keras_core.backend.categorical_crossentropy(a, b)
>>> print(np.around(loss, 5))
[0.10536 0.82807 0.1011 1.77196]
>>> loss = keras_core.backend.categorical_crossentropy(a, a)
>>> print(np.around(loss, 5))
[0. 0. 0. 0.]
"""
if backend.is_tensor(x):
return backend.nn.one_hot(x, num_classes)
x = np.array(x, dtype="int64")
input_shape = x.shape
# Shrink the last dimension if the shape is (..., 1).
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
input_shape = tuple(input_shape[:-1])
x = x.reshape(-1)
if not num_classes:
num_classes = np.max(x) + 1
batch_size = x.shape[0]
categorical = np.zeros((batch_size, num_classes))
categorical[np.arange(batch_size), x] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
return categorical

@ -0,0 +1,54 @@
import numpy as np
from absl.testing import parameterized
from keras_core import backend
from keras_core import testing
from keras_core.utils import numerical_utils
NUM_CLASSES = 5
class TestNumericalUtils(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
[
((1,), (1, NUM_CLASSES)),
((3,), (3, NUM_CLASSES)),
((4, 3), (4, 3, NUM_CLASSES)),
((5, 4, 3), (5, 4, 3, NUM_CLASSES)),
((3, 1), (3, NUM_CLASSES)),
((3, 2, 1), (3, 2, NUM_CLASSES)),
]
)
def test_to_categorical(self, shape, expected_shape):
label = np.random.randint(0, NUM_CLASSES, shape)
one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)
# Check shape
self.assertEqual(one_hot.shape, expected_shape)
# Make sure there is only one 1 in a row
self.assertTrue(np.all(one_hot.sum(axis=-1) == 1))
# Get original labels back from one hots
self.assertTrue(
np.all(np.argmax(one_hot, -1).reshape(label.shape) == label)
)
def test_to_categorial_without_num_classes(self):
label = [0, 2, 5]
one_hot = numerical_utils.to_categorical(label)
self.assertEqual(one_hot.shape, (3, 5 + 1))
def test_to_categorical_with_backend_tensor(self):
label = backend.convert_to_tensor(np.array([0, 2, 1, 3, 4]))
expected = backend.convert_to_tensor(
np.array(
[
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
]
)
)
one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)
assert backend.is_tensor(one_hot)
self.assertAllClose(one_hot, expected)

@ -62,4 +62,5 @@ per-file-ignores =
keras_core/models/functional.py:E501 keras_core/models/functional.py:E501
keras_core/layers/layer.py:E501 keras_core/layers/layer.py:E501
keras_core/initializers/random_initializers.py:E501 keras_core/initializers/random_initializers.py:E501
keras_core/saving/saving_lib_test.py:E501
max-line-length = 80 max-line-length = 80