Catch KeyboardInterrupt exception if the training is aborted
This commit is contained in:
parent
2b1bb90bf0
commit
dced184426
@ -239,16 +239,23 @@ class Sequential(Model):
|
||||
for batch_index, (batch_start, batch_end) in enumerate(batches):
|
||||
callbacks.on_batch_begin(batch_index)
|
||||
|
||||
batch_ids = index_array[batch_start:batch_end]
|
||||
X_batch = slice_X(X, batch_ids)
|
||||
y_batch = y[batch_ids]
|
||||
try:
|
||||
batch_ids = index_array[batch_start:batch_end]
|
||||
X_batch = slice_X(X, batch_ids)
|
||||
y_batch = y[batch_ids]
|
||||
|
||||
ins = X_batch + [y_batch]
|
||||
if show_accuracy:
|
||||
loss, acc = self._train_with_acc(*ins)
|
||||
else:
|
||||
loss = self._train(*ins)
|
||||
acc = None
|
||||
ins = X_batch + [y_batch]
|
||||
if show_accuracy:
|
||||
loss, acc = self._train_with_acc(*ins)
|
||||
else:
|
||||
loss = self._train(*ins)
|
||||
acc = None
|
||||
except KeyboardInterrupt:
|
||||
# If training is aborted, call the callbacks anyway before terminating
|
||||
callbacks.on_batch_end(batch_index, [], 0., 0.)
|
||||
callbacks.on_epoch_end(epoch, 0., 0.)
|
||||
callbacks.on_train_end()
|
||||
raise KeyboardInterrupt # TODO: Raise a more explicit Excpetion (?)
|
||||
|
||||
callbacks.on_batch_end(batch_index, batch_ids, loss, acc)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user