Update callback system

This commit is contained in:
fchollet 2015-06-01 20:11:44 -07:00
parent 1d61d18b9e
commit 54b999e661
5 changed files with 125 additions and 92 deletions

@ -1,6 +1,6 @@
## Usage of callbacks
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callback (as the keyword argument `callbacks`) to the `.fit()` method of the `Sequential` model. The relevant methods of the callbacks will then be called at each stage of the training.
---
@ -13,19 +13,24 @@ keras.callbacks.Callback()
- __params__: dict. Training parameters (eg. verbosity, batch size, number of epochs...).
- __model__: `keras.models.Model`. Reference of the model being trained.
- __Methods__:
- __on_train_begin__(): Method called at the beginning of training.
- __on_train_end__(): Method called at the end of training.
- __on_epoch_begin__(epoch): Method called at the beginning of epoch `epoch`.
- __on_epoch_end__(epoch): Method called at the end of epoch `epoch`.
- __on_batch_begin__(batch): Method called at the beginning of batch `batch`.
- __on_batch_end__(batch): Method called at the end of batch `batch`.
- __on_train_begin__(logs={}): Method called at the beginning of training.
- __on_train_end__(logs={}): Method called at the end of training.
- __on_epoch_begin__(epoch, logs={}): Method called at the beginning of epoch `epoch`.
- __on_epoch_end__(epoch, logs={}): Method called at the end of epoch `epoch`.
- __on_batch_begin__(batch, logs={}): Method called at the beginning of batch `batch`.
- __on_batch_end__(batch, logs={}): Method called at the end of batch `batch`.
The `logs` dictionary will contain keys for quantities relevant to the current batch or epoch. Currently, the `.fit()` method of the `Sequential` model class will include the following quantities in the `logs` that it passes to its callbacks:
- __on_epoch_end__: logs optionally include `val_loss` (if validation is enabled in `fit`), and `val_accuracy` (if validation and accuracy monitoring are enabled).
- __on_batch_begin__: logs include `size`, the number of samples in the current batch.
- __on_batch_end__: logs include `loss`, and optionally `accuracy` (if accuracy monitoring is enabled).
---
## Create a callback
You can create a custom callback by extending the base class `keras.callbacks.Callback`. A callback has access to its associated model through the class property `self.model`. Two properties of models will be of particular interest to callbacks: `self.model.epoch_history` and `self.model.batch_history`.
You can create a custom callback by extending the base class `keras.callbacks.Callback`. A callback has access to its associated model through the class property `self.model`.
Here's a simple example saving a list of losses over each batch during training:
```python
@ -33,8 +38,8 @@ class LossHistory(keras.callbacks.Callback):
def on_train_begin(self):
self.losses = []
def on_batch_end(self, batch):
self.losses.append(self.model.batch_history.loss[-1])
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
```
---
@ -46,8 +51,8 @@ class LossHistory(keras.callbacks.Callback):
def on_train_begin(self):
self.losses = []
def on_batch_end(self, batch):
self.losses.append(self.model.batch_history.loss[-1])
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
model = Sequential()
model.add(Dense(784, 10, init='uniform'))

@ -17,7 +17,7 @@ from keras.preprocessing.text import Tokenizer
python examples/reuters_mlp.py
'''
max_words = 10000
max_words = 1000
batch_size = 32
print("Loading data...")
@ -51,8 +51,12 @@ model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
history = model.fit(X_train, Y_train, nb_epoch=3, batch_size=batch_size, verbose=1, show_accuracy=True, validation_split=0.1)
print(history)
score = model.evaluate(X_test, Y_test, batch_size=batch_size, verbose=1, show_accuracy=True)
history = model.fit(X_train, Y_train, nb_epoch=3, batch_size=batch_size, verbose=1, show_accuracy=False, validation_split=0.1)
print(history.epoch)
print(history.loss)
print(history.accuracy)
print(history.validation_loss)
print(history.validation_accuracy)
score = model.evaluate(X_test, Y_test, batch_size=batch_size, verbose=1, show_accuracy=False)
print('Test score:', score[0])
print('Test accuracy:', score[1])

@ -26,21 +26,21 @@ class CallbackList(object):
for callback in self.callbacks:
callback._set_model(model)
def on_epoch_begin(self, epoch):
def on_epoch_begin(self, epoch, logs={}):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
callback.on_epoch_begin(epoch, logs)
self._delta_t_batch = 0.
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
def on_epoch_end(self, epoch):
def on_epoch_end(self, epoch, logs={}):
for callback in self.callbacks:
callback.on_epoch_end(epoch)
callback.on_epoch_end(epoch, logs)
def on_batch_begin(self, batch):
def on_batch_begin(self, batch, logs={}):
t_before_callbacks = time.time()
for callback in self.callbacks:
callback.on_batch_begin(batch)
callback.on_batch_begin(batch, logs)
self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
delta_t_median = np.median(self._delta_ts_batch_begin)
if self._delta_t_batch > 0. and delta_t_median > 0.95 * self._delta_t_batch \
@ -49,11 +49,11 @@ class CallbackList(object):
'to the batch update (%f). Check your callbacks.' % delta_t_median)
self._t_enter_batch = time.time()
def on_batch_end(self, batch):
def on_batch_end(self, batch, logs={}):
self._delta_t_batch = time.time() - self._t_enter_batch
t_before_callbacks = time.time()
for callback in self.callbacks:
callback.on_batch_end(batch)
callback.on_batch_end(batch, logs)
self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
delta_t_median = np.median(self._delta_ts_batch_end)
if self._delta_t_batch > 0. and delta_t_median > 0.95 * self._delta_t_batch \
@ -61,13 +61,13 @@ class CallbackList(object):
warnings.warn('Method on_batch_end() is slow compared '
'to the batch update (%f). Check your callbacks.' % delta_t_median)
def on_train_begin(self):
def on_train_begin(self, logs={}):
for callback in self.callbacks:
callback.on_train_begin()
callback.on_train_begin(logs)
def on_train_end(self):
def on_train_end(self, logs={}):
for callback in self.callbacks:
callback.on_train_end()
callback.on_train_end(logs)
class Callback(object):
@ -81,62 +81,102 @@ class Callback(object):
def _set_model(self, model):
self.model = model
def on_epoch_begin(self, epoch):
def on_epoch_begin(self, epoch, logs={}):
pass
def on_epoch_end(self, epoch):
def on_epoch_end(self, epoch, logs={}):
pass
def on_batch_begin(self, batch):
def on_batch_begin(self, batch, logs={}):
pass
def on_batch_end(self, batch):
def on_batch_end(self, batch, logs={}):
pass
def on_train_begin(self):
def on_train_begin(self, logs={}):
pass
def on_train_end(self):
def on_train_end(self, logs={}):
pass
class BaseLogger(Callback):
def on_train_begin(self):
def on_train_begin(self, logs={}):
self.verbose = self.params['verbose']
def on_epoch_begin(self, epoch):
def on_epoch_begin(self, epoch, logs={}):
if self.verbose:
print('Epoch %d' % epoch)
self.progbar = Progbar(target=self.params['nb_sample'], \
verbose=self.verbose)
self.current = 0
self.tot_loss = 0.
self.tot_acc = 0.
def on_batch_begin(self, batch):
self.log_values = []
def on_batch_end(self, batch_index):
self.current += self.model.batch_history['batch_size'][-1]
# skip progbar update for the last batch; will be handled by on_epoch_end
def on_batch_begin(self, batch, logs={}):
if self.current < self.params['nb_sample']:
loss = self.model.batch_history['loss'][-1]
self.log_values.append(('loss', loss))
if self.params['show_accuracy']:
accuracy = self.model.batch_history['accuracy'][-1]
self.log_values.append(('acc.', accuracy))
if self.verbose:
self.progbar.update(self.current, self.log_values)
self.log_values = []
def on_epoch_end(self, epoch):
loss = self.model.batch_history['loss'][-1]
def on_batch_end(self, batch, logs={}):
batch_size = logs.get('size', 0)
self.current += batch_size
loss = logs.get('loss')
self.log_values.append(('loss', loss))
self.tot_loss += loss * batch_size
if self.params['show_accuracy']:
accuracy = self.model.batch_history['accuracy'][-1]
accuracy = logs.get('accuracy')
self.log_values.append(('acc.', accuracy))
self.tot_acc += accuracy * batch_size
# skip progbar update for the last batch; will be handled by on_epoch_end
if self.verbose and self.current < self.params['nb_sample']:
self.progbar.update(self.current, self.log_values)
def on_epoch_end(self, epoch, logs={}):
self.log_values.append(('loss', self.tot_loss / self.current))
if self.params['show_accuracy']:
self.log_values.append(('acc.', self.tot_acc / self.current))
if self.params['do_validation']:
val_loss = self.model.epoch_history['val_loss'][-1]
val_loss = logs.get('val_loss')
self.log_values.append(('val. loss', val_loss))
if self.params['show_accuracy']:
val_acc = self.model.epoch_history['val_accuracy'][-1]
val_acc = logs.get('val_accuracy')
self.log_values.append(('val. acc.', val_acc))
self.progbar.update(self.current, self.log_values)
class History(Callback):
def on_train_begin(self, logs={}):
self.epoch = []
self.loss = []
if self.params['show_accuracy']:
self.accuracy = []
if self.params['do_validation']:
self.validation_loss = []
if self.params['show_accuracy']:
self.validation_accuracy = []
def on_epoch_begin(self, epoch, logs={}):
self.seen = 0
self.tot_loss = 0.
self.tot_accuracy = 0.
def on_batch_end(self, batch, logs={}):
batch_size = logs.get('size', 0)
self.seen += batch_size
self.tot_loss += logs.get('loss', 0.) * batch_size
if self.params['show_accuracy']:
self.tot_accuracy += logs.get('accuracy', 0.) * batch_size
def on_epoch_end(self, epoch, logs={}):
val_loss = logs.get('val_loss')
val_acc = logs.get('val_accuracy')
self.epoch.append(epoch)
self.loss.append(self.tot_loss / self.seen)
if self.params['show_accuracy']:
self.accuracy.append(self.tot_accuracy / self.seen)
if self.params['do_validation']:
self.validation_loss.append(val_loss)
if self.params['show_accuracy']:
self.validation_accuracy.append(val_acc)

@ -218,6 +218,7 @@ class Sequential(Model):
callbacks = cbks.CallbackList(callbacks)
if verbose:
callbacks.append(cbks.BaseLogger())
callbacks.append(cbks.History())
callbacks._set_model(self)
callbacks._set_params({
@ -228,26 +229,9 @@ class Sequential(Model):
'do_validation': do_validation,
'show_accuracy': show_accuracy
})
self.batch_history = {
'batch':[],
'batch_size':[],
'loss':[],
'accuracy':[],
'val_loss':[],
'val_accuracy':[],
}
self.epoch_history = {
'epoch':[],
'epoch_size':[],
'loss':[],
'accuracy':[],
'val_loss':[],
'val_accuracy':[],
}
callbacks.on_train_begin()
for epoch in range(nb_epoch):
self.epoch_history['epoch'] = epoch
callbacks.on_epoch_begin(epoch)
if shuffle:
np.random.shuffle(index_array)
@ -258,40 +242,38 @@ class Sequential(Model):
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
self.batch_history['batch'].append(batch_index)
self.batch_history['batch_size'].append(len(batch_ids))
callbacks.on_batch_begin(batch_index)
batch_logs = {}
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
ins = X_batch + [y_batch]
if show_accuracy:
loss, acc = self._train_with_acc(*ins)
self.batch_history['accuracy'].append(acc)
batch_logs['accuracy'] = acc
else:
loss = self._train(*ins)
self.batch_history['loss'].append(loss)
batch_logs['loss'] = loss
callbacks.on_batch_end(batch_index)
callbacks.on_batch_end(batch_index, batch_logs)
if batch_index == len(batches) - 1: # last batch
# validation
epoch_logs = {}
if do_validation:
if show_accuracy:
val_loss, val_acc = self.evaluate(X_val, y_val, batch_size=batch_size, \
verbose=0, show_accuracy=True)
self.epoch_history['val_accuracy'].append(val_acc)
epoch_logs['val_accuracy'] = val_acc
else:
val_loss = self.evaluate(X_val, y_val, batch_size=batch_size, verbose=0)
self.epoch_history['val_loss'].append(val_loss)
epoch_logs['val_loss'] = val_loss
epoch_loss = sum(map(lambda x: x[0]*x[1], zip(self.batch_history['batch_size'], self.batch_history['loss']))) / len(y)
self.epoch_history['loss'].append(epoch_loss)
if show_accuracy:
epoch_acc = sum(map(lambda x: x[0]*x[1], zip(self.batch_history['batch_size'], self.batch_history['accuracy']))) / len(y)
self.epoch_history['accuracy'].append(epoch_acc)
callbacks.on_epoch_end(epoch)
callbacks.on_epoch_end(epoch, epoch_logs)
callbacks.on_train_end()
return self.epoch_history
# return history
return callbacks.callbacks[-1]
def predict(self, X, batch_size=128, verbose=1):
X = standardize_X(X)
@ -334,6 +316,7 @@ class Sequential(Model):
if show_accuracy:
tot_acc = 0.
tot_score = 0.
seen = 0
batches = make_batches(len(y), batch_size)
if verbose:
@ -345,21 +328,22 @@ class Sequential(Model):
ins = X_batch + [y_batch]
if show_accuracy:
loss, acc = self._test_with_acc(*ins)
tot_acc += acc
tot_acc += acc * len(y_batch)
log_values = [('loss', loss), ('acc.', acc)]
else:
loss = self._test(*ins)
log_values = [('loss', loss)]
tot_score += loss
tot_score += loss * len(y_batch)
seen += len(y_batch)
# logging
if verbose:
progbar.update(batch_end, log_values)
if show_accuracy:
return tot_score/len(batches), tot_acc/len(batches)
return tot_score / seen, tot_acc / seen
else:
return tot_score/len(batches)
return tot_score / seen
def get_config(self, verbose=0):
layers = []

@ -60,7 +60,7 @@ class Progbar(object):
'''
for k, v in values:
if k not in self.sum_values:
self.sum_values[k] = [v * max(1, current-self.seen_so_far), current-self.seen_so_far]
self.sum_values[k] = [v * (current-self.seen_so_far), current-self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current-self.seen_so_far)