keras/guides/writing_your_own_callbacks.py
2023-09-22 09:29:36 -07:00

444 lines
13 KiB
Python

"""
Title: Writing your own callbacks
Authors: Rick Chao, Francois Chollet
Date created: 2019/03/20
Last modified: 2023/06/25
Description: Complete guide to writing new Keras callbacks.
Accelerator: GPU
"""
"""
## Introduction
A callback is a powerful tool to customize the behavior of a Keras model during
training, evaluation, or inference. Examples include `keras.callbacks.TensorBoard`
to visualize training progress and results with TensorBoard, or
`keras.callbacks.ModelCheckpoint` to periodically save your model during training.
In this guide, you will learn what a Keras callback is, what it can do, and how you can
build your own. We provide a few demos of simple callback applications to get you
started.
"""
"""
## Setup
"""
import numpy as np
import keras as keras
"""
## Keras callbacks overview
All callbacks subclass the `keras.callbacks.Callback` class, and
override a set of methods called at various stages of training, testing, and
predicting. Callbacks are useful to get a view on internal states and statistics of
the model during training.
You can pass a list of callbacks (as the keyword argument `callbacks`) to the following
model methods:
- `keras.Model.fit()`
- `keras.Model.evaluate()`
- `keras.Model.predict()`
"""
"""
## An overview of callback methods
### Global methods
#### `on_(train|test|predict)_begin(self, logs=None)`
Called at the beginning of `fit`/`evaluate`/`predict`.
#### `on_(train|test|predict)_end(self, logs=None)`
Called at the end of `fit`/`evaluate`/`predict`.
### Batch-level methods for training/testing/predicting
#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`
Called right before processing a batch during training/testing/predicting.
#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`
Called at the end of training/testing/predicting a batch. Within this method, `logs` is
a dict containing the metrics results.
### Epoch-level methods (training only)
#### `on_epoch_begin(self, epoch, logs=None)`
Called at the beginning of an epoch during training.
#### `on_epoch_end(self, epoch, logs=None)`
Called at the end of an epoch during training.
"""
"""
## A basic example
Let's take a look at a concrete example. To get started, let's import tensorflow and
define a simple Sequential Keras model:
"""
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
"""
Then, load the MNIST data for training and testing from Keras datasets API:
"""
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
"""
Now, define a simple custom callback that logs:
- When `fit`/`evaluate`/`predict` starts & ends
- When each epoch starts & ends
- When each training batch starts & ends
- When each evaluation (test) batch starts & ends
- When each inference (prediction) batch starts & ends
"""
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print(
"Start epoch {} of training; got log keys: {}".format(epoch, keys)
)
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Training: start of batch {}; got log keys: {}".format(
batch, keys
)
)
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Training: end of batch {}; got log keys: {}".format(batch, keys)
)
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Evaluating: start of batch {}; got log keys: {}".format(
batch, keys
)
)
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Evaluating: end of batch {}; got log keys: {}".format(
batch, keys
)
)
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Predicting: start of batch {}; got log keys: {}".format(
batch, keys
)
)
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print(
"...Predicting: end of batch {}; got log keys: {}".format(
batch, keys
)
)
"""
Let's try it out:
"""
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
"""
### Usage of `logs` dict
The `logs` dict contains the loss value, and all the metrics at the end of a batch or
epoch. Example includes the loss and mean absolute error.
"""
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(
batch, logs["loss"]
)
)
def on_test_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(
batch, logs["loss"]
)
)
def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
"""
## Usage of `self.model` attribute
In addition to receiving log information when one of their methods is called,
callbacks have access to the model associated with the current round of
training/evaluation/inference: `self.model`.
Here are a few of the things you can do with `self.model` in a callback:
- Set `self.model.stop_training = True` to immediately interrupt training.
- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`),
such as `self.model.optimizer.learning_rate`.
- Save the model at period intervals.
- Record the output of `model.predict()` on a few test samples at the end of each
epoch, to use as a sanity check during training.
- Extract visualizations of intermediate features at the end of each epoch, to monitor
what the model is learning over time.
- etc.
Let's see this in action in a couple of examples.
"""
"""
## Examples of Keras callback applications
### Early stopping at minimum loss
This first example shows the creation of a `Callback` that stops training when the
minimum of loss has been reached, by setting the attribute `self.model.stop_training`
(boolean). Optionally, you can provide an argument `patience` to specify how many
epochs we should wait before stopping after having reached a local minimum.
`keras.callbacks.EarlyStopping` provides a more complete and general implementation.
"""
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super().__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print(f"Epoch {self.stopped_epoch + 1}: early stopping")
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
"""
### Learning rate scheduling
In this example, we show how a custom Callback can be used to dynamically change the
learning rate of the optimizer during the course of training.
See `callbacks.LearningRateScheduler` for a more general implementations.
"""
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super().__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "learning_rate"):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
# Get the current learning rate from model's optimizer.
lr = self.model.optimizer.learning_rate
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
self.model.optimizer.learning_rate = scheduled_lr
print(
f"\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}."
)
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05),
(6, 0.01),
(9, 0.005),
(12, 0.001),
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
epochs=15,
verbose=0,
callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
"""
### Built-in Keras callbacks
Be sure to check out the existing Keras callbacks by
reading the [API docs](https://keras.io/api/callbacks/).
Applications include logging to CSV, saving
the model, visualizing metrics in TensorBoard, and a lot more!
"""