From 038d7bb200c0722bbffdfe14ea5d72353b63096e Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Thu, 27 Apr 2023 16:47:34 -0700 Subject: [PATCH] Add the EarlyStopping callback (#44) * add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports --- keras_core/callbacks/__init__.py | 1 + keras_core/callbacks/early_stopping.py | 185 +++++++++++++++++ keras_core/callbacks/early_stopping_test.py | 207 ++++++++++++++++++++ keras_core/utils/numerical_utils.py | 61 ++++++ keras_core/utils/numerical_utils_test.py | 54 +++++ setup.cfg | 1 + 6 files changed, 509 insertions(+) create mode 100644 keras_core/callbacks/early_stopping.py create mode 100644 keras_core/callbacks/early_stopping_test.py create mode 100644 keras_core/utils/numerical_utils_test.py diff --git a/keras_core/callbacks/__init__.py b/keras_core/callbacks/__init__.py index d6aff49f3..cd03db235 100644 --- a/keras_core/callbacks/__init__.py +++ b/keras_core/callbacks/__init__.py @@ -1,5 +1,6 @@ from keras_core.callbacks.callback import Callback from keras_core.callbacks.callback_list import CallbackList +from keras_core.callbacks.early_stopping import EarlyStopping from keras_core.callbacks.history import History from keras_core.callbacks.lambda_callback import LambdaCallback from keras_core.callbacks.progbar_logger import ProgbarLogger diff --git a/keras_core/callbacks/early_stopping.py b/keras_core/callbacks/early_stopping.py new file mode 100644 index 000000000..ae8b24690 --- /dev/null +++ b/keras_core/callbacks/early_stopping.py @@ -0,0 +1,185 @@ +import warnings + +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.callbacks.callback import Callback +from keras_core.utils import io_utils + + +@keras_core_export("keras_core.callbacks.EarlyStopping") +class EarlyStopping(Callback): + """Stop training when a monitored metric has stopped improving. + + Assuming the goal of a training is to minimize the loss. With this, the + metric to be monitored would be `'loss'`, and mode would be `'min'`. A + `model.fit()` training loop will check at end of every epoch whether + the loss is no longer decreasing, considering the `min_delta` and + `patience` if applicable. Once it's found no longer decreasing, + `model.stop_training` is marked True and the training terminates. + + The quantity to be monitored needs to be available in `logs` dict. + To make it so, pass the loss or metrics at `model.compile()`. + + Args: + monitor: Quantity to be monitored. Defaults to `"val_loss"`. + min_delta: Minimum change in the monitored quantity to qualify as an + improvement, i.e. an absolute change of less than min_delta, will + count as no improvement. Defaults to `0`. + patience: Number of epochs with no improvement after which training will + be stopped. Defaults to `0`. + verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays + messages when the callback takes an action. Defaults to `0`. + mode: One of `{"auto", "min", "max"}`. In `min` mode, training will stop + when the quantity monitored has stopped decreasing; in `"max"` mode + it will stop when the quantity monitored has stopped increasing; in + `"auto"` mode, the direction is automatically inferred from the name + of the monitored quantity. Defaults to `"auto"`. + baseline: Baseline value for the monitored quantity. If not `None`, + training will stop if the model doesn't show improvement over the + baseline. Defaults to `None`. + restore_best_weights: Whether to restore model weights from the epoch + with the best value of the monitored quantity. If `False`, the model + weights obtained at the last step of training are used. An epoch + will be restored regardless of the performance relative to the + `baseline`. If no epoch improves on `baseline`, training will run + for `patience` epochs and restore weights from the best epoch in + that set. Defaults to `False`. + start_from_epoch: Number of epochs to wait before starting to monitor + improvement. This allows for a warm-up period in which no + improvement is expected and thus training will not be stopped. + Defaults to `0`. + + + Example: + + >>> callback = keras_core.callbacks.EarlyStopping(monitor='loss', + ... patience=3) + >>> # This callback will stop the training when there is no improvement in + >>> # the loss for three consecutive epochs. + >>> model = keras_core.models.Sequential([keras_core.layers.Dense(10)]) + >>> model.compile(keras_core.optimizers.SGD(), loss='mse') + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=10, batch_size=1, callbacks=[callback], + ... verbose=0) + >>> len(history.history['loss']) # Only 4 epochs are run. + 4 + """ + + def __init__( + self, + monitor="val_loss", + min_delta=0, + patience=0, + verbose=0, + mode="auto", + baseline=None, + restore_best_weights=False, + start_from_epoch=0, + ): + super().__init__() + + self.monitor = monitor + self.patience = patience + self.verbose = verbose + self.baseline = baseline + self.min_delta = abs(min_delta) + self.wait = 0 + self.stopped_epoch = 0 + self.restore_best_weights = restore_best_weights + self.best_weights = None + self.start_from_epoch = start_from_epoch + + if mode not in ["auto", "min", "max"]: + warnings.warn( + f"EarlyStopping mode {mode} is unknown, fallback to auto mode.", + stacklevel=2, + ) + mode = "auto" + + if mode == "min": + self.monitor_op = ops.less + elif mode == "max": + self.monitor_op = ops.greater + else: + if ( + self.monitor.endswith("acc") + or self.monitor.endswith("accuracy") + or self.monitor.endswith("auc") + ): + self.monitor_op = ops.greater + else: + self.monitor_op = ops.less + + if self.monitor_op == ops.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + def on_train_begin(self, logs=None): + # Allow instances to be re-used + self.wait = 0 + self.stopped_epoch = 0 + self.best = ( + float("inf") if self.monitor_op == ops.less else -float("inf") + ) + self.best_weights = None + self.best_epoch = 0 + + def on_epoch_end(self, epoch, logs=None): + current = self.get_monitor_value(logs) + if current is None or epoch < self.start_from_epoch: + # If no monitor value exists or still in initial warm-up stage. + return + if self.restore_best_weights and self.best_weights is None: + # Restore the weights after first epoch if no progress is ever made. + self.best_weights = self.model.get_weights() + + self.wait += 1 + if self._is_improvement(current, self.best): + self.best = current + self.best_epoch = epoch + if self.restore_best_weights: + self.best_weights = self.model.get_weights() + # Only restart wait if we beat both the baseline and our previous + # best. + if self.baseline is None or self._is_improvement( + current, self.baseline + ): + self.wait = 0 + return + + # Only check after the first epoch. + if self.wait >= self.patience and epoch > 0: + self.stopped_epoch = epoch + self.model.stop_training = True + if self.restore_best_weights and self.best_weights is not None: + if self.verbose > 0: + io_utils.print_msg( + "Restoring model weights from " + "the end of the best epoch: " + f"{self.best_epoch + 1}." + ) + self.model.set_weights(self.best_weights) + + def on_train_end(self, logs=None): + if self.stopped_epoch > 0 and self.verbose > 0: + io_utils.print_msg( + f"Epoch {self.stopped_epoch + 1}: early stopping" + ) + + def get_monitor_value(self, logs): + logs = logs or {} + monitor_value = logs.get(self.monitor) + if monitor_value is None: + warnings.warn( + ( + f"Early stopping conditioned on metric `{self.monitor}` " + "which is not available. " + f"Available metrics are: {','.join(list(logs.keys()))}" + ), + stacklevel=2, + ) + return monitor_value + + def _is_improvement(self, monitor_value, reference_value): + return self.monitor_op(monitor_value - self.min_delta, reference_value) diff --git a/keras_core/callbacks/early_stopping_test.py b/keras_core/callbacks/early_stopping_test.py new file mode 100644 index 000000000..0e4a2ba40 --- /dev/null +++ b/keras_core/callbacks/early_stopping_test.py @@ -0,0 +1,207 @@ +import numpy as np + +from keras_core import callbacks +from keras_core import layers +from keras_core import models +from keras_core import testing + + +class EarlyStoppingTest(testing.TestCase): + def test_early_stopping(self): + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + x_test = np.random.random((10, 5)) + y_test = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile( + loss="mae", + optimizer="adam", + metrics=["mse"], + ) + + cases = [ + ("max", "val_mse"), + ("min", "val_loss"), + ("auto", "val_mse"), + ("auto", "loss"), + ("unknown", "unknown"), + ] + for mode, monitor in cases: + patience = 0 + cbks = [ + callbacks.EarlyStopping( + patience=patience, monitor=monitor, mode=mode + ) + ] + model.fit( + x_train, + y_train, + batch_size=5, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=5, + verbose=0, + ) + + def test_early_stopping_patience(self): + cases = [0, 1, 2, 3] + losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5] + + for patience in cases: + stopper = callbacks.EarlyStopping(monitor="loss", patience=patience) + stopper.model = models.Sequential() + stopper.model.compile(loss="mse", optimizer="sgd") + stopper.on_train_begin() + + for epoch, loss in enumerate(losses): + stopper.on_epoch_end(epoch=epoch, logs={"loss": loss}) + if stopper.model.stop_training: + break + + self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2) + + def test_early_stopping_reuse(self): + patience = 3 + data = np.random.random((100, 1)) + labels = np.where(data > 0.5, 1, 0) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile( + optimizer="sgd", + loss="mae", + metrics=["mse"], + ) + weights = model.get_weights() + + # This should allow training to go for at least `patience` epochs + model.set_weights(weights) + + stopper = callbacks.EarlyStopping(monitor="mse", patience=patience) + hist = model.fit( + data, labels, callbacks=[stopper], verbose=0, epochs=20 + ) + assert len(hist.epoch) >= patience + + def test_early_stopping_with_baseline(self): + baseline = 0.6 + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile(optimizer="sgd", loss="mae", metrics=["mse"]) + + patience = 3 + stopper = callbacks.EarlyStopping( + monitor="mse", patience=patience, baseline=baseline + ) + hist = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + assert len(hist.epoch) >= patience + + def test_early_stopping_final_weights_when_restoring_model_weights(self): + class DummyModel: + def __init__(self): + self.stop_training = False + self.weights = -1 + + def get_weights(self): + return self.weights + + def set_weights(self, weights): + self.weights = weights + + def set_weight_to_epoch(self, epoch): + self.weights = epoch + + early_stop = callbacks.EarlyStopping( + monitor="val_loss", patience=2, restore_best_weights=True + ) + early_stop.model = DummyModel() + losses = [0.2, 0.15, 0.1, 0.11, 0.12] + # The best configuration is in the epoch 2 (loss = 0.1000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]}) + if early_stop.model.stop_training: + break + # The best configuration is in epoch 2 (loss = 0.1000), + # and while patience = 2, we're restoring the best weights, + # so we end up at the epoch with the best weights, i.e. epoch 2 + self.assertEqual(early_stop.model.get_weights(), 2) + + # Check early stopping when no model beats the baseline. + early_stop = callbacks.EarlyStopping( + monitor="val_loss", + patience=5, + baseline=0.5, + restore_best_weights=True, + ) + early_stop.model = DummyModel() + losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73] + # The best configuration is in the epoch 2 (loss = 0.7000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]}) + if early_stop.model.stop_training: + break + # No epoch improves on the baseline, so we should train for only 5 + # epochs, and restore the second model. + self.assertEqual(epochs_trained, 5) + self.assertEqual(early_stop.model.get_weights(), 2) + + def test_early_stopping_with_start_from_epoch(self): + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile(optimizer="sgd", loss="mae", metrics=["mse"]) + start_from_epoch = 2 + patience = 3 + stopper = callbacks.EarlyStopping( + monitor="mse", + patience=patience, + start_from_epoch=start_from_epoch, + ) + history = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + # Test 'patience' argument functions correctly when used + # in conjunction with 'start_from_epoch'. + self.assertGreaterEqual(len(history.epoch), patience + start_from_epoch) + + start_from_epoch = 2 + patience = 0 + stopper = callbacks.EarlyStopping( + monitor="mse", + patience=patience, + start_from_epoch=start_from_epoch, + ) + history = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + # Test for boundary condition when 'patience' = 0. + self.assertGreaterEqual(len(history.epoch), start_from_epoch) diff --git a/keras_core/utils/numerical_utils.py b/keras_core/utils/numerical_utils.py index e8c6c1a10..2a779d1d9 100644 --- a/keras_core/utils/numerical_utils.py +++ b/keras_core/utils/numerical_utils.py @@ -1,5 +1,8 @@ +import numpy as np + from keras_core import backend from keras_core import operations as ops +from keras_core.api_export import keras_core_export def l2_normalize(x, axis=0): @@ -7,3 +10,61 @@ def l2_normalize(x, axis=0): square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True) l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon))) return ops.multiply(x, l2_norm) + + +@keras_core_export("keras_core.utils.to_categorical") +def to_categorical(x, num_classes=None): + """Converts a class vector (integers) to binary class matrix. + + E.g. for use with `categorical_crossentropy`. + + Args: + x: Array-like with class values to be converted into a matrix + (integers from 0 to `num_classes - 1`). + num_classes: Total number of classes. If `None`, this would be inferred + as `max(x) + 1`. Defaults to `None`. + + Returns: + A binary matrix representation of the input as a NumPy array. The class + axis is placed last. + + Example: + + >>> a = keras_core.utils.to_categorical([0, 1, 2, 3], num_classes=4) + >>> print(a) + [[1. 0. 0. 0.] + [0. 1. 0. 0.] + [0. 0. 1. 0.] + [0. 0. 0. 1.]] + + >>> b = np.array([.9, .04, .03, .03, + ... .3, .45, .15, .13, + ... .04, .01, .94, .05, + ... .12, .21, .5, .17], + ... shape=[4, 4]) + >>> loss = keras_core.backend.categorical_crossentropy(a, b) + >>> print(np.around(loss, 5)) + [0.10536 0.82807 0.1011 1.77196] + + >>> loss = keras_core.backend.categorical_crossentropy(a, a) + >>> print(np.around(loss, 5)) + [0. 0. 0. 0.] + """ + if backend.is_tensor(x): + return backend.nn.one_hot(x, num_classes) + x = np.array(x, dtype="int64") + input_shape = x.shape + + # Shrink the last dimension if the shape is (..., 1). + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + + x = x.reshape(-1) + if not num_classes: + num_classes = np.max(x) + 1 + batch_size = x.shape[0] + categorical = np.zeros((batch_size, num_classes)) + categorical[np.arange(batch_size), x] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical diff --git a/keras_core/utils/numerical_utils_test.py b/keras_core/utils/numerical_utils_test.py new file mode 100644 index 000000000..1f241713f --- /dev/null +++ b/keras_core/utils/numerical_utils_test.py @@ -0,0 +1,54 @@ +import numpy as np +from absl.testing import parameterized + +from keras_core import backend +from keras_core import testing +from keras_core.utils import numerical_utils + +NUM_CLASSES = 5 + + +class TestNumericalUtils(testing.TestCase, parameterized.TestCase): + @parameterized.parameters( + [ + ((1,), (1, NUM_CLASSES)), + ((3,), (3, NUM_CLASSES)), + ((4, 3), (4, 3, NUM_CLASSES)), + ((5, 4, 3), (5, 4, 3, NUM_CLASSES)), + ((3, 1), (3, NUM_CLASSES)), + ((3, 2, 1), (3, 2, NUM_CLASSES)), + ] + ) + def test_to_categorical(self, shape, expected_shape): + label = np.random.randint(0, NUM_CLASSES, shape) + one_hot = numerical_utils.to_categorical(label, NUM_CLASSES) + # Check shape + self.assertEqual(one_hot.shape, expected_shape) + # Make sure there is only one 1 in a row + self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) + # Get original labels back from one hots + self.assertTrue( + np.all(np.argmax(one_hot, -1).reshape(label.shape) == label) + ) + + def test_to_categorial_without_num_classes(self): + label = [0, 2, 5] + one_hot = numerical_utils.to_categorical(label) + self.assertEqual(one_hot.shape, (3, 5 + 1)) + + def test_to_categorical_with_backend_tensor(self): + label = backend.convert_to_tensor(np.array([0, 2, 1, 3, 4])) + expected = backend.convert_to_tensor( + np.array( + [ + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ) + ) + one_hot = numerical_utils.to_categorical(label, NUM_CLASSES) + assert backend.is_tensor(one_hot) + self.assertAllClose(one_hot, expected) diff --git a/setup.cfg b/setup.cfg index 38a59923e..8c763769d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,4 +62,5 @@ per-file-ignores = keras_core/models/functional.py:E501 keras_core/layers/layer.py:E501 keras_core/initializers/random_initializers.py:E501 + keras_core/saving/saving_lib_test.py:E501 max-line-length = 80