Allow re-use of EarlyStopping callback objects. (#3000)

An EarlyStopping callback object has internal state variables to tell it
when it has reached its stopping point.  These were initialized in __init__(),
so attempting to re-use the same object resulted in immediate stopping. This
prevents (for example) performing early stopping during cross-validation with
the scikit-learn wrapper.

This patch initializes the variables in on_train_begin(), so they are re-set
for each training fold.  Tests included.
This commit is contained in:
jkleint 2016-06-18 16:04:30 -06:00 committed by François Chollet
parent 60e0c96f6c
commit 3513472467
2 changed files with 25 additions and 4 deletions

@ -327,17 +327,17 @@ class EarlyStopping(Callback):
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_train_begin(self, logs={}):
self.wait = 0 # Allow instances to be re-used
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs={}):
current = logs.get(self.monitor)

@ -105,6 +105,27 @@ def test_EarlyStopping():
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=20)
def test_EarlyStopping_reuse():
patience = 3
data = np.random.random((100, 1))
labels = np.where(data > 0.5, 1, 0)
model = Sequential((
Dense(1, input_dim=1, activation='relu'),
Dense(1, activation='sigmoid'),
))
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
stopper = callbacks.EarlyStopping(monitor='acc', patience=patience)
weights = model.get_weights()
hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience
# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience
def test_LearningRateScheduler():
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
nb_test=test_samples,