## Usage of callbacks A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument `callbacks`) to the `.fit()` method of the `Sequential` model. The relevant methods of the callbacks will then be called at each stage of the training. --- {{autogenerated}} --- # Create a callback You can create a custom callback by extending the base class `keras.callbacks.Callback`. A callback has access to its associated model through the class property `self.model`. Here's a simple example saving a list of losses over each batch during training: ```python class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) ``` --- ### Example: recording loss history ```python class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) model = Sequential() model.add(Dense(10, input_dim=784, init='uniform')) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') history = LossHistory() model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, callbacks=[history]) print history.losses # outputs ''' [0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789] ''' ``` --- ### Example: model checkpoints ```python from keras.callbacks import ModelCheckpoint model = Sequential() model.add(Dense(10, input_dim=784, init='uniform')) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') ''' saves the model weights after each epoch if the validation loss decreased ''' checkpointer = ModelCheckpoint(filepath="/tmp/weights.hdf5", verbose=1, save_best_only=True) model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer]) ```