Catch KeyboardInterrupt exception if the training is aborted

This commit is contained in:
Tristan Deleu 2015-05-28 11:14:24 +02:00
parent 2b1bb90bf0
commit dced184426

@ -239,16 +239,23 @@ class Sequential(Model):
for batch_index, (batch_start, batch_end) in enumerate(batches):
callbacks.on_batch_begin(batch_index)
batch_ids = index_array[batch_start:batch_end]
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
try:
batch_ids = index_array[batch_start:batch_end]
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
ins = X_batch + [y_batch]
if show_accuracy:
loss, acc = self._train_with_acc(*ins)
else:
loss = self._train(*ins)
acc = None
ins = X_batch + [y_batch]
if show_accuracy:
loss, acc = self._train_with_acc(*ins)
else:
loss = self._train(*ins)
acc = None
except KeyboardInterrupt:
# If training is aborted, call the callbacks anyway before terminating
callbacks.on_batch_end(batch_index, [], 0., 0.)
callbacks.on_epoch_end(epoch, 0., 0.)
callbacks.on_train_end()
raise KeyboardInterrupt # TODO: Raise a more explicit Excpetion (?)
callbacks.on_batch_end(batch_index, batch_ids, loss, acc)