391 lines
17 KiB
Python
391 lines
17 KiB
Python
import os
|
|
import re
|
|
import warnings
|
|
|
|
import numpy as np
|
|
from tensorflow.io import gfile
|
|
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.callbacks.callback import Callback
|
|
from keras_core.utils import file_utils
|
|
from keras_core.utils import io_utils
|
|
|
|
|
|
@keras_core_export("keras_core.callbacks.ModelCheckpoint")
|
|
class ModelCheckpoint(Callback):
|
|
"""Callback to save the Keras model or model weights at some frequency.
|
|
|
|
`ModelCheckpoint` callback is used in conjunction with training using
|
|
`model.fit()` to save a model or weights (in a checkpoint file) at some
|
|
interval, so the model or weights can be loaded later to continue the
|
|
training from the state saved.
|
|
|
|
A few options this callback provides include:
|
|
|
|
- Whether to only keep the model that has achieved the "best performance" so
|
|
far, or whether to save the model at the end of every epoch regardless of
|
|
performance.
|
|
- Definition of "best"; which quantity to monitor and whether it should be
|
|
maximized or minimized.
|
|
- The frequency it should save at. Currently, the callback supports saving
|
|
at the end of every epoch, or after a fixed number of training batches.
|
|
- Whether only weights are saved, or the whole model is saved.
|
|
|
|
Example:
|
|
|
|
```python
|
|
model.compile(loss=..., optimizer=...,
|
|
metrics=['accuracy'])
|
|
|
|
EPOCHS = 10
|
|
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
|
|
model_checkpoint_callback = keras_core.callbacks.ModelCheckpoint(
|
|
filepath=checkpoint_filepath,
|
|
monitor='val_accuracy',
|
|
mode='max',
|
|
save_best_only=True)
|
|
|
|
# Model is saved at the end of every epoch, if it's the best seen so far.
|
|
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
|
|
|
|
# The model (that are considered the best) can be loaded as -
|
|
keras_core.models.load_model(checkpoint_filepath)
|
|
|
|
# Alternatively, one could checkpoint just the model weights as -
|
|
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
|
|
model_checkpoint_callback = keras_core.callbacks.ModelCheckpoint(
|
|
filepath=checkpoint_filepath,
|
|
save_weights_only=True,
|
|
monitor='val_accuracy',
|
|
mode='max',
|
|
save_best_only=True)
|
|
|
|
# Model weights are saved at the end of every epoch, if it's the best seen
|
|
# so far.
|
|
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
|
|
|
|
# The model weights (that are considered the best) can be loaded as -
|
|
model.load_weights(checkpoint_filepath)
|
|
```
|
|
|
|
Args:
|
|
filepath: string or `PathLike`, path to save the model file.
|
|
`filepath` can contain named formatting options,
|
|
which will be filled the value of `epoch` and keys in `logs`
|
|
(passed in `on_epoch_end`).
|
|
The `filepath` name needs to end with `".weights.h5"` when
|
|
`save_weights_only=True` or should end with `".keras"` when
|
|
checkpoint saving the whole model (default).
|
|
For example:
|
|
if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"`, then the
|
|
model checkpoints will be saved with the epoch number and the
|
|
validation loss in the filename. The directory of the filepath
|
|
should not be reused by any other callbacks to avoid conflicts.
|
|
monitor: The metric name to monitor. Typically the metrics are set by
|
|
the `Model.compile` method. Note:
|
|
* Prefix the name with `"val_"` to monitor validation metrics.
|
|
* Use `"loss"` or `"val_loss"` to monitor the model's total loss.
|
|
* If you specify metrics as strings, like `"accuracy"`, pass the
|
|
same string (with or without the `"val_"` prefix).
|
|
* If you pass `metrics.Metric` objects, `monitor` should be set to
|
|
`metric.name`
|
|
* If you're not sure about the metric names you can check the
|
|
contents of the `history.history` dictionary returned by
|
|
`history = model.fit()`
|
|
* Multi-output models set additional prefixes on the metric names.
|
|
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
|
|
displays messages when the callback takes an action.
|
|
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
is considered the "best" and the latest best model according to the
|
|
quantity monitored will not be overwritten. If `filepath` doesn't
|
|
contain formatting options like `{epoch}` then `filepath` will be
|
|
overwritten by each new better model.
|
|
mode: one of {`"auto"`, `"min"`, `"max"`}. If `save_best_only=True`, the
|
|
decision to overwrite the current save file is made based on either
|
|
the maximization or the minimization of the monitored quantity.
|
|
For `val_acc`, this should be `"max"`, for `val_loss` this should be
|
|
`"min"`, etc. In `"auto"` mode, the mode is set to `"max"` if the
|
|
quantities monitored are `"acc"` or start with `"fmeasure"` and are
|
|
set to `"min"` for the rest of the quantities.
|
|
save_weights_only: if True, then only the model's weights will be saved
|
|
(`model.save_weights(filepath)`), else the full model is saved
|
|
(`model.save(filepath)`).
|
|
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
|
|
saves the model after each epoch. When using integer, the callback
|
|
saves the model at end of this many batches. If the `Model` is
|
|
compiled with `steps_per_execution=N`, then the saving criteria will
|
|
be checked every Nth batch. Note that if the saving isn't aligned to
|
|
epochs, the monitored metric may potentially be less reliable (it
|
|
could reflect as little as 1 batch, since the metrics get reset
|
|
every epoch). Defaults to `"epoch"`.
|
|
initial_value_threshold: Floating point initial "best" value of the
|
|
metric to be monitored. Only applies if `save_best_value=True`. Only
|
|
overwrites the model weights already saved if the performance of
|
|
current model is better than this value.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
filepath,
|
|
monitor="val_loss",
|
|
verbose=0,
|
|
save_best_only=False,
|
|
save_weights_only=False,
|
|
mode="auto",
|
|
save_freq="epoch",
|
|
initial_value_threshold=None,
|
|
):
|
|
super().__init__()
|
|
self.monitor = monitor
|
|
self.verbose = verbose
|
|
self.filepath = file_utils.path_to_string(filepath)
|
|
self.save_best_only = save_best_only
|
|
self.save_weights_only = save_weights_only
|
|
self.save_freq = save_freq
|
|
self._batches_seen_since_last_saving = 0
|
|
self._last_batch_seen = 0
|
|
self.best = initial_value_threshold
|
|
|
|
if mode not in ["auto", "min", "max"]:
|
|
warnings.warn(
|
|
f"ModelCheckpoint mode '{mode}' is unknown, "
|
|
"fallback to auto mode.",
|
|
stacklevel=2,
|
|
)
|
|
mode = "auto"
|
|
|
|
if mode == "min":
|
|
self.monitor_op = np.less
|
|
if self.best is None:
|
|
self.best = np.Inf
|
|
elif mode == "max":
|
|
self.monitor_op = np.greater
|
|
if self.best is None:
|
|
self.best = -np.Inf
|
|
else:
|
|
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
|
|
self.monitor_op = np.greater
|
|
if self.best is None:
|
|
self.best = -np.Inf
|
|
else:
|
|
self.monitor_op = np.less
|
|
if self.best is None:
|
|
self.best = np.Inf
|
|
|
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
raise ValueError(
|
|
f"Unrecognized save_freq: {self.save_freq}. "
|
|
"Expected save_freq are 'epoch' or integer values"
|
|
)
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
if self._should_save_on_batch(batch):
|
|
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self._current_epoch = epoch
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
if self.save_freq == "epoch":
|
|
self._save_model(epoch=epoch, batch=None, logs=logs)
|
|
|
|
def _should_save_on_batch(self, batch):
|
|
"""Handles batch-level saving logic, supports steps_per_execution."""
|
|
if self.save_freq == "epoch":
|
|
return False
|
|
if batch <= self._last_batch_seen: # New epoch.
|
|
add_batches = batch + 1 # batches are zero-indexed.
|
|
else:
|
|
add_batches = batch - self._last_batch_seen
|
|
self._batches_seen_since_last_saving += add_batches
|
|
self._last_batch_seen = batch
|
|
|
|
if self._batches_seen_since_last_saving >= self.save_freq:
|
|
self._batches_seen_since_last_saving = 0
|
|
return True
|
|
return False
|
|
|
|
def _save_model(self, epoch, batch, logs):
|
|
"""Saves the model.
|
|
|
|
Args:
|
|
epoch: the epoch this iteration is in.
|
|
batch: the batch this iteration is in. `None` if the `save_freq`
|
|
is set to `"epoch"`.
|
|
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
|
|
"""
|
|
logs = logs or {}
|
|
|
|
filepath = self._get_file_path(epoch, batch, logs)
|
|
# Create host directory if it doesn't exist.
|
|
dirname = os.path.dirname(filepath)
|
|
if dirname and not gfile.exists(dirname):
|
|
gfile.makedirs(dirname)
|
|
|
|
try:
|
|
if self.save_best_only:
|
|
current = logs.get(self.monitor)
|
|
if current is None:
|
|
warnings.warn(
|
|
f"Can save best model only with {self.monitor} "
|
|
"available, skipping.",
|
|
stacklevel=2,
|
|
)
|
|
else:
|
|
if self.monitor_op(current, self.best):
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: {self.monitor} "
|
|
"improved "
|
|
f"from {self.best:.5f} to {current:.5f}, "
|
|
f"saving model to {filepath}"
|
|
)
|
|
self.best = current
|
|
if self.save_weights_only:
|
|
self.model.save_weights(filepath, overwrite=True)
|
|
else:
|
|
self.model.save(filepath, overwrite=True)
|
|
else:
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: "
|
|
f"{self.monitor} did not improve "
|
|
f"from {self.best:.5f}"
|
|
)
|
|
else:
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: saving model to {filepath}"
|
|
)
|
|
if self.save_weights_only:
|
|
self.model.save_weights(filepath, overwrite=True)
|
|
else:
|
|
self.model.save(filepath, overwrite=True)
|
|
except IsADirectoryError: # h5py 3.x
|
|
raise IOError(
|
|
"Please specify a non-directory filepath for "
|
|
"ModelCheckpoint. Filepath used is an existing "
|
|
f"directory: {filepath}"
|
|
)
|
|
except IOError as e: # h5py 2.x
|
|
# `e.errno` appears to be `None` so checking the content of
|
|
# `e.args[0]`.
|
|
if "is a directory" in str(e.args[0]).lower():
|
|
raise IOError(
|
|
"Please specify a non-directory filepath for "
|
|
"ModelCheckpoint. Filepath used is an existing "
|
|
f"directory: f{filepath}"
|
|
)
|
|
# Re-throw the error for any other causes.
|
|
raise e
|
|
|
|
def _get_file_path(self, epoch, batch, logs):
|
|
"""Returns the file path for checkpoint."""
|
|
|
|
try:
|
|
# `filepath` may contain placeholders such as
|
|
# `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between
|
|
# logged metrics and the path's placeholders can cause formatting to
|
|
# fail.
|
|
if batch is None or "batch" in logs:
|
|
file_path = self.filepath.format(epoch=epoch + 1, **logs)
|
|
else:
|
|
file_path = self.filepath.format(
|
|
epoch=epoch + 1, batch=batch + 1, **logs
|
|
)
|
|
except KeyError as e:
|
|
raise KeyError(
|
|
f'Failed to format this callback filepath: "{self.filepath}". '
|
|
f"Reason: {e}"
|
|
)
|
|
return file_path
|
|
|
|
def _checkpoint_exists(self, filepath):
|
|
"""Returns whether the checkpoint `filepath` refers to exists."""
|
|
return gfile.exists(filepath)
|
|
|
|
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
|
"""Returns the most recently modified filepath matching pattern.
|
|
|
|
In the rare case where there are more than one pattern-matching file
|
|
having the same modified time that is most recent among all, return the
|
|
filepath that is largest (by `>` operator, lexicographically using the
|
|
numeric equivalents). This provides a tie-breaker when multiple files
|
|
are most recent. Note that a larger `filepath` can sometimes indicate a
|
|
later time of modification (for instance, when epoch/batch is used as
|
|
formatting option), but not necessarily (when accuracy or loss is used).
|
|
The tie-breaker is put in the logic as best effort to return the most
|
|
recent, and to avoid undeterministic result.
|
|
|
|
Modified time of a file is obtained with `os.path.getmtime()`.
|
|
|
|
This utility function is best demonstrated via an example:
|
|
|
|
```python
|
|
file_pattern = 'batch{batch:02d}epoch{epoch:02d}.keras'
|
|
test_dir = self.get_temp_dir()
|
|
path_pattern = os.path.join(test_dir, file_pattern)
|
|
file_paths = [
|
|
os.path.join(test_dir, file_name) for file_name in
|
|
['batch03epoch02.keras',
|
|
'batch02epoch02.keras', 'batch01epoch01.keras']
|
|
]
|
|
for file_path in file_paths:
|
|
# Write something to each of the files
|
|
...
|
|
self.assertEqual(
|
|
_get_most_recently_modified_file_matching_pattern(path_pattern),
|
|
file_paths[-1])
|
|
```
|
|
|
|
Args:
|
|
pattern: The file pattern that may optionally contain python
|
|
placeholder such as `{epoch:02d}`.
|
|
|
|
Returns:
|
|
The most recently modified file's full filepath matching `pattern`.
|
|
If `pattern` does not contain any placeholder, this returns the
|
|
filepath that exactly matches `pattern`. Returns `None` if no match
|
|
is found.
|
|
"""
|
|
dir_name = os.path.dirname(pattern)
|
|
base_name = os.path.basename(pattern)
|
|
base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$"
|
|
|
|
latest_mod_time = 0
|
|
file_path_with_latest_mod_time = None
|
|
n_file_with_latest_mod_time = 0
|
|
file_path_with_largest_file_name = None
|
|
|
|
if gfile.exists(dir_name):
|
|
for file_name in os.listdir(dir_name):
|
|
# Only consider if `file_name` matches the pattern.
|
|
if re.match(base_name_regex, file_name):
|
|
file_path = os.path.join(dir_name, file_name)
|
|
mod_time = os.path.getmtime(file_path)
|
|
if (
|
|
file_path_with_largest_file_name is None
|
|
or file_path > file_path_with_largest_file_name
|
|
):
|
|
file_path_with_largest_file_name = file_path
|
|
if mod_time > latest_mod_time:
|
|
latest_mod_time = mod_time
|
|
file_path_with_latest_mod_time = file_path
|
|
# In the case a file with later modified time is found,
|
|
# reset the counter for the number of files with latest
|
|
# modified time.
|
|
n_file_with_latest_mod_time = 1
|
|
elif mod_time == latest_mod_time:
|
|
# In the case a file has modified time tied with the
|
|
# most recent, increment the counter for the number of
|
|
# files with latest modified time by 1.
|
|
n_file_with_latest_mod_time += 1
|
|
|
|
if n_file_with_latest_mod_time == 1:
|
|
# Return the sole file that has most recent modified time.
|
|
return file_path_with_latest_mod_time
|
|
else:
|
|
# If there are more than one file having latest modified time,
|
|
# return the file path with the largest file name.
|
|
return file_path_with_largest_file_name
|