Add docstrings in callbacks, models, optimizers

This commit is contained in:
Francois Chollet 2015-12-12 12:36:23 -08:00
parent 06f5f43079
commit 9a93fc51cf
4 changed files with 474 additions and 135 deletions

@ -69,7 +69,30 @@ class CallbackList(object):
class Callback(object):
'''Abstract base class used to build new callbacks.
# Properties
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Sequential` model class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
'''
def __init__(self):
pass
@ -99,6 +122,12 @@ class Callback(object):
class BaseLogger(Callback):
'''Callback that prints events to the standard output.
This callback is automatically applied to
every Keras model (it is the basis of the verbosity modes
in models).
'''
def on_train_begin(self, logs={}):
self.verbose = self.params['verbose']
self.nb_epoch = self.params['nb_epoch']
@ -128,7 +157,8 @@ class BaseLogger(Callback):
if k in logs:
self.log_values.append((k, logs[k]))
# skip progbar update for the last batch; will be handled by on_epoch_end
# skip progbar update for the last batch;
# will be handled by on_epoch_end
if self.verbose and self.seen < self.params['nb_sample']:
self.progbar.update(self.seen, self.log_values)
@ -143,7 +173,13 @@ class BaseLogger(Callback):
class History(Callback):
'''Callback that records events
into a `History` object.
This callback is automatically applied to
every Keras model. The `History` object
gets returned by the `fit` method of models.
'''
def on_train_begin(self, logs={}):
self.epoch = []
self.history = {}
@ -175,26 +211,55 @@ class History(Callback):
class ModelCheckpoint(Callback):
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, mode='auto'):
'''Save the model after every epoch.
`filepath` can contain named formatting options,
which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`).
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then multiple files will be save with the epoch number and
the validation loss.
# Arguments
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the validation loss will not be overwritten.
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
to overwrite the current save file is made
based on either the maximization or the
minization of the monitored. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
'''
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, mode='auto'):
super(Callback, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
if mode not in ['auto', 'min', 'max']:
warnings.warn("ModelCheckpoint mode %s is unknown, fallback to auto mode" % (self.mode), RuntimeWarning)
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode' % (self.mode), RuntimeWarning)
mode = 'auto'
if mode == "min":
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == "max":
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if "acc" in self.monitor:
if 'acc' in self.monitor:
self.monitor_op = np.greater
self.best = -np.Inf
else:
@ -206,24 +271,36 @@ class ModelCheckpoint(Callback):
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn("Can save best model only with %s available, skipping." % (self.monitor), RuntimeWarning)
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print("Epoch %05d: %s improved from %0.5f to %0.5f, saving model to %s"
% (epoch, self.monitor, self.best, current, filepath))
print('Epoch %05d: %s improved from %0.5f to %0.5f, ' +
'saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
self.model.save_weights(filepath, overwrite=True)
else:
if self.verbose > 0:
print("Epoch %05d: %s did not improve" % (epoch, self.monitor))
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > 0:
print("Epoch %05d: saving model to %s" % (epoch, filepath))
print('Epoch %05d: saving model to %s' % (epoch, filepath))
self.model.save_weights(filepath, overwrite=True)
class EarlyStopping(Callback):
'''Stop training when a monitored quantity has stopped improving.
# Arguments
monitor: quantity to be monitored.
patience: number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
'''
def __init__(self, monitor='val_loss', patience=0, verbose=0):
super(Callback, self).__init__()
@ -236,7 +313,8 @@ class EarlyStopping(Callback):
def on_epoch_end(self, epoch, logs={}):
current = logs.get(self.monitor)
if current is None:
warnings.warn("Early stopping requires %s available!" % (self.monitor), RuntimeWarning)
warnings.warn('Early stopping requires %s available!' %
(self.monitor), RuntimeWarning)
if current < self.best:
self.best = current
@ -244,12 +322,16 @@ class EarlyStopping(Callback):
else:
if self.wait >= self.patience:
if self.verbose > 0:
print("Epoch %05d: early stopping" % (epoch))
print('Epoch %05d: early stopping' % (epoch))
self.model.stop_training = True
self.wait += 1
class RemoteMonitor(Callback):
'''Experimental callback used to stream events to a server.
Requires the `requests` library.
'''
def __init__(self, root='http://localhost:9000'):
self.root = root
@ -277,15 +359,20 @@ class RemoteMonitor(Callback):
send[k] = v
try:
r = requests.post(self.root + '/publish/epoch/end/', {'data': json.dumps(send)})
requests.post(self.root + '/publish/epoch/end/',
{'data': json.dumps(send)})
except:
print('Warning: could not reach RemoteMonitor root server at ' + str(self.root))
print('Warning: could not reach RemoteMonitor '
'root server at ' + str(self.root))
class LearningRateScheduler(Callback):
'''LearningRateScheduler
schedule is a function that gets an epoch number as input and returns a new
learning rate as output.
'''Learning rate scheduler.
# Arguments
schedule: a function that gets an epoch index as input
(integer, indexed from 0) and returns a new
learning rate as output.
'''
def __init__(self, schedule):
super(LearningRateScheduler, self).__init__()

@ -21,6 +21,7 @@ class Layer(object):
'''Abstract base layer class.
All Keras layers accept certain keyword arguments:
trainable: boolean. Set to "False" before model compilation
to freeze layer weights (they won't be updated further
during training).

@ -2,8 +2,6 @@ from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import warnings
import time
import copy
import pprint
from six.moves import range
import six
@ -11,11 +9,10 @@ import six
from . import backend as K
from . import optimizers
from . import objectives
from . import regularizers
from . import constraints
from . import callbacks as cbks
from .utils.layer_utils import container_from_config, model_summary
from .utils.generic_utils import Progbar, printv
from .utils.layer_utils import container_from_config
from .utils.layer_utils import model_summary
from .utils.generic_utils import Progbar
from .layers import containers
@ -52,6 +49,8 @@ def standardize_X(X):
def slice_X(X, start=None, stop=None):
'''
'''
if type(X) == list:
if hasattr(start, '__len__'):
# hdf5 dataset only support list object as indices
@ -71,9 +70,7 @@ def slice_X(X, start=None, stop=None):
def weighted_objective(fn):
def weighted(y_true, y_pred, weights, mask=None):
'''To be called only with non-zero weights.
mask: binary
'''
'''
# score_array has ndim >= 2
score_array = fn(y_true, y_pred)
@ -96,12 +93,15 @@ def weighted_objective(fn):
def standardize_weights(y, sample_weight=None, class_weight=None):
'''
'''
if sample_weight is not None:
assert len(sample_weight) == len(y)
return sample_weight.flatten()
elif isinstance(class_weight, dict):
if len(y.shape) > 2:
raise Exception('class_weight not supported for 3+ dimensional targets.')
raise Exception('class_weight not supported for '
'3+ dimensional targets.')
if y.shape[1] > 1:
y_classes = y.argmax(axis=1)
elif y.shape[1] == 1:
@ -132,6 +132,8 @@ def model_from_json(json_string, custom_objects={}):
def model_from_config(config, custom_objects={}):
'''
'''
model_name = config.get('name')
if model_name not in {'Graph', 'Sequential'}:
raise Exception('Unrecognized model:', model_name)
@ -147,7 +149,6 @@ def model_from_config(config, custom_objects={}):
# if it has an optimizer, the model is assumed to be compiled
loss = config.get('loss')
class_mode = config.get('class_mode')
theano_mode = config.get('theano_mode')
optimizer_params = dict([(k, v) for k, v in config.get('optimizer').items()])
optimizer_name = optimizer_params.pop('name')
@ -155,10 +156,9 @@ def model_from_config(config, custom_objects={}):
if model_name == 'Sequential':
model.compile(loss=loss, optimizer=optimizer,
class_mode=class_mode, theano_mode=theano_mode)
class_mode=class_mode)
elif model_name == 'Graph':
model.compile(loss=loss, optimizer=optimizer,
theano_mode=theano_mode)
model.compile(loss=loss, optimizer=optimizer)
return model
@ -170,6 +170,8 @@ def get_function_name(o):
class Model(object):
'''Abstract base model class.
'''
def _fit(self, f, ins, out_labels=[], batch_size=128,
nb_epoch=100, verbose=1, callbacks=[],
val_f=None, val_ins=None, shuffle=True, metrics=[]):
@ -181,7 +183,8 @@ class Model(object):
if val_f and val_ins:
do_validation = True
if verbose:
print("Train on %d samples, validate on %d samples" % (len(ins[0]), len(val_ins[0])))
print('Train on %d samples, validate on %d samples' %
(len(ins[0]), len(val_ins[0])))
nb_train_sample = len(ins[0])
index_array = np.arange(nb_train_sample)
@ -218,9 +221,9 @@ class Model(object):
try:
ins_batch = slice_X(ins, batch_ids)
except TypeError:
raise Exception('TypeError while preparing batch. \
If using HDF5 input data, pass shuffle="batch".\n')
raise Exception('TypeError while preparing batch. '
'If using HDF5 input data, '
'pass shuffle="batch".')
batch_logs = {}
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
@ -255,8 +258,7 @@ class Model(object):
return history
def _predict_loop(self, f, ins, batch_size=128, verbose=0):
'''
Abstract method to loop over some data in batches.
'''Abstract method to loop over some data in batches.
'''
nb_sample = len(ins[0])
outs = []
@ -283,8 +285,7 @@ class Model(object):
return outs
def _test_loop(self, f, ins, batch_size=128, verbose=0):
'''
Abstract method to loop over some data in batches.
'''Abstract method to loop over some data in batches.
'''
nb_sample = len(ins[0])
outs = []
@ -315,8 +316,14 @@ class Model(object):
return outs
def get_config(self, verbose=0):
'''Return the configuration of the model
as a dictionary.
To load a model from its configuration, use
`keras.models.model_from_config(config, custom_objects={})`.
'''
config = super(Model, self).get_config()
for p in ['class_mode', 'theano_mode']:
for p in ['class_mode']:
if hasattr(self, p):
config[p] = getattr(self, p)
if hasattr(self, 'optimizer'):
@ -333,38 +340,54 @@ class Model(object):
return config
def to_yaml(self, **kwargs):
# dump model configuration to yaml string
'''Return a yaml string containing the model configuration.
To load a model from a yaml save file, use
`keras.models.from_yaml(yaml_string, custom_objects={})`.
`custom_objects` should be a dictionary mapping
the names of custom losses / layers / etc to the corresponding
functions / classes.
'''
import yaml
config = self.get_config()
return yaml.dump(config, **kwargs)
def to_json(self, **kwargs):
# dump model configuration to json string
'''Return a JSON string containing the model configuration.
To load a model from a JSON save file, use
`keras.models.from_json(json_string, custom_objects={})`.
'''
import json
config = self.get_config()
return json.dumps(config, **kwargs)
def summary(self):
'''Print out a summary of the model architecture,
include parameter count information.
'''
model_summary(self)
class Sequential(Model, containers.Sequential):
'''
Inherits from Model the following methods:
- _fit
- _predict
- _evaluate
Inherits from containers.Sequential the following methods:
- __init__
- add
- get_output
- get_input
- get_weights
- set_weights
'''
'''Linear stack of layers.
Inherits from containers.Sequential.
'''
def compile(self, optimizer, loss,
class_mode="categorical", theano_mode=None):
class_mode="categorical"):
'''Configure the learning process.
# Arguments
optimizer: str (name of optimizer) or optimizer object.
See [optimizers](optimizers.md).
loss: str (name of objective function) or objective function.
See [objectives](objectives.md).
class_mode: one of "categorical", "binary".
This is only used for computing classification accuracy or
using the predict_classes method.
'''
self.optimizer = optimizers.get(optimizer)
self.loss = objectives.get(loss)
@ -401,7 +424,6 @@ class Sequential(Model, containers.Sequential):
else:
raise Exception("Invalid class mode:" + str(class_mode))
self.class_mode = class_mode
self.theano_mode = theano_mode
for r in self.regularizers:
train_loss = r(train_loss)
@ -426,37 +448,48 @@ class Sequential(Model, containers.Sequential):
self._test = K.function(test_ins, [test_loss])
self._test_with_acc = K.function(test_ins, [test_loss, test_accuracy])
def train_on_batch(self, X, y, accuracy=False,
class_weight=None, sample_weight=None):
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, class_weight=class_weight,
sample_weight=sample_weight)
ins = X + [y, sample_weight]
if accuracy:
return self._train_with_acc(ins)
else:
return self._train(ins)
def test_on_batch(self, X, y, accuracy=False, sample_weight=None):
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, sample_weight=sample_weight)
ins = X + [y, sample_weight]
if accuracy:
return self._test_with_acc(ins)
else:
return self._test(ins)
def predict_on_batch(self, X):
ins = standardize_X(X)
return self._predict(ins)
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True,
show_accuracy=False, class_weight=None, sample_weight=None):
'''Train the model for a fixed number of epochs.
Returns a history object. It `history` attribute is a record of
training loss values at successive epochs,
as well as validation loss values (if applicable).
# Arguments
X: data, as a numpy array.
y: labels, as a numpy array.
batch_size: int. Number of samples per gradient update.
nb_epoch: int.
verbose: 0 for no logging to stdout,
1 for progress bar logging, 2 for one log line per epoch.
callbacks: `keras.callbacks.Callback` list.
List of callbacks to apply during training.
See [callbacks](callbacks.md).
validation_split: float (0. < x < 1).
Fraction of the data to use as held-out validation data.
validation_data: tuple (X, y) to be used as held-out
validation data. Will override validation_split.
shuffle: boolean or str (for 'batch').
Whether to shuffle the samples at each epoch.
'batch' is a special option for dealing with the
limitations of HDF5 data; it shuffles in batch-sized chunks.
show_accuracy: boolean. Whether to display
class accuracy in the logs to stdout at each epoch.
class_weight: dictionary mapping classes to a weight value,
used for scaling the loss function (during training only).
sample_weight: list or numpy array with 1:1 mapping to
the training samples, used for scaling the loss function
(during training only). For time-distributed data,
there is one weight per sample *per timestep*,
i.e. if your output data is shaped
`(nb_samples, timesteps, output_dim)`,
your mask should be of shape `(nb_samples, timesteps, 1)`.
This allows you to mask out or reweight individual
output timesteps, which is useful
in sequence to sequence learning.
'''
X = standardize_X(X)
y = standardize_y(y)
@ -480,8 +513,11 @@ class Sequential(Model, containers.Sequential):
sample_weight_val = standardize_weights(y_val,
sample_weight=sample_weight_val)
else:
raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val) or (X_val, y_val, sample_weight). \
X_val may be a numpy array or a list of numpy arrays depending on your model input.")
raise Exception('Invalid format for validation data; '
'provide a tuple (X_val, y_val) or '
'(X_val, y_val, sample_weight). '
'X_val may be a numpy array or a list of '
'numpy arrays depending on your model input.')
val_ins = X_val + [y_val, sample_weight_val]
elif 0 < validation_split < 1:
@ -490,7 +526,8 @@ class Sequential(Model, containers.Sequential):
y, y_val = (slice_X(y, 0, split_at), slice_X(y, split_at))
if sample_weight is not None:
sample_weight, sample_weight_val = (slice_X(sample_weight, 0, split_at), slice_X(sample_weight, split_at))
sample_weight_val = standardize_weights(y_val, sample_weight=sample_weight_val)
sample_weight_val = standardize_weights(y_val,
sample_weight=sample_weight_val)
else:
sample_weight_val = standardize_weights(y_val)
val_ins = X_val + [y_val, sample_weight_val]
@ -502,7 +539,8 @@ class Sequential(Model, containers.Sequential):
f = self._train
out_labels = ['loss']
sample_weight = standardize_weights(y, class_weight=class_weight, sample_weight=sample_weight)
sample_weight = standardize_weights(y, class_weight=class_weight,
sample_weight=sample_weight)
ins = X + [y, sample_weight]
metrics = ['loss', 'acc', 'val_loss', 'val_acc']
return self._fit(f, ins, out_labels=out_labels,
@ -512,24 +550,67 @@ class Sequential(Model, containers.Sequential):
shuffle=shuffle, metrics=metrics)
def predict(self, X, batch_size=128, verbose=0):
'''Generate output predictions for the input samples
batch by batch.
# Arguments
X: the input data, as a numpy array.
batch_size: integer.
verbose: verbosity mode, 0 or 1.
# Returns
A numpy array of predictions.
'''
X = standardize_X(X)
return self._predict_loop(self._predict, X, batch_size, verbose)[0]
def predict_proba(self, X, batch_size=128, verbose=1):
'''Generate class probability predictions for the input samples
batch by batch.
# Arguments
X: the input data, as a numpy array.
batch_size: integer.
verbose: verbosity mode, 0 or 1.
# Returns
A numpy array of probability predictions.
'''
preds = self.predict(X, batch_size, verbose)
if preds.min() < 0 or preds.max() > 1:
warnings.warn("Network returning invalid probability values.")
warnings.warn('Network returning invalid probability values.')
return preds
def predict_classes(self, X, batch_size=128, verbose=1):
'''Generate class predictions for the input samples
batch by batch.
# Arguments
X: the input data, as a numpy array.
batch_size: integer.
verbose: verbosity mode, 0 or 1.
# Returns
A numpy array of class predictions.
'''
proba = self.predict(X, batch_size=batch_size, verbose=verbose)
if self.class_mode == "categorical":
if self.class_mode == 'categorical':
return proba.argmax(axis=-1)
else:
return (proba > 0.5).astype('int32')
def evaluate(self, X, y, batch_size=128, show_accuracy=False,
verbose=1, sample_weight=None):
'''Compute the loss on some input data, batch by batch.
# Arguments
X: input data, as a numpy array.
y: labels, as a numpy array.
batch_size: integer.
show_accuracy: boolean.
verbose: verbosity mode, 0 or 1.
sample_weight: sample weights, as a numpy array.
'''
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, sample_weight=sample_weight)
@ -545,8 +626,50 @@ class Sequential(Model, containers.Sequential):
else:
return outs[0]
def train_on_batch(self, X, y, accuracy=False,
class_weight=None, sample_weight=None):
'''Single gradient update over one batch of samples.
Returns the loss over the data,
or a tuple `(loss, accuracy)` if `accuracy=True`.
Arguments: see `fit` method.
'''
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, class_weight=class_weight,
sample_weight=sample_weight)
ins = X + [y, sample_weight]
if accuracy:
return self._train_with_acc(ins)
else:
return self._train(ins)
def test_on_batch(self, X, y, accuracy=False, sample_weight=None):
'''Returns the loss over a single batch of samples,
or a tuple `(loss, accuracy)` if `accuracy=True`.
Arguments: see `fit` method.
'''
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, sample_weight=sample_weight)
ins = X + [y, sample_weight]
if accuracy:
return self._test_with_acc(ins)
else:
return self._test(ins)
def predict_on_batch(self, X):
'''Returns predictions for a single batch of samples.
'''
ins = standardize_X(X)
return self._predict(ins)
def save_weights(self, filepath, overwrite=False):
# Save weights from all layers to HDF5
'''Dump all layer weights to a HDF5 file.
'''
import h5py
import os.path
# if file exists and should not be overwritten
@ -555,7 +678,8 @@ class Sequential(Model, containers.Sequential):
get_input = input
if sys.version_info[:2] <= (2, 7):
get_input = raw_input
overwrite = get_input('[WARNING] %s already exists - overwrite? [y/n]' % (filepath))
overwrite = get_input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath))
while overwrite not in ['y', 'n']:
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
if overwrite == 'n':
@ -570,20 +694,20 @@ class Sequential(Model, containers.Sequential):
g.attrs['nb_params'] = len(weights)
for n, param in enumerate(weights):
param_name = 'param_{}'.format(n)
param_dset = g.create_dataset(param_name, param.shape, dtype=param.dtype)
param_dset = g.create_dataset(param_name, param.shape,
dtype=param.dtype)
param_dset[:] = param
f.flush()
f.close()
def load_weights(self, filepath):
'''Load all layer weights from a HDF5 save file.
'''
This method does not make use of Sequential.set_weights()
for backwards compatibility.
'''
# Loads weights from HDF5 file
import h5py
f = h5py.File(filepath)
for k in range(f.attrs['nb_layers']):
# This method does not make use of Sequential.set_weights()
# for backwards compatibility.
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
self.layers[k].set_weights(weights)
@ -591,8 +715,24 @@ class Sequential(Model, containers.Sequential):
class Graph(Model, containers.Graph):
def compile(self, optimizer, loss, theano_mode=None):
# loss is a dictionary mapping output name to loss functions
'''Arbitrary connection graph.
It can have any number of inputs and outputs,
with each output trained with its own loss function.
The quantity being optimized by a Graph model is
the sum of all loss functions over the different outputs.
Inherits from `containers.Graph`.
'''
def compile(self, optimizer, loss):
'''Configure the learning process.
# Arguments
optimizer: str (name of optimizer) or optimizer object.
See [optimizers](optimizers.md).
loss: dictionary mapping the name(s) of the output(s) to
a loss function (string name of objective function or
objective function. See [objectives](objectives.md)).
'''
ys = []
ys_train = []
ys_test = []
@ -627,38 +767,49 @@ class Graph(Model, containers.Graph):
for r in self.regularizers:
train_loss = r(train_loss)
self.optimizer = optimizers.get(optimizer)
updates = self.optimizer.get_updates(self.params, self.constraints, train_loss)
updates = self.optimizer.get_updates(self.params,
self.constraints,
train_loss)
updates += self.updates
self.theano_mode = theano_mode
self.loss = loss
self._train = K.function(train_ins, [train_loss], updates=updates)
self._test = K.function(test_ins, [test_loss])
self._predict = K.function(inputs=ins, outputs=ys_test, updates=self.state_updates)
def train_on_batch(self, data, class_weight={}, sample_weight={}):
# data is a dictionary mapping output and input names to arrays
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name),
class_weight=class_weight.get(name)) for name in self.output_order]
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
return self._train(ins)
def test_on_batch(self, data, sample_weight={}):
# data is a dictionary mapping input names to arrays
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
return self._test(ins)
def predict_on_batch(self, data):
# data is a dictionary mapping input names to arrays
ins = [data[name] for name in self.input_order]
return self._predict(ins)
self._predict = K.function(inputs=ins, outputs=ys_test,
updates=self.state_updates)
def fit(self, data, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True,
class_weight={}, sample_weight={}):
'''Train the model for a fixed number of epochs.
Returns a history object. It `history` attribute is a record of
training loss values at successive epochs,
as well as validation loss values (if applicable).
# Arguments
data: dictionary mapping input names and outputs names to
appropriate numpy arrays. All arrays should contain
the same number of samples.
batch_size: int. Number of samples per gradient update.
nb_epoch: int.
verbose: 0 for no logging to stdout,
1 for progress bar logging, 2 for one log line per epoch.
callbacks: `keras.callbacks.Callback` list. List of callbacks
to apply during training. See [callbacks](callbacks.md).
validation_split: float (0. < x < 1). Fraction of the data to
use as held-out validation data.
validation_data: dictionary mapping input names and outputs names
to appropriate numpy arrays to be used as
held-out validation data.
All arrays should contain the same number of samples.
Will override validation_split.
shuffle: boolean. Whether to shuffle the samples at each epoch.
class_weight: dictionary mapping output names to
class weight dictionaries.
sample_weight: dictionary mapping output names to
numpy arrays of sample weights.
'''
X = [data[name] for name in self.input_order]
y = [standardize_y(data[name]) for name in self.output_order]
@ -691,13 +842,18 @@ class Graph(Model, containers.Graph):
sample_weight=sample_weight_list[i],
class_weight=class_weight_list[i]) for i in range(len(self.output_order))]
ins = X + y + sample_weight_list
history = self._fit(f, ins, out_labels=out_labels, batch_size=batch_size, nb_epoch=nb_epoch,
history = self._fit(f, ins, out_labels=out_labels,
batch_size=batch_size, nb_epoch=nb_epoch,
verbose=verbose, callbacks=callbacks,
val_f=val_f, val_ins=val_ins,
shuffle=shuffle, metrics=metrics)
return history
def evaluate(self, data, batch_size=128, verbose=0, sample_weight={}):
'''Compute the loss on some input data, batch by batch.
Arguments: see `fit` method.
'''
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
@ -706,12 +862,45 @@ class Graph(Model, containers.Graph):
return outs[0]
def predict(self, data, batch_size=128, verbose=0):
'''Generate output predictions for the input samples
batch by batch.
Arguments: see `fit` method.
'''
ins = [data[name] for name in self.input_order]
outs = self._predict_loop(self._predict, ins, batch_size, verbose)
return dict(zip(self.output_order, outs))
def train_on_batch(self, data, class_weight={}, sample_weight={}):
'''Single gradient update on a batch of samples.
Arguments: see `fit` method.
'''
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name),
class_weight=class_weight.get(name)) for name in self.output_order]
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
return self._train(ins)
def test_on_batch(self, data, sample_weight={}):
'''Compute the loss on a single batch of samples.
Arguments: see `fit` method.
'''
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
return self._test(ins)
def predict_on_batch(self, data):
'''Generate predictions for a single batch of samples.
'''
ins = [data[name] for name in self.input_order]
return self._predict(ins)
def save_weights(self, filepath, overwrite=False):
# Save weights from all layers to HDF5
'''Save weights from all layers to a HDF5 files.
'''
import h5py
import os.path
# if file exists and should not be overwritten
@ -720,7 +909,8 @@ class Graph(Model, containers.Graph):
get_input = input
if sys.version_info[:2] <= (2, 7):
get_input = raw_input
overwrite = get_input('[WARNING] %s already exists - overwrite? [y/n]' % (filepath))
overwrite = get_input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath))
while overwrite not in ['y', 'n']:
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
if overwrite == 'n':
@ -733,13 +923,15 @@ class Graph(Model, containers.Graph):
g.attrs['nb_params'] = len(weights)
for n, param in enumerate(weights):
param_name = 'param_{}'.format(n)
param_dset = g.create_dataset(param_name, param.shape, dtype=param.dtype)
param_dset = g.create_dataset(param_name, param.shape,
dtype=param.dtype)
param_dset[:] = param
f.flush()
f.close()
def load_weights(self, filepath):
# Loads weights from HDF5 file
'''Load weights from a HDF5 file.
'''
import h5py
f = h5py.File(filepath)
g = f['graph']

@ -16,6 +16,18 @@ def kl_divergence(p, p_hat):
class Optimizer(object):
'''Abstract optimizer base class.
Note: this is the parent class of all optimizers, not an actual optimizer
that can be used for training models.
All Keras optimizers support the following keyword arguments:
clipnorm: float >= 0. Gradients will be clipped
when their L2 norm exceeds this value.
clipvalue: float >= 0. Gradients will be clipped
when their absolute value exceeds this value.
'''
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
self.updates = []
@ -45,7 +57,15 @@ class Optimizer(object):
class SGD(Optimizer):
'''Stochastic gradient descent, with support for momentum,
decay, and Nesterov momentum.
# Arguments
lr: float >= 0. Learning rate.
momentum: float >= 0. Parameter updates momentum.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
'''
def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False,
*args, **kwargs):
super(SGD, self).__init__(**kwargs)
@ -82,6 +102,19 @@ class SGD(Optimizer):
class RMSprop(Optimizer):
'''RMSProp optimizer.
It is recommended to leave the parameters of this optimizer
at their default values.
This optimizer is usually a good choice for recurrent
neural networks.
# Arguments
lr: float >= 0. Learning rate.
rho: float >= 0.
epsilon: float >= 0. Fuzz factor.
'''
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs):
super(RMSprop, self).__init__(**kwargs)
self.__dict__.update(locals())
@ -110,6 +143,15 @@ class RMSprop(Optimizer):
class Adagrad(Optimizer):
'''Adagrad optimizer.
It is recommended to leave the parameters of this optimizer
at their default values.
# Arguments
lr: float >= 0. Learning rate.
epsilon: float >= 0.
'''
def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs):
super(Adagrad, self).__init__(**kwargs)
self.__dict__.update(locals())
@ -134,8 +176,18 @@ class Adagrad(Optimizer):
class Adadelta(Optimizer):
'''
Reference: http://arxiv.org/abs/1212.5701
'''Adadelta optimizer.
It is recommended to leave the parameters of this optimizer
at their default values.
# Arguments
lr: float >= 0. Learning rate. It is recommended to leave it at the default value.
rho: float >= 0.
epsilon: float >= 0. Fuzz factor.
# References
- [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701)
'''
def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):
super(Adadelta, self).__init__(**kwargs)
@ -173,10 +225,17 @@ class Adadelta(Optimizer):
class Adam(Optimizer):
'''
Reference: http://arxiv.org/abs/1412.6980v8
'''Adam optimizer.
Default parameters follow those provided in the original paper.
Default parameters follow those provided in the original paper.
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8,
*args, **kwargs):