444 lines
13 KiB
Python
444 lines
13 KiB
Python
"""
|
|
Title: Writing your own callbacks
|
|
Authors: Rick Chao, Francois Chollet
|
|
Date created: 2019/03/20
|
|
Last modified: 2023/06/25
|
|
Description: Complete guide to writing new Keras callbacks.
|
|
Accelerator: GPU
|
|
"""
|
|
"""
|
|
## Introduction
|
|
|
|
A callback is a powerful tool to customize the behavior of a Keras model during
|
|
training, evaluation, or inference. Examples include `keras.callbacks.TensorBoard`
|
|
to visualize training progress and results with TensorBoard, or
|
|
`keras.callbacks.ModelCheckpoint` to periodically save your model during training.
|
|
|
|
In this guide, you will learn what a Keras callback is, what it can do, and how you can
|
|
build your own. We provide a few demos of simple callback applications to get you
|
|
started.
|
|
"""
|
|
|
|
"""
|
|
## Setup
|
|
"""
|
|
|
|
import numpy as np
|
|
import keras as keras
|
|
|
|
"""
|
|
## Keras callbacks overview
|
|
|
|
All callbacks subclass the `keras.callbacks.Callback` class, and
|
|
override a set of methods called at various stages of training, testing, and
|
|
predicting. Callbacks are useful to get a view on internal states and statistics of
|
|
the model during training.
|
|
|
|
You can pass a list of callbacks (as the keyword argument `callbacks`) to the following
|
|
model methods:
|
|
|
|
- `keras.Model.fit()`
|
|
- `keras.Model.evaluate()`
|
|
- `keras.Model.predict()`
|
|
"""
|
|
|
|
"""
|
|
## An overview of callback methods
|
|
|
|
### Global methods
|
|
|
|
#### `on_(train|test|predict)_begin(self, logs=None)`
|
|
|
|
Called at the beginning of `fit`/`evaluate`/`predict`.
|
|
|
|
#### `on_(train|test|predict)_end(self, logs=None)`
|
|
|
|
Called at the end of `fit`/`evaluate`/`predict`.
|
|
|
|
### Batch-level methods for training/testing/predicting
|
|
|
|
#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`
|
|
|
|
Called right before processing a batch during training/testing/predicting.
|
|
|
|
#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`
|
|
|
|
Called at the end of training/testing/predicting a batch. Within this method, `logs` is
|
|
a dict containing the metrics results.
|
|
|
|
### Epoch-level methods (training only)
|
|
|
|
#### `on_epoch_begin(self, epoch, logs=None)`
|
|
|
|
Called at the beginning of an epoch during training.
|
|
|
|
#### `on_epoch_end(self, epoch, logs=None)`
|
|
|
|
Called at the end of an epoch during training.
|
|
"""
|
|
|
|
"""
|
|
## A basic example
|
|
|
|
Let's take a look at a concrete example. To get started, let's import tensorflow and
|
|
define a simple Sequential Keras model:
|
|
"""
|
|
|
|
|
|
# Define the Keras model to add callbacks to
|
|
def get_model():
|
|
model = keras.Sequential()
|
|
model.add(keras.layers.Dense(1))
|
|
model.compile(
|
|
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
|
|
loss="mean_squared_error",
|
|
metrics=["mean_absolute_error"],
|
|
)
|
|
return model
|
|
|
|
|
|
"""
|
|
Then, load the MNIST data for training and testing from Keras datasets API:
|
|
"""
|
|
|
|
# Load example MNIST data and pre-process it
|
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
|
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
|
|
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
|
|
|
|
# Limit the data to 1000 samples
|
|
x_train = x_train[:1000]
|
|
y_train = y_train[:1000]
|
|
x_test = x_test[:1000]
|
|
y_test = y_test[:1000]
|
|
|
|
"""
|
|
Now, define a simple custom callback that logs:
|
|
|
|
- When `fit`/`evaluate`/`predict` starts & ends
|
|
- When each epoch starts & ends
|
|
- When each training batch starts & ends
|
|
- When each evaluation (test) batch starts & ends
|
|
- When each inference (prediction) batch starts & ends
|
|
"""
|
|
|
|
|
|
class CustomCallback(keras.callbacks.Callback):
|
|
def on_train_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Starting training; got log keys: {}".format(keys))
|
|
|
|
def on_train_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Stop training; got log keys: {}".format(keys))
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"Start epoch {} of training; got log keys: {}".format(epoch, keys)
|
|
)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
keys = list(logs.keys())
|
|
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
|
|
|
|
def on_test_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Start testing; got log keys: {}".format(keys))
|
|
|
|
def on_test_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Stop testing; got log keys: {}".format(keys))
|
|
|
|
def on_predict_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Start predicting; got log keys: {}".format(keys))
|
|
|
|
def on_predict_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
print("Stop predicting; got log keys: {}".format(keys))
|
|
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Training: start of batch {}; got log keys: {}".format(
|
|
batch, keys
|
|
)
|
|
)
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Training: end of batch {}; got log keys: {}".format(batch, keys)
|
|
)
|
|
|
|
def on_test_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Evaluating: start of batch {}; got log keys: {}".format(
|
|
batch, keys
|
|
)
|
|
)
|
|
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Evaluating: end of batch {}; got log keys: {}".format(
|
|
batch, keys
|
|
)
|
|
)
|
|
|
|
def on_predict_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Predicting: start of batch {}; got log keys: {}".format(
|
|
batch, keys
|
|
)
|
|
)
|
|
|
|
def on_predict_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
print(
|
|
"...Predicting: end of batch {}; got log keys: {}".format(
|
|
batch, keys
|
|
)
|
|
)
|
|
|
|
|
|
"""
|
|
Let's try it out:
|
|
"""
|
|
|
|
model = get_model()
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=128,
|
|
epochs=1,
|
|
verbose=0,
|
|
validation_split=0.5,
|
|
callbacks=[CustomCallback()],
|
|
)
|
|
|
|
res = model.evaluate(
|
|
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
|
|
)
|
|
|
|
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
|
|
|
|
"""
|
|
### Usage of `logs` dict
|
|
|
|
The `logs` dict contains the loss value, and all the metrics at the end of a batch or
|
|
epoch. Example includes the loss and mean absolute error.
|
|
"""
|
|
|
|
|
|
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
print(
|
|
"Up to batch {}, the average loss is {:7.2f}.".format(
|
|
batch, logs["loss"]
|
|
)
|
|
)
|
|
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
print(
|
|
"Up to batch {}, the average loss is {:7.2f}.".format(
|
|
batch, logs["loss"]
|
|
)
|
|
)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
print(
|
|
"The average loss for epoch {} is {:7.2f} "
|
|
"and mean absolute error is {:7.2f}.".format(
|
|
epoch, logs["loss"], logs["mean_absolute_error"]
|
|
)
|
|
)
|
|
|
|
|
|
model = get_model()
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=128,
|
|
epochs=2,
|
|
verbose=0,
|
|
callbacks=[LossAndErrorPrintingCallback()],
|
|
)
|
|
|
|
res = model.evaluate(
|
|
x_test,
|
|
y_test,
|
|
batch_size=128,
|
|
verbose=0,
|
|
callbacks=[LossAndErrorPrintingCallback()],
|
|
)
|
|
|
|
"""
|
|
## Usage of `self.model` attribute
|
|
|
|
In addition to receiving log information when one of their methods is called,
|
|
callbacks have access to the model associated with the current round of
|
|
training/evaluation/inference: `self.model`.
|
|
|
|
Here are a few of the things you can do with `self.model` in a callback:
|
|
|
|
- Set `self.model.stop_training = True` to immediately interrupt training.
|
|
- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`),
|
|
such as `self.model.optimizer.learning_rate`.
|
|
- Save the model at period intervals.
|
|
- Record the output of `model.predict()` on a few test samples at the end of each
|
|
epoch, to use as a sanity check during training.
|
|
- Extract visualizations of intermediate features at the end of each epoch, to monitor
|
|
what the model is learning over time.
|
|
- etc.
|
|
|
|
Let's see this in action in a couple of examples.
|
|
"""
|
|
|
|
"""
|
|
## Examples of Keras callback applications
|
|
|
|
### Early stopping at minimum loss
|
|
|
|
This first example shows the creation of a `Callback` that stops training when the
|
|
minimum of loss has been reached, by setting the attribute `self.model.stop_training`
|
|
(boolean). Optionally, you can provide an argument `patience` to specify how many
|
|
epochs we should wait before stopping after having reached a local minimum.
|
|
|
|
`keras.callbacks.EarlyStopping` provides a more complete and general implementation.
|
|
"""
|
|
|
|
|
|
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
|
|
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
|
|
|
|
Arguments:
|
|
patience: Number of epochs to wait after min has been hit. After this
|
|
number of no improvement, training stops.
|
|
"""
|
|
|
|
def __init__(self, patience=0):
|
|
super().__init__()
|
|
self.patience = patience
|
|
# best_weights to store the weights at which the minimum loss occurs.
|
|
self.best_weights = None
|
|
|
|
def on_train_begin(self, logs=None):
|
|
# The number of epoch it has waited when loss is no longer minimum.
|
|
self.wait = 0
|
|
# The epoch the training stops at.
|
|
self.stopped_epoch = 0
|
|
# Initialize the best as infinity.
|
|
self.best = np.Inf
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
current = logs.get("loss")
|
|
if np.less(current, self.best):
|
|
self.best = current
|
|
self.wait = 0
|
|
# Record the best weights if current results is better (less).
|
|
self.best_weights = self.model.get_weights()
|
|
else:
|
|
self.wait += 1
|
|
if self.wait >= self.patience:
|
|
self.stopped_epoch = epoch
|
|
self.model.stop_training = True
|
|
print("Restoring model weights from the end of the best epoch.")
|
|
self.model.set_weights(self.best_weights)
|
|
|
|
def on_train_end(self, logs=None):
|
|
if self.stopped_epoch > 0:
|
|
print(f"Epoch {self.stopped_epoch + 1}: early stopping")
|
|
|
|
|
|
model = get_model()
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=64,
|
|
epochs=30,
|
|
verbose=0,
|
|
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
|
|
)
|
|
|
|
"""
|
|
### Learning rate scheduling
|
|
|
|
In this example, we show how a custom Callback can be used to dynamically change the
|
|
learning rate of the optimizer during the course of training.
|
|
|
|
See `callbacks.LearningRateScheduler` for a more general implementations.
|
|
"""
|
|
|
|
|
|
class CustomLearningRateScheduler(keras.callbacks.Callback):
|
|
"""Learning rate scheduler which sets the learning rate according to schedule.
|
|
|
|
Arguments:
|
|
schedule: a function that takes an epoch index
|
|
(integer, indexed from 0) and current learning rate
|
|
as inputs and returns a new learning rate as output (float).
|
|
"""
|
|
|
|
def __init__(self, schedule):
|
|
super().__init__()
|
|
self.schedule = schedule
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
if not hasattr(self.model.optimizer, "learning_rate"):
|
|
raise ValueError('Optimizer must have a "learning_rate" attribute.')
|
|
# Get the current learning rate from model's optimizer.
|
|
lr = self.model.optimizer.learning_rate
|
|
# Call schedule function to get the scheduled learning rate.
|
|
scheduled_lr = self.schedule(epoch, lr)
|
|
# Set the value back to the optimizer before this epoch starts
|
|
self.model.optimizer.learning_rate = scheduled_lr
|
|
print(
|
|
f"\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}."
|
|
)
|
|
|
|
|
|
LR_SCHEDULE = [
|
|
# (epoch to start, learning rate) tuples
|
|
(3, 0.05),
|
|
(6, 0.01),
|
|
(9, 0.005),
|
|
(12, 0.001),
|
|
]
|
|
|
|
|
|
def lr_schedule(epoch, lr):
|
|
"""Helper function to retrieve the scheduled learning rate based on epoch."""
|
|
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
|
|
return lr
|
|
for i in range(len(LR_SCHEDULE)):
|
|
if epoch == LR_SCHEDULE[i][0]:
|
|
return LR_SCHEDULE[i][1]
|
|
return lr
|
|
|
|
|
|
model = get_model()
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=64,
|
|
epochs=15,
|
|
verbose=0,
|
|
callbacks=[
|
|
LossAndErrorPrintingCallback(),
|
|
CustomLearningRateScheduler(lr_schedule),
|
|
],
|
|
)
|
|
|
|
"""
|
|
### Built-in Keras callbacks
|
|
|
|
Be sure to check out the existing Keras callbacks by
|
|
reading the [API docs](https://keras.io/api/callbacks/).
|
|
Applications include logging to CSV, saving
|
|
the model, visualizing metrics in TensorBoard, and a lot more!
|
|
"""
|