Update callback system
This commit is contained in:
parent
1d61d18b9e
commit
54b999e661
@ -1,6 +1,6 @@
|
||||
## Usage of callbacks
|
||||
|
||||
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.
|
||||
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callback (as the keyword argument `callbacks`) to the `.fit()` method of the `Sequential` model. The relevant methods of the callbacks will then be called at each stage of the training.
|
||||
|
||||
---
|
||||
|
||||
@ -13,19 +13,24 @@ keras.callbacks.Callback()
|
||||
- __params__: dict. Training parameters (eg. verbosity, batch size, number of epochs...).
|
||||
- __model__: `keras.models.Model`. Reference of the model being trained.
|
||||
- __Methods__:
|
||||
- __on_train_begin__(): Method called at the beginning of training.
|
||||
- __on_train_end__(): Method called at the end of training.
|
||||
- __on_epoch_begin__(epoch): Method called at the beginning of epoch `epoch`.
|
||||
- __on_epoch_end__(epoch): Method called at the end of epoch `epoch`.
|
||||
- __on_batch_begin__(batch): Method called at the beginning of batch `batch`.
|
||||
- __on_batch_end__(batch): Method called at the end of batch `batch`.
|
||||
- __on_train_begin__(logs={}): Method called at the beginning of training.
|
||||
- __on_train_end__(logs={}): Method called at the end of training.
|
||||
- __on_epoch_begin__(epoch, logs={}): Method called at the beginning of epoch `epoch`.
|
||||
- __on_epoch_end__(epoch, logs={}): Method called at the end of epoch `epoch`.
|
||||
- __on_batch_begin__(batch, logs={}): Method called at the beginning of batch `batch`.
|
||||
- __on_batch_end__(batch, logs={}): Method called at the end of batch `batch`.
|
||||
|
||||
The `logs` dictionary will contain keys for quantities relevant to the current batch or epoch. Currently, the `.fit()` method of the `Sequential` model class will include the following quantities in the `logs` that it passes to its callbacks:
|
||||
- __on_epoch_end__: logs optionally include `val_loss` (if validation is enabled in `fit`), and `val_accuracy` (if validation and accuracy monitoring are enabled).
|
||||
- __on_batch_begin__: logs include `size`, the number of samples in the current batch.
|
||||
- __on_batch_end__: logs include `loss`, and optionally `accuracy` (if accuracy monitoring is enabled).
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Create a callback
|
||||
|
||||
You can create a custom callback by extending the base class `keras.callbacks.Callback`. A callback has access to its associated model through the class property `self.model`. Two properties of models will be of particular interest to callbacks: `self.model.epoch_history` and `self.model.batch_history`.
|
||||
You can create a custom callback by extending the base class `keras.callbacks.Callback`. A callback has access to its associated model through the class property `self.model`.
|
||||
|
||||
Here's a simple example saving a list of losses over each batch during training:
|
||||
```python
|
||||
@ -33,8 +38,8 @@ class LossHistory(keras.callbacks.Callback):
|
||||
def on_train_begin(self):
|
||||
self.losses = []
|
||||
|
||||
def on_batch_end(self, batch):
|
||||
self.losses.append(self.model.batch_history.loss[-1])
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
self.losses.append(logs.get('loss'))
|
||||
```
|
||||
|
||||
---
|
||||
@ -46,8 +51,8 @@ class LossHistory(keras.callbacks.Callback):
|
||||
def on_train_begin(self):
|
||||
self.losses = []
|
||||
|
||||
def on_batch_end(self, batch):
|
||||
self.losses.append(self.model.batch_history.loss[-1])
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
self.losses.append(logs.get('loss'))
|
||||
|
||||
model = Sequential()
|
||||
model.add(Dense(784, 10, init='uniform'))
|
||||
|
@ -17,7 +17,7 @@ from keras.preprocessing.text import Tokenizer
|
||||
python examples/reuters_mlp.py
|
||||
'''
|
||||
|
||||
max_words = 10000
|
||||
max_words = 1000
|
||||
batch_size = 32
|
||||
|
||||
print("Loading data...")
|
||||
@ -51,8 +51,12 @@ model.add(Activation('softmax'))
|
||||
|
||||
model.compile(loss='categorical_crossentropy', optimizer='adam')
|
||||
|
||||
history = model.fit(X_train, Y_train, nb_epoch=3, batch_size=batch_size, verbose=1, show_accuracy=True, validation_split=0.1)
|
||||
print(history)
|
||||
score = model.evaluate(X_test, Y_test, batch_size=batch_size, verbose=1, show_accuracy=True)
|
||||
history = model.fit(X_train, Y_train, nb_epoch=3, batch_size=batch_size, verbose=1, show_accuracy=False, validation_split=0.1)
|
||||
print(history.epoch)
|
||||
print(history.loss)
|
||||
print(history.accuracy)
|
||||
print(history.validation_loss)
|
||||
print(history.validation_accuracy)
|
||||
score = model.evaluate(X_test, Y_test, batch_size=batch_size, verbose=1, show_accuracy=False)
|
||||
print('Test score:', score[0])
|
||||
print('Test accuracy:', score[1])
|
||||
|
@ -26,21 +26,21 @@ class CallbackList(object):
|
||||
for callback in self.callbacks:
|
||||
callback._set_model(model)
|
||||
|
||||
def on_epoch_begin(self, epoch):
|
||||
def on_epoch_begin(self, epoch, logs={}):
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_begin(epoch)
|
||||
callback.on_epoch_begin(epoch, logs)
|
||||
self._delta_t_batch = 0.
|
||||
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
|
||||
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
|
||||
|
||||
def on_epoch_end(self, epoch):
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
callback.on_epoch_end(epoch, logs)
|
||||
|
||||
def on_batch_begin(self, batch):
|
||||
def on_batch_begin(self, batch, logs={}):
|
||||
t_before_callbacks = time.time()
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_begin(batch)
|
||||
callback.on_batch_begin(batch, logs)
|
||||
self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
|
||||
delta_t_median = np.median(self._delta_ts_batch_begin)
|
||||
if self._delta_t_batch > 0. and delta_t_median > 0.95 * self._delta_t_batch \
|
||||
@ -49,11 +49,11 @@ class CallbackList(object):
|
||||
'to the batch update (%f). Check your callbacks.' % delta_t_median)
|
||||
self._t_enter_batch = time.time()
|
||||
|
||||
def on_batch_end(self, batch):
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
self._delta_t_batch = time.time() - self._t_enter_batch
|
||||
t_before_callbacks = time.time()
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end(batch)
|
||||
callback.on_batch_end(batch, logs)
|
||||
self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
|
||||
delta_t_median = np.median(self._delta_ts_batch_end)
|
||||
if self._delta_t_batch > 0. and delta_t_median > 0.95 * self._delta_t_batch \
|
||||
@ -61,13 +61,13 @@ class CallbackList(object):
|
||||
warnings.warn('Method on_batch_end() is slow compared '
|
||||
'to the batch update (%f). Check your callbacks.' % delta_t_median)
|
||||
|
||||
def on_train_begin(self):
|
||||
def on_train_begin(self, logs={}):
|
||||
for callback in self.callbacks:
|
||||
callback.on_train_begin()
|
||||
callback.on_train_begin(logs)
|
||||
|
||||
def on_train_end(self):
|
||||
def on_train_end(self, logs={}):
|
||||
for callback in self.callbacks:
|
||||
callback.on_train_end()
|
||||
callback.on_train_end(logs)
|
||||
|
||||
|
||||
class Callback(object):
|
||||
@ -81,62 +81,102 @@ class Callback(object):
|
||||
def _set_model(self, model):
|
||||
self.model = model
|
||||
|
||||
def on_epoch_begin(self, epoch):
|
||||
def on_epoch_begin(self, epoch, logs={}):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, epoch):
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
pass
|
||||
|
||||
def on_batch_begin(self, batch):
|
||||
def on_batch_begin(self, batch, logs={}):
|
||||
pass
|
||||
|
||||
def on_batch_end(self, batch):
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
pass
|
||||
|
||||
def on_train_begin(self):
|
||||
def on_train_begin(self, logs={}):
|
||||
pass
|
||||
|
||||
def on_train_end(self):
|
||||
def on_train_end(self, logs={}):
|
||||
pass
|
||||
|
||||
class BaseLogger(Callback):
|
||||
|
||||
def on_train_begin(self):
|
||||
def on_train_begin(self, logs={}):
|
||||
self.verbose = self.params['verbose']
|
||||
|
||||
def on_epoch_begin(self, epoch):
|
||||
def on_epoch_begin(self, epoch, logs={}):
|
||||
if self.verbose:
|
||||
print('Epoch %d' % epoch)
|
||||
self.progbar = Progbar(target=self.params['nb_sample'], \
|
||||
verbose=self.verbose)
|
||||
self.current = 0
|
||||
self.tot_loss = 0.
|
||||
self.tot_acc = 0.
|
||||
|
||||
def on_batch_begin(self, batch):
|
||||
def on_batch_begin(self, batch, logs={}):
|
||||
if self.current < self.params['nb_sample']:
|
||||
self.log_values = []
|
||||
|
||||
def on_batch_end(self, batch_index):
|
||||
self.current += self.model.batch_history['batch_size'][-1]
|
||||
# skip progbar update for the last batch; will be handled by on_epoch_end
|
||||
if self.current < self.params['nb_sample']:
|
||||
loss = self.model.batch_history['loss'][-1]
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
batch_size = logs.get('size', 0)
|
||||
self.current += batch_size
|
||||
|
||||
loss = logs.get('loss')
|
||||
self.log_values.append(('loss', loss))
|
||||
self.tot_loss += loss * batch_size
|
||||
if self.params['show_accuracy']:
|
||||
accuracy = self.model.batch_history['accuracy'][-1]
|
||||
accuracy = logs.get('accuracy')
|
||||
self.log_values.append(('acc.', accuracy))
|
||||
if self.verbose:
|
||||
self.tot_acc += accuracy * batch_size
|
||||
# skip progbar update for the last batch; will be handled by on_epoch_end
|
||||
if self.verbose and self.current < self.params['nb_sample']:
|
||||
self.progbar.update(self.current, self.log_values)
|
||||
|
||||
def on_epoch_end(self, epoch):
|
||||
loss = self.model.batch_history['loss'][-1]
|
||||
self.log_values.append(('loss', loss))
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
self.log_values.append(('loss', self.tot_loss / self.current))
|
||||
if self.params['show_accuracy']:
|
||||
accuracy = self.model.batch_history['accuracy'][-1]
|
||||
self.log_values.append(('acc.', accuracy))
|
||||
self.log_values.append(('acc.', self.tot_acc / self.current))
|
||||
if self.params['do_validation']:
|
||||
val_loss = self.model.epoch_history['val_loss'][-1]
|
||||
val_loss = logs.get('val_loss')
|
||||
self.log_values.append(('val. loss', val_loss))
|
||||
if self.params['show_accuracy']:
|
||||
val_acc = self.model.epoch_history['val_accuracy'][-1]
|
||||
val_acc = logs.get('val_accuracy')
|
||||
self.log_values.append(('val. acc.', val_acc))
|
||||
self.progbar.update(self.current, self.log_values)
|
||||
|
||||
|
||||
class History(Callback):
|
||||
|
||||
def on_train_begin(self, logs={}):
|
||||
self.epoch = []
|
||||
self.loss = []
|
||||
if self.params['show_accuracy']:
|
||||
self.accuracy = []
|
||||
if self.params['do_validation']:
|
||||
self.validation_loss = []
|
||||
if self.params['show_accuracy']:
|
||||
self.validation_accuracy = []
|
||||
|
||||
def on_epoch_begin(self, epoch, logs={}):
|
||||
self.seen = 0
|
||||
self.tot_loss = 0.
|
||||
self.tot_accuracy = 0.
|
||||
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
batch_size = logs.get('size', 0)
|
||||
self.seen += batch_size
|
||||
self.tot_loss += logs.get('loss', 0.) * batch_size
|
||||
if self.params['show_accuracy']:
|
||||
self.tot_accuracy += logs.get('accuracy', 0.) * batch_size
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
val_loss = logs.get('val_loss')
|
||||
val_acc = logs.get('val_accuracy')
|
||||
self.epoch.append(epoch)
|
||||
self.loss.append(self.tot_loss / self.seen)
|
||||
if self.params['show_accuracy']:
|
||||
self.accuracy.append(self.tot_accuracy / self.seen)
|
||||
if self.params['do_validation']:
|
||||
self.validation_loss.append(val_loss)
|
||||
if self.params['show_accuracy']:
|
||||
self.validation_accuracy.append(val_acc)
|
||||
|
@ -218,6 +218,7 @@ class Sequential(Model):
|
||||
callbacks = cbks.CallbackList(callbacks)
|
||||
if verbose:
|
||||
callbacks.append(cbks.BaseLogger())
|
||||
callbacks.append(cbks.History())
|
||||
|
||||
callbacks._set_model(self)
|
||||
callbacks._set_params({
|
||||
@ -228,26 +229,9 @@ class Sequential(Model):
|
||||
'do_validation': do_validation,
|
||||
'show_accuracy': show_accuracy
|
||||
})
|
||||
self.batch_history = {
|
||||
'batch':[],
|
||||
'batch_size':[],
|
||||
'loss':[],
|
||||
'accuracy':[],
|
||||
'val_loss':[],
|
||||
'val_accuracy':[],
|
||||
}
|
||||
self.epoch_history = {
|
||||
'epoch':[],
|
||||
'epoch_size':[],
|
||||
'loss':[],
|
||||
'accuracy':[],
|
||||
'val_loss':[],
|
||||
'val_accuracy':[],
|
||||
}
|
||||
callbacks.on_train_begin()
|
||||
|
||||
for epoch in range(nb_epoch):
|
||||
self.epoch_history['epoch'] = epoch
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
if shuffle:
|
||||
np.random.shuffle(index_array)
|
||||
@ -258,40 +242,38 @@ class Sequential(Model):
|
||||
X_batch = slice_X(X, batch_ids)
|
||||
y_batch = y[batch_ids]
|
||||
|
||||
self.batch_history['batch'].append(batch_index)
|
||||
self.batch_history['batch_size'].append(len(batch_ids))
|
||||
callbacks.on_batch_begin(batch_index)
|
||||
batch_logs = {}
|
||||
batch_logs['batch'] = batch_index
|
||||
batch_logs['size'] = len(batch_ids)
|
||||
callbacks.on_batch_begin(batch_index, batch_logs)
|
||||
|
||||
ins = X_batch + [y_batch]
|
||||
if show_accuracy:
|
||||
loss, acc = self._train_with_acc(*ins)
|
||||
self.batch_history['accuracy'].append(acc)
|
||||
batch_logs['accuracy'] = acc
|
||||
else:
|
||||
loss = self._train(*ins)
|
||||
self.batch_history['loss'].append(loss)
|
||||
batch_logs['loss'] = loss
|
||||
|
||||
callbacks.on_batch_end(batch_index)
|
||||
callbacks.on_batch_end(batch_index, batch_logs)
|
||||
|
||||
if batch_index == len(batches) - 1: # last batch
|
||||
# validation
|
||||
epoch_logs = {}
|
||||
if do_validation:
|
||||
if show_accuracy:
|
||||
val_loss, val_acc = self.evaluate(X_val, y_val, batch_size=batch_size, \
|
||||
verbose=0, show_accuracy=True)
|
||||
self.epoch_history['val_accuracy'].append(val_acc)
|
||||
epoch_logs['val_accuracy'] = val_acc
|
||||
else:
|
||||
val_loss = self.evaluate(X_val, y_val, batch_size=batch_size, verbose=0)
|
||||
self.epoch_history['val_loss'].append(val_loss)
|
||||
epoch_logs['val_loss'] = val_loss
|
||||
|
||||
epoch_loss = sum(map(lambda x: x[0]*x[1], zip(self.batch_history['batch_size'], self.batch_history['loss']))) / len(y)
|
||||
self.epoch_history['loss'].append(epoch_loss)
|
||||
if show_accuracy:
|
||||
epoch_acc = sum(map(lambda x: x[0]*x[1], zip(self.batch_history['batch_size'], self.batch_history['accuracy']))) / len(y)
|
||||
self.epoch_history['accuracy'].append(epoch_acc)
|
||||
callbacks.on_epoch_end(epoch)
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
|
||||
callbacks.on_train_end()
|
||||
return self.epoch_history
|
||||
# return history
|
||||
return callbacks.callbacks[-1]
|
||||
|
||||
def predict(self, X, batch_size=128, verbose=1):
|
||||
X = standardize_X(X)
|
||||
@ -334,6 +316,7 @@ class Sequential(Model):
|
||||
if show_accuracy:
|
||||
tot_acc = 0.
|
||||
tot_score = 0.
|
||||
seen = 0
|
||||
|
||||
batches = make_batches(len(y), batch_size)
|
||||
if verbose:
|
||||
@ -345,21 +328,22 @@ class Sequential(Model):
|
||||
ins = X_batch + [y_batch]
|
||||
if show_accuracy:
|
||||
loss, acc = self._test_with_acc(*ins)
|
||||
tot_acc += acc
|
||||
tot_acc += acc * len(y_batch)
|
||||
log_values = [('loss', loss), ('acc.', acc)]
|
||||
else:
|
||||
loss = self._test(*ins)
|
||||
log_values = [('loss', loss)]
|
||||
tot_score += loss
|
||||
tot_score += loss * len(y_batch)
|
||||
seen += len(y_batch)
|
||||
|
||||
# logging
|
||||
if verbose:
|
||||
progbar.update(batch_end, log_values)
|
||||
|
||||
if show_accuracy:
|
||||
return tot_score/len(batches), tot_acc/len(batches)
|
||||
return tot_score / seen, tot_acc / seen
|
||||
else:
|
||||
return tot_score/len(batches)
|
||||
return tot_score / seen
|
||||
|
||||
def get_config(self, verbose=0):
|
||||
layers = []
|
||||
|
@ -60,7 +60,7 @@ class Progbar(object):
|
||||
'''
|
||||
for k, v in values:
|
||||
if k not in self.sum_values:
|
||||
self.sum_values[k] = [v * max(1, current-self.seen_so_far), current-self.seen_so_far]
|
||||
self.sum_values[k] = [v * (current-self.seen_so_far), current-self.seen_so_far]
|
||||
self.unique_values.append(k)
|
||||
else:
|
||||
self.sum_values[k][0] += v * (current-self.seen_so_far)
|
||||
|
Loading…
Reference in New Issue
Block a user