2f197f8ac2
* Add Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint * Adds Model Checkpoint
537 lines
17 KiB
Python
537 lines
17 KiB
Python
import os
|
|
import warnings
|
|
|
|
import pytest
|
|
|
|
from keras_core import callbacks
|
|
from keras_core import layers
|
|
from keras_core import metrics
|
|
from keras_core import models
|
|
from keras_core import saving
|
|
from keras_core import testing
|
|
from keras_core.models import Sequential
|
|
from keras_core.testing import test_utils
|
|
from keras_core.utils import numerical_utils
|
|
|
|
try:
|
|
import h5py
|
|
except ImportError:
|
|
h5py = None
|
|
|
|
TRAIN_SAMPLES = 10
|
|
TEST_SAMPLES = 10
|
|
NUM_CLASSES = 2
|
|
INPUT_DIM = 3
|
|
NUM_HIDDEN = 5
|
|
BATCH_SIZE = 5
|
|
|
|
|
|
class ModelCheckpointTest(testing.TestCase):
|
|
@pytest.mark.skipif(
|
|
h5py is None,
|
|
reason="`h5py` is a required dependency for `ModelCheckpoint` tests.",
|
|
)
|
|
def test_model_checkpoint_options(self):
|
|
def get_model():
|
|
model = Sequential(
|
|
[
|
|
layers.Dense(NUM_HIDDEN, activation="relu"),
|
|
layers.Dense(NUM_CLASSES, activation="softmax"),
|
|
]
|
|
)
|
|
model.compile(
|
|
loss="categorical_crossentropy",
|
|
optimizer="sgd",
|
|
metrics=[metrics.Accuracy("acc")],
|
|
)
|
|
return model
|
|
|
|
model = get_model()
|
|
temp_dir = self.get_temp_dir()
|
|
|
|
# Save model to a subdir inside the temp_dir so we can test
|
|
# automatic directory creation.
|
|
filepath = os.path.join(temp_dir, "subdir", "checkpoint.keras")
|
|
(x_train, y_train), (x_test, y_test) = test_utils.get_test_data(
|
|
train_samples=TRAIN_SAMPLES,
|
|
test_samples=TEST_SAMPLES,
|
|
input_shape=(INPUT_DIM,),
|
|
num_classes=NUM_CLASSES,
|
|
)
|
|
y_test = numerical_utils.to_categorical(y_test)
|
|
y_train = numerical_utils.to_categorical(y_train)
|
|
|
|
# Case 1
|
|
monitor = "val_loss"
|
|
save_best_only = False
|
|
mode = "auto"
|
|
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 2
|
|
mode = "min"
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 3
|
|
mode = "max"
|
|
monitor = "val_acc"
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 4
|
|
save_best_only = True
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 5: metric not available.
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath, monitor="unknown", save_best_only=True
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
# File won't be written.
|
|
self.assertFalse(os.path.exists(filepath))
|
|
|
|
# Case 6
|
|
with warnings.catch_warnings(record=True) as warning_logs:
|
|
warnings.simplefilter("always")
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode="unknown",
|
|
)
|
|
self.assertIn(
|
|
"ModelCheckpoint mode 'unknown' is unknown",
|
|
str(warning_logs[-1].message),
|
|
)
|
|
|
|
# Case 8a: `ModelCheckpoint` with an integer `save_freq`
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(temp_dir, "checkpoint.epoch{epoch:02d}.keras")
|
|
save_best_only = False
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
save_freq=15,
|
|
)
|
|
]
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=3)))
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=2, # 5 batches / epoch, so should backup every 3 epochs
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=10,
|
|
verbose=0,
|
|
)
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=1)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=4)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=5)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=6)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=7)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=8)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=9)))
|
|
os.remove(filepath.format(epoch=3))
|
|
os.remove(filepath.format(epoch=6))
|
|
os.remove(filepath.format(epoch=9))
|
|
|
|
# Case 8b: `ModelCheckpoint` with int `save_freq` & `save_weights_only`
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(
|
|
temp_dir, "checkpoint.epoch{epoch:02d}.weights.h5"
|
|
)
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath, monitor=monitor, save_freq=15, save_weights_only=True
|
|
)
|
|
]
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=3)))
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=2,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=10,
|
|
verbose=0,
|
|
)
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=1)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=4)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=5)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=6)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=7)))
|
|
self.assertFalse(os.path.exists(filepath.format(epoch=8)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=9)))
|
|
|
|
# Case 9: `ModelCheckpoint` with valid and invalid save_freq argument.
|
|
with self.assertRaisesRegex(ValueError, "Unrecognized save_freq"):
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
save_freq="invalid_save_freq",
|
|
)
|
|
# The following should not raise ValueError.
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
save_freq="epoch",
|
|
)
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
save_freq=3,
|
|
)
|
|
|
|
# Case 10a: `ModelCheckpoint` save with batch in filename.
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(
|
|
temp_dir, "checkpoint.epoch{epoch:02d}batch{batch:02d}.keras"
|
|
)
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(filepath, monitor=monitor, save_freq=1)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=5,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=5,
|
|
verbose=1,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2)))
|
|
|
|
# Case 10b: `ModelCheckpoint` save weights with batch in filename.
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(
|
|
temp_dir, "checkpoint.epoch{epoch:02d}batch{batch:02d}.weights.h5"
|
|
)
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath, monitor=monitor, save_freq=1, save_weights_only=True
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=5,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=5,
|
|
verbose=1,
|
|
)
|
|
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1)))
|
|
self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2)))
|
|
|
|
# Case 11: ModelCheckpoint saves model with initial_value_threshold
|
|
# param
|
|
mode = "max"
|
|
monitor = "val_acc"
|
|
initial_value_threshold = -0.01
|
|
save_best_only = True
|
|
filepath = os.path.join(temp_dir, "checkpoint.keras")
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
initial_value_threshold=initial_value_threshold,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 12: ModelCheckpoint saves model with initial_value_threshold
|
|
# param
|
|
mode = "auto"
|
|
monitor = "val_loss"
|
|
initial_value_threshold = None
|
|
save_best_only = True
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
initial_value_threshold=initial_value_threshold,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertTrue(os.path.exists(filepath))
|
|
os.remove(filepath)
|
|
|
|
# Case 13: ModelCheckpoint doesnt save model if loss was minimum earlier
|
|
mode = "min"
|
|
monitor = "val_loss"
|
|
initial_value_threshold = 0
|
|
save_best_only = True
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
initial_value_threshold=initial_value_threshold,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertFalse(os.path.exists(filepath))
|
|
|
|
# Case 14: ModelCheckpoint doesnt save model if loss was min earlier in
|
|
# auto mode
|
|
mode = "auto"
|
|
monitor = "val_loss"
|
|
initial_value_threshold = 0
|
|
save_best_only = True
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
initial_value_threshold=initial_value_threshold,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
self.assertFalse(os.path.exists(filepath))
|
|
|
|
@pytest.mark.skipif(
|
|
h5py is None,
|
|
reason="`h5py` is a required dependency for `ModelCheckpoint` tests.",
|
|
)
|
|
def test_model_checkpoint_loading(self):
|
|
def get_model():
|
|
inputs = layers.Input(shape=(INPUT_DIM,), batch_size=2)
|
|
x = layers.Dense(NUM_HIDDEN, activation="relu")(inputs)
|
|
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
|
|
functional_model = models.Model(inputs, outputs)
|
|
functional_model.compile(
|
|
loss="categorical_crossentropy",
|
|
optimizer="sgd",
|
|
metrics=[metrics.Accuracy("acc")],
|
|
)
|
|
return functional_model
|
|
|
|
(x_train, y_train), (x_test, y_test) = test_utils.get_test_data(
|
|
train_samples=TRAIN_SAMPLES,
|
|
test_samples=TEST_SAMPLES,
|
|
input_shape=(INPUT_DIM,),
|
|
num_classes=NUM_CLASSES,
|
|
)
|
|
y_test = numerical_utils.to_categorical(y_test)
|
|
y_train = numerical_utils.to_categorical(y_train)
|
|
|
|
# Model Checkpoint load model (default)
|
|
model = get_model()
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(temp_dir, "checkpoint.model.keras")
|
|
mode = "auto"
|
|
monitor = "val_loss"
|
|
save_best_only = True
|
|
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
ref_weights = model.get_weights()
|
|
self.assertTrue(os.path.exists(filepath))
|
|
new_model = saving.load_model(filepath)
|
|
new_weights = new_model.get_weights()
|
|
self.assertEqual(len(ref_weights), len(new_weights))
|
|
for ref_w, w in zip(ref_weights, new_weights):
|
|
self.assertAllClose(ref_w, w)
|
|
|
|
# Model Checkpoint load model weights
|
|
model = get_model()
|
|
temp_dir = self.get_temp_dir()
|
|
filepath = os.path.join(temp_dir, "checkpoint.weights.h5")
|
|
mode = "auto"
|
|
monitor = "val_loss"
|
|
save_best_only = True
|
|
|
|
cbks = [
|
|
callbacks.ModelCheckpoint(
|
|
filepath,
|
|
monitor=monitor,
|
|
save_best_only=save_best_only,
|
|
save_weights_only=True,
|
|
mode=mode,
|
|
)
|
|
]
|
|
model.fit(
|
|
x_train,
|
|
y_train,
|
|
batch_size=BATCH_SIZE,
|
|
validation_data=(x_test, y_test),
|
|
callbacks=cbks,
|
|
epochs=1,
|
|
verbose=0,
|
|
)
|
|
ref_weights = model.get_weights()
|
|
self.assertTrue(os.path.exists(filepath))
|
|
new_model = get_model()
|
|
new_model.load_weights(filepath)
|
|
new_weights = new_model.get_weights()
|
|
self.assertEqual(len(ref_weights), len(new_weights))
|
|
for ref_w, w in zip(ref_weights, new_weights):
|
|
self.assertAllClose(ref_w, w)
|