import os import re import warnings import numpy as np from tensorflow.io import gfile from keras_core.api_export import keras_core_export from keras_core.callbacks.callback import Callback from keras_core.utils import file_utils from keras_core.utils import io_utils @keras_core_export("keras_core.callbacks.ModelCheckpoint") class ModelCheckpoint(Callback): """Callback to save the Keras model or model weights at some frequency. `ModelCheckpoint` callback is used in conjunction with training using `model.fit()` to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved. A few options this callback provides include: - Whether to only keep the model that has achieved the "best performance" so far, or whether to save the model at the end of every epoch regardless of performance. - Definition of "best"; which quantity to monitor and whether it should be maximized or minimized. - The frequency it should save at. Currently, the callback supports saving at the end of every epoch, or after a fixed number of training batches. - Whether only weights are saved, or the whole model is saved. Example: ```python model.compile(loss=..., optimizer=..., metrics=['accuracy']) EPOCHS = 10 checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras' model_checkpoint_callback = keras_core.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, monitor='val_accuracy', mode='max', save_best_only=True) # Model is saved at the end of every epoch, if it's the best seen so far. model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) # The model (that are considered the best) can be loaded as - keras_core.models.load_model(checkpoint_filepath) # Alternatively, one could checkpoint just the model weights as - checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5' model_checkpoint_callback = keras_core.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only=True) # Model weights are saved at the end of every epoch, if it's the best seen # so far. model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) # The model weights (that are considered the best) can be loaded as - model.load_weights(checkpoint_filepath) ``` Args: filepath: string or `PathLike`, path to save the model file. `filepath` can contain named formatting options, which will be filled the value of `epoch` and keys in `logs` (passed in `on_epoch_end`). The `filepath` name needs to end with `".weights.h5"` when `save_weights_only=True` or should end with `".keras"` when checkpoint saving the whole model (default). For example: if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"`, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The directory of the filepath should not be reused by any other callbacks to avoid conflicts. monitor: The metric name to monitor. Typically the metrics are set by the `Model.compile` method. Note: * Prefix the name with `"val_"` to monitor validation metrics. * Use `"loss"` or `"val_loss"` to monitor the model's total loss. * If you specify metrics as strings, like `"accuracy"`, pass the same string (with or without the `"val_"` prefix). * If you pass `metrics.Metric` objects, `monitor` should be set to `metric.name` * If you're not sure about the metric names you can check the contents of the `history.history` dictionary returned by `history = model.fit()` * Multi-output models set additional prefixes on the metric names. verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays messages when the callback takes an action. save_best_only: if `save_best_only=True`, it only saves when the model is considered the "best" and the latest best model according to the quantity monitored will not be overwritten. If `filepath` doesn't contain formatting options like `{epoch}` then `filepath` will be overwritten by each new better model. mode: one of {`"auto"`, `"min"`, `"max"`}. If `save_best_only=True`, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, this should be `"max"`, for `val_loss` this should be `"min"`, etc. In `"auto"` mode, the mode is set to `"max"` if the quantities monitored are `"acc"` or start with `"fmeasure"` and are set to `"min"` for the rest of the quantities. save_weights_only: if True, then only the model's weights will be saved (`model.save_weights(filepath)`), else the full model is saved (`model.save(filepath)`). save_freq: `"epoch"` or integer. When using `"epoch"`, the callback saves the model after each epoch. When using integer, the callback saves the model at end of this many batches. If the `Model` is compiled with `steps_per_execution=N`, then the saving criteria will be checked every Nth batch. Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (it could reflect as little as 1 batch, since the metrics get reset every epoch). Defaults to `"epoch"`. initial_value_threshold: Floating point initial "best" value of the metric to be monitored. Only applies if `save_best_value=True`. Only overwrites the model weights already saved if the performance of current model is better than this value. """ def __init__( self, filepath, monitor="val_loss", verbose=0, save_best_only=False, save_weights_only=False, mode="auto", save_freq="epoch", initial_value_threshold=None, ): super().__init__() self.monitor = monitor self.verbose = verbose self.filepath = file_utils.path_to_string(filepath) self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.save_freq = save_freq self._batches_seen_since_last_saving = 0 self._last_batch_seen = 0 self.best = initial_value_threshold if mode not in ["auto", "min", "max"]: warnings.warn( f"ModelCheckpoint mode '{mode}' is unknown, " "fallback to auto mode.", stacklevel=2, ) mode = "auto" if mode == "min": self.monitor_op = np.less if self.best is None: self.best = np.Inf elif mode == "max": self.monitor_op = np.greater if self.best is None: self.best = -np.Inf else: if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.monitor_op = np.greater if self.best is None: self.best = -np.Inf else: self.monitor_op = np.less if self.best is None: self.best = np.Inf if self.save_freq != "epoch" and not isinstance(self.save_freq, int): raise ValueError( f"Unrecognized save_freq: {self.save_freq}. " "Expected save_freq are 'epoch' or integer values" ) def on_train_batch_end(self, batch, logs=None): if self._should_save_on_batch(batch): self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) def on_epoch_begin(self, epoch, logs=None): self._current_epoch = epoch def on_epoch_end(self, epoch, logs=None): if self.save_freq == "epoch": self._save_model(epoch=epoch, batch=None, logs=logs) def _should_save_on_batch(self, batch): """Handles batch-level saving logic, supports steps_per_execution.""" if self.save_freq == "epoch": return False if batch <= self._last_batch_seen: # New epoch. add_batches = batch + 1 # batches are zero-indexed. else: add_batches = batch - self._last_batch_seen self._batches_seen_since_last_saving += add_batches self._last_batch_seen = batch if self._batches_seen_since_last_saving >= self.save_freq: self._batches_seen_since_last_saving = 0 return True return False def _save_model(self, epoch, batch, logs): """Saves the model. Args: epoch: the epoch this iteration is in. batch: the batch this iteration is in. `None` if the `save_freq` is set to `"epoch"`. logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. """ logs = logs or {} filepath = self._get_file_path(epoch, batch, logs) # Create host directory if it doesn't exist. dirname = os.path.dirname(filepath) if dirname and not gfile.exists(dirname): gfile.makedirs(dirname) try: if self.save_best_only: current = logs.get(self.monitor) if current is None: warnings.warn( f"Can save best model only with {self.monitor} " "available, skipping.", stacklevel=2, ) else: if self.monitor_op(current, self.best): if self.verbose > 0: io_utils.print_msg( f"\nEpoch {epoch + 1}: {self.monitor} " "improved " f"from {self.best:.5f} to {current:.5f}, " f"saving model to {filepath}" ) self.best = current if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) else: if self.verbose > 0: io_utils.print_msg( f"\nEpoch {epoch + 1}: " f"{self.monitor} did not improve " f"from {self.best:.5f}" ) else: if self.verbose > 0: io_utils.print_msg( f"\nEpoch {epoch + 1}: saving model to {filepath}" ) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) except IsADirectoryError: # h5py 3.x raise IOError( "Please specify a non-directory filepath for " "ModelCheckpoint. Filepath used is an existing " f"directory: {filepath}" ) except IOError as e: # h5py 2.x # `e.errno` appears to be `None` so checking the content of # `e.args[0]`. if "is a directory" in str(e.args[0]).lower(): raise IOError( "Please specify a non-directory filepath for " "ModelCheckpoint. Filepath used is an existing " f"directory: f{filepath}" ) # Re-throw the error for any other causes. raise e def _get_file_path(self, epoch, batch, logs): """Returns the file path for checkpoint.""" try: # `filepath` may contain placeholders such as # `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between # logged metrics and the path's placeholders can cause formatting to # fail. if batch is None or "batch" in logs: file_path = self.filepath.format(epoch=epoch + 1, **logs) else: file_path = self.filepath.format( epoch=epoch + 1, batch=batch + 1, **logs ) except KeyError as e: raise KeyError( f'Failed to format this callback filepath: "{self.filepath}". ' f"Reason: {e}" ) return file_path def _checkpoint_exists(self, filepath): """Returns whether the checkpoint `filepath` refers to exists.""" return gfile.exists(filepath) def _get_most_recently_modified_file_matching_pattern(self, pattern): """Returns the most recently modified filepath matching pattern. In the rare case where there are more than one pattern-matching file having the same modified time that is most recent among all, return the filepath that is largest (by `>` operator, lexicographically using the numeric equivalents). This provides a tie-breaker when multiple files are most recent. Note that a larger `filepath` can sometimes indicate a later time of modification (for instance, when epoch/batch is used as formatting option), but not necessarily (when accuracy or loss is used). The tie-breaker is put in the logic as best effort to return the most recent, and to avoid undeterministic result. Modified time of a file is obtained with `os.path.getmtime()`. This utility function is best demonstrated via an example: ```python file_pattern = 'batch{batch:02d}epoch{epoch:02d}.keras' test_dir = self.get_temp_dir() path_pattern = os.path.join(test_dir, file_pattern) file_paths = [ os.path.join(test_dir, file_name) for file_name in ['batch03epoch02.keras', 'batch02epoch02.keras', 'batch01epoch01.keras'] ] for file_path in file_paths: # Write something to each of the files ... self.assertEqual( _get_most_recently_modified_file_matching_pattern(path_pattern), file_paths[-1]) ``` Args: pattern: The file pattern that may optionally contain python placeholder such as `{epoch:02d}`. Returns: The most recently modified file's full filepath matching `pattern`. If `pattern` does not contain any placeholder, this returns the filepath that exactly matches `pattern`. Returns `None` if no match is found. """ dir_name = os.path.dirname(pattern) base_name = os.path.basename(pattern) base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$" latest_mod_time = 0 file_path_with_latest_mod_time = None n_file_with_latest_mod_time = 0 file_path_with_largest_file_name = None if gfile.exists(dir_name): for file_name in os.listdir(dir_name): # Only consider if `file_name` matches the pattern. if re.match(base_name_regex, file_name): file_path = os.path.join(dir_name, file_name) mod_time = os.path.getmtime(file_path) if ( file_path_with_largest_file_name is None or file_path > file_path_with_largest_file_name ): file_path_with_largest_file_name = file_path if mod_time > latest_mod_time: latest_mod_time = mod_time file_path_with_latest_mod_time = file_path # In the case a file with later modified time is found, # reset the counter for the number of files with latest # modified time. n_file_with_latest_mod_time = 1 elif mod_time == latest_mod_time: # In the case a file has modified time tied with the # most recent, increment the counter for the number of # files with latest modified time by 1. n_file_with_latest_mod_time += 1 if n_file_with_latest_mod_time == 1: # Return the sole file that has most recent modified time. return file_path_with_latest_mod_time else: # If there are more than one file having latest modified time, # return the file path with the largest file name. return file_path_with_largest_file_name