Add docstrings in callbacks, models, optimizers
This commit is contained in:
parent
06f5f43079
commit
9a93fc51cf
@ -69,7 +69,30 @@ class CallbackList(object):
|
||||
|
||||
|
||||
class Callback(object):
|
||||
'''Abstract base class used to build new callbacks.
|
||||
|
||||
# Properties
|
||||
params: dict. Training parameters
|
||||
(eg. verbosity, batch size, number of epochs...).
|
||||
model: instance of `keras.models.Model`.
|
||||
Reference of the model being trained.
|
||||
|
||||
The `logs` dictionary that callback methods
|
||||
take as argument will contain keys for quantities relevant to
|
||||
the current batch or epoch.
|
||||
|
||||
Currently, the `.fit()` method of the `Sequential` model class
|
||||
will include the following quantities in the `logs` that
|
||||
it passes to its callbacks:
|
||||
|
||||
on_epoch_end: logs optionally include `val_loss`
|
||||
(if validation is enabled in `fit`), and `val_acc`
|
||||
(if validation and accuracy monitoring are enabled).
|
||||
on_batch_begin: logs include `size`,
|
||||
the number of samples in the current batch.
|
||||
on_batch_end: logs include `loss`, and optionally `acc`
|
||||
(if accuracy monitoring is enabled).
|
||||
'''
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@ -99,6 +122,12 @@ class Callback(object):
|
||||
|
||||
|
||||
class BaseLogger(Callback):
|
||||
'''Callback that prints events to the standard output.
|
||||
|
||||
This callback is automatically applied to
|
||||
every Keras model (it is the basis of the verbosity modes
|
||||
in models).
|
||||
'''
|
||||
def on_train_begin(self, logs={}):
|
||||
self.verbose = self.params['verbose']
|
||||
self.nb_epoch = self.params['nb_epoch']
|
||||
@ -128,7 +157,8 @@ class BaseLogger(Callback):
|
||||
if k in logs:
|
||||
self.log_values.append((k, logs[k]))
|
||||
|
||||
# skip progbar update for the last batch; will be handled by on_epoch_end
|
||||
# skip progbar update for the last batch;
|
||||
# will be handled by on_epoch_end
|
||||
if self.verbose and self.seen < self.params['nb_sample']:
|
||||
self.progbar.update(self.seen, self.log_values)
|
||||
|
||||
@ -143,7 +173,13 @@ class BaseLogger(Callback):
|
||||
|
||||
|
||||
class History(Callback):
|
||||
'''Callback that records events
|
||||
into a `History` object.
|
||||
|
||||
This callback is automatically applied to
|
||||
every Keras model. The `History` object
|
||||
gets returned by the `fit` method of models.
|
||||
'''
|
||||
def on_train_begin(self, logs={}):
|
||||
self.epoch = []
|
||||
self.history = {}
|
||||
@ -175,26 +211,55 @@ class History(Callback):
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, mode='auto'):
|
||||
'''Save the model after every epoch.
|
||||
|
||||
`filepath` can contain named formatting options,
|
||||
which will be filled the value of `epoch` and
|
||||
keys in `logs` (passed in `on_epoch_end`).
|
||||
|
||||
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
|
||||
then multiple files will be save with the epoch number and
|
||||
the validation loss.
|
||||
|
||||
# Arguments
|
||||
filepath: string, path to save the model file.
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
save_best_only: if `save_best_only=True`,
|
||||
the latest best model according to
|
||||
the validation loss will not be overwritten.
|
||||
mode: one of {auto, min, max}.
|
||||
If `save_best_only=True`, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
minization of the monitored. For `val_acc`,
|
||||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
|
||||
'''
|
||||
def __init__(self, filepath, monitor='val_loss', verbose=0,
|
||||
save_best_only=False, mode='auto'):
|
||||
|
||||
super(Callback, self).__init__()
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
self.save_best_only = save_best_only
|
||||
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
warnings.warn("ModelCheckpoint mode %s is unknown, fallback to auto mode" % (self.mode), RuntimeWarning)
|
||||
warnings.warn('ModelCheckpoint mode %s is unknown, '
|
||||
'fallback to auto mode' % (self.mode), RuntimeWarning)
|
||||
mode = 'auto'
|
||||
|
||||
if mode == "min":
|
||||
|
||||
if mode == 'min':
|
||||
self.monitor_op = np.less
|
||||
self.best = np.Inf
|
||||
elif mode == "max":
|
||||
elif mode == 'max':
|
||||
self.monitor_op = np.greater
|
||||
self.best = -np.Inf
|
||||
else:
|
||||
if "acc" in self.monitor:
|
||||
if 'acc' in self.monitor:
|
||||
self.monitor_op = np.greater
|
||||
self.best = -np.Inf
|
||||
else:
|
||||
@ -206,24 +271,36 @@ class ModelCheckpoint(Callback):
|
||||
if self.save_best_only:
|
||||
current = logs.get(self.monitor)
|
||||
if current is None:
|
||||
warnings.warn("Can save best model only with %s available, skipping." % (self.monitor), RuntimeWarning)
|
||||
warnings.warn('Can save best model only with %s available, '
|
||||
'skipping.' % (self.monitor), RuntimeWarning)
|
||||
else:
|
||||
if self.monitor_op(current, self.best):
|
||||
if self.verbose > 0:
|
||||
print("Epoch %05d: %s improved from %0.5f to %0.5f, saving model to %s"
|
||||
% (epoch, self.monitor, self.best, current, filepath))
|
||||
print('Epoch %05d: %s improved from %0.5f to %0.5f, ' +
|
||||
'saving model to %s'
|
||||
% (epoch, self.monitor, self.best,
|
||||
current, filepath))
|
||||
self.best = current
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print("Epoch %05d: %s did not improve" % (epoch, self.monitor))
|
||||
print('Epoch %05d: %s did not improve' %
|
||||
(epoch, self.monitor))
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print("Epoch %05d: saving model to %s" % (epoch, filepath))
|
||||
print('Epoch %05d: saving model to %s' % (epoch, filepath))
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
'''Stop training when a monitored quantity has stopped improving.
|
||||
|
||||
# Arguments
|
||||
monitor: quantity to be monitored.
|
||||
patience: number of epochs with no improvement
|
||||
after which training will be stopped.
|
||||
verbose: verbosity mode.
|
||||
'''
|
||||
def __init__(self, monitor='val_loss', patience=0, verbose=0):
|
||||
super(Callback, self).__init__()
|
||||
|
||||
@ -236,7 +313,8 @@ class EarlyStopping(Callback):
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
current = logs.get(self.monitor)
|
||||
if current is None:
|
||||
warnings.warn("Early stopping requires %s available!" % (self.monitor), RuntimeWarning)
|
||||
warnings.warn('Early stopping requires %s available!' %
|
||||
(self.monitor), RuntimeWarning)
|
||||
|
||||
if current < self.best:
|
||||
self.best = current
|
||||
@ -244,12 +322,16 @@ class EarlyStopping(Callback):
|
||||
else:
|
||||
if self.wait >= self.patience:
|
||||
if self.verbose > 0:
|
||||
print("Epoch %05d: early stopping" % (epoch))
|
||||
print('Epoch %05d: early stopping' % (epoch))
|
||||
self.model.stop_training = True
|
||||
self.wait += 1
|
||||
|
||||
|
||||
class RemoteMonitor(Callback):
|
||||
'''Experimental callback used to stream events to a server.
|
||||
|
||||
Requires the `requests` library.
|
||||
'''
|
||||
def __init__(self, root='http://localhost:9000'):
|
||||
self.root = root
|
||||
|
||||
@ -277,15 +359,20 @@ class RemoteMonitor(Callback):
|
||||
send[k] = v
|
||||
|
||||
try:
|
||||
r = requests.post(self.root + '/publish/epoch/end/', {'data': json.dumps(send)})
|
||||
requests.post(self.root + '/publish/epoch/end/',
|
||||
{'data': json.dumps(send)})
|
||||
except:
|
||||
print('Warning: could not reach RemoteMonitor root server at ' + str(self.root))
|
||||
print('Warning: could not reach RemoteMonitor '
|
||||
'root server at ' + str(self.root))
|
||||
|
||||
|
||||
class LearningRateScheduler(Callback):
|
||||
'''LearningRateScheduler
|
||||
schedule is a function that gets an epoch number as input and returns a new
|
||||
learning rate as output.
|
||||
'''Learning rate scheduler.
|
||||
|
||||
# Arguments
|
||||
schedule: a function that gets an epoch index as input
|
||||
(integer, indexed from 0) and returns a new
|
||||
learning rate as output.
|
||||
'''
|
||||
def __init__(self, schedule):
|
||||
super(LearningRateScheduler, self).__init__()
|
||||
|
@ -21,6 +21,7 @@ class Layer(object):
|
||||
'''Abstract base layer class.
|
||||
|
||||
All Keras layers accept certain keyword arguments:
|
||||
|
||||
trainable: boolean. Set to "False" before model compilation
|
||||
to freeze layer weights (they won't be updated further
|
||||
during training).
|
||||
|
412
keras/models.py
412
keras/models.py
@ -2,8 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import warnings
|
||||
import time
|
||||
import copy
|
||||
import pprint
|
||||
from six.moves import range
|
||||
import six
|
||||
@ -11,11 +9,10 @@ import six
|
||||
from . import backend as K
|
||||
from . import optimizers
|
||||
from . import objectives
|
||||
from . import regularizers
|
||||
from . import constraints
|
||||
from . import callbacks as cbks
|
||||
from .utils.layer_utils import container_from_config, model_summary
|
||||
from .utils.generic_utils import Progbar, printv
|
||||
from .utils.layer_utils import container_from_config
|
||||
from .utils.layer_utils import model_summary
|
||||
from .utils.generic_utils import Progbar
|
||||
from .layers import containers
|
||||
|
||||
|
||||
@ -52,6 +49,8 @@ def standardize_X(X):
|
||||
|
||||
|
||||
def slice_X(X, start=None, stop=None):
|
||||
'''
|
||||
'''
|
||||
if type(X) == list:
|
||||
if hasattr(start, '__len__'):
|
||||
# hdf5 dataset only support list object as indices
|
||||
@ -71,9 +70,7 @@ def slice_X(X, start=None, stop=None):
|
||||
|
||||
def weighted_objective(fn):
|
||||
def weighted(y_true, y_pred, weights, mask=None):
|
||||
'''To be called only with non-zero weights.
|
||||
|
||||
mask: binary
|
||||
'''
|
||||
'''
|
||||
# score_array has ndim >= 2
|
||||
score_array = fn(y_true, y_pred)
|
||||
@ -96,12 +93,15 @@ def weighted_objective(fn):
|
||||
|
||||
|
||||
def standardize_weights(y, sample_weight=None, class_weight=None):
|
||||
'''
|
||||
'''
|
||||
if sample_weight is not None:
|
||||
assert len(sample_weight) == len(y)
|
||||
return sample_weight.flatten()
|
||||
elif isinstance(class_weight, dict):
|
||||
if len(y.shape) > 2:
|
||||
raise Exception('class_weight not supported for 3+ dimensional targets.')
|
||||
raise Exception('class_weight not supported for '
|
||||
'3+ dimensional targets.')
|
||||
if y.shape[1] > 1:
|
||||
y_classes = y.argmax(axis=1)
|
||||
elif y.shape[1] == 1:
|
||||
@ -132,6 +132,8 @@ def model_from_json(json_string, custom_objects={}):
|
||||
|
||||
|
||||
def model_from_config(config, custom_objects={}):
|
||||
'''
|
||||
'''
|
||||
model_name = config.get('name')
|
||||
if model_name not in {'Graph', 'Sequential'}:
|
||||
raise Exception('Unrecognized model:', model_name)
|
||||
@ -147,7 +149,6 @@ def model_from_config(config, custom_objects={}):
|
||||
# if it has an optimizer, the model is assumed to be compiled
|
||||
loss = config.get('loss')
|
||||
class_mode = config.get('class_mode')
|
||||
theano_mode = config.get('theano_mode')
|
||||
|
||||
optimizer_params = dict([(k, v) for k, v in config.get('optimizer').items()])
|
||||
optimizer_name = optimizer_params.pop('name')
|
||||
@ -155,10 +156,9 @@ def model_from_config(config, custom_objects={}):
|
||||
|
||||
if model_name == 'Sequential':
|
||||
model.compile(loss=loss, optimizer=optimizer,
|
||||
class_mode=class_mode, theano_mode=theano_mode)
|
||||
class_mode=class_mode)
|
||||
elif model_name == 'Graph':
|
||||
model.compile(loss=loss, optimizer=optimizer,
|
||||
theano_mode=theano_mode)
|
||||
model.compile(loss=loss, optimizer=optimizer)
|
||||
return model
|
||||
|
||||
|
||||
@ -170,6 +170,8 @@ def get_function_name(o):
|
||||
|
||||
|
||||
class Model(object):
|
||||
'''Abstract base model class.
|
||||
'''
|
||||
def _fit(self, f, ins, out_labels=[], batch_size=128,
|
||||
nb_epoch=100, verbose=1, callbacks=[],
|
||||
val_f=None, val_ins=None, shuffle=True, metrics=[]):
|
||||
@ -181,7 +183,8 @@ class Model(object):
|
||||
if val_f and val_ins:
|
||||
do_validation = True
|
||||
if verbose:
|
||||
print("Train on %d samples, validate on %d samples" % (len(ins[0]), len(val_ins[0])))
|
||||
print('Train on %d samples, validate on %d samples' %
|
||||
(len(ins[0]), len(val_ins[0])))
|
||||
|
||||
nb_train_sample = len(ins[0])
|
||||
index_array = np.arange(nb_train_sample)
|
||||
@ -218,9 +221,9 @@ class Model(object):
|
||||
try:
|
||||
ins_batch = slice_X(ins, batch_ids)
|
||||
except TypeError:
|
||||
raise Exception('TypeError while preparing batch. \
|
||||
If using HDF5 input data, pass shuffle="batch".\n')
|
||||
|
||||
raise Exception('TypeError while preparing batch. '
|
||||
'If using HDF5 input data, '
|
||||
'pass shuffle="batch".')
|
||||
batch_logs = {}
|
||||
batch_logs['batch'] = batch_index
|
||||
batch_logs['size'] = len(batch_ids)
|
||||
@ -255,8 +258,7 @@ class Model(object):
|
||||
return history
|
||||
|
||||
def _predict_loop(self, f, ins, batch_size=128, verbose=0):
|
||||
'''
|
||||
Abstract method to loop over some data in batches.
|
||||
'''Abstract method to loop over some data in batches.
|
||||
'''
|
||||
nb_sample = len(ins[0])
|
||||
outs = []
|
||||
@ -283,8 +285,7 @@ class Model(object):
|
||||
return outs
|
||||
|
||||
def _test_loop(self, f, ins, batch_size=128, verbose=0):
|
||||
'''
|
||||
Abstract method to loop over some data in batches.
|
||||
'''Abstract method to loop over some data in batches.
|
||||
'''
|
||||
nb_sample = len(ins[0])
|
||||
outs = []
|
||||
@ -315,8 +316,14 @@ class Model(object):
|
||||
return outs
|
||||
|
||||
def get_config(self, verbose=0):
|
||||
'''Return the configuration of the model
|
||||
as a dictionary.
|
||||
|
||||
To load a model from its configuration, use
|
||||
`keras.models.model_from_config(config, custom_objects={})`.
|
||||
'''
|
||||
config = super(Model, self).get_config()
|
||||
for p in ['class_mode', 'theano_mode']:
|
||||
for p in ['class_mode']:
|
||||
if hasattr(self, p):
|
||||
config[p] = getattr(self, p)
|
||||
if hasattr(self, 'optimizer'):
|
||||
@ -333,38 +340,54 @@ class Model(object):
|
||||
return config
|
||||
|
||||
def to_yaml(self, **kwargs):
|
||||
# dump model configuration to yaml string
|
||||
'''Return a yaml string containing the model configuration.
|
||||
|
||||
To load a model from a yaml save file, use
|
||||
`keras.models.from_yaml(yaml_string, custom_objects={})`.
|
||||
|
||||
`custom_objects` should be a dictionary mapping
|
||||
the names of custom losses / layers / etc to the corresponding
|
||||
functions / classes.
|
||||
'''
|
||||
import yaml
|
||||
config = self.get_config()
|
||||
return yaml.dump(config, **kwargs)
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
# dump model configuration to json string
|
||||
'''Return a JSON string containing the model configuration.
|
||||
|
||||
To load a model from a JSON save file, use
|
||||
`keras.models.from_json(json_string, custom_objects={})`.
|
||||
'''
|
||||
import json
|
||||
config = self.get_config()
|
||||
return json.dumps(config, **kwargs)
|
||||
|
||||
def summary(self):
|
||||
'''Print out a summary of the model architecture,
|
||||
include parameter count information.
|
||||
'''
|
||||
model_summary(self)
|
||||
|
||||
|
||||
class Sequential(Model, containers.Sequential):
|
||||
'''
|
||||
Inherits from Model the following methods:
|
||||
- _fit
|
||||
- _predict
|
||||
- _evaluate
|
||||
Inherits from containers.Sequential the following methods:
|
||||
- __init__
|
||||
- add
|
||||
- get_output
|
||||
- get_input
|
||||
- get_weights
|
||||
- set_weights
|
||||
'''
|
||||
'''Linear stack of layers.
|
||||
|
||||
Inherits from containers.Sequential.
|
||||
'''
|
||||
def compile(self, optimizer, loss,
|
||||
class_mode="categorical", theano_mode=None):
|
||||
class_mode="categorical"):
|
||||
'''Configure the learning process.
|
||||
|
||||
# Arguments
|
||||
optimizer: str (name of optimizer) or optimizer object.
|
||||
See [optimizers](optimizers.md).
|
||||
loss: str (name of objective function) or objective function.
|
||||
See [objectives](objectives.md).
|
||||
class_mode: one of "categorical", "binary".
|
||||
This is only used for computing classification accuracy or
|
||||
using the predict_classes method.
|
||||
'''
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
|
||||
self.loss = objectives.get(loss)
|
||||
@ -401,7 +424,6 @@ class Sequential(Model, containers.Sequential):
|
||||
else:
|
||||
raise Exception("Invalid class mode:" + str(class_mode))
|
||||
self.class_mode = class_mode
|
||||
self.theano_mode = theano_mode
|
||||
|
||||
for r in self.regularizers:
|
||||
train_loss = r(train_loss)
|
||||
@ -426,37 +448,48 @@ class Sequential(Model, containers.Sequential):
|
||||
self._test = K.function(test_ins, [test_loss])
|
||||
self._test_with_acc = K.function(test_ins, [test_loss, test_accuracy])
|
||||
|
||||
def train_on_batch(self, X, y, accuracy=False,
|
||||
class_weight=None, sample_weight=None):
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
sample_weight = standardize_weights(y, class_weight=class_weight,
|
||||
sample_weight=sample_weight)
|
||||
ins = X + [y, sample_weight]
|
||||
if accuracy:
|
||||
return self._train_with_acc(ins)
|
||||
else:
|
||||
return self._train(ins)
|
||||
|
||||
def test_on_batch(self, X, y, accuracy=False, sample_weight=None):
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
sample_weight = standardize_weights(y, sample_weight=sample_weight)
|
||||
|
||||
ins = X + [y, sample_weight]
|
||||
if accuracy:
|
||||
return self._test_with_acc(ins)
|
||||
else:
|
||||
return self._test(ins)
|
||||
|
||||
def predict_on_batch(self, X):
|
||||
ins = standardize_X(X)
|
||||
return self._predict(ins)
|
||||
|
||||
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
|
||||
validation_split=0., validation_data=None, shuffle=True,
|
||||
show_accuracy=False, class_weight=None, sample_weight=None):
|
||||
'''Train the model for a fixed number of epochs.
|
||||
|
||||
Returns a history object. It `history` attribute is a record of
|
||||
training loss values at successive epochs,
|
||||
as well as validation loss values (if applicable).
|
||||
|
||||
# Arguments
|
||||
X: data, as a numpy array.
|
||||
y: labels, as a numpy array.
|
||||
batch_size: int. Number of samples per gradient update.
|
||||
nb_epoch: int.
|
||||
verbose: 0 for no logging to stdout,
|
||||
1 for progress bar logging, 2 for one log line per epoch.
|
||||
callbacks: `keras.callbacks.Callback` list.
|
||||
List of callbacks to apply during training.
|
||||
See [callbacks](callbacks.md).
|
||||
validation_split: float (0. < x < 1).
|
||||
Fraction of the data to use as held-out validation data.
|
||||
validation_data: tuple (X, y) to be used as held-out
|
||||
validation data. Will override validation_split.
|
||||
shuffle: boolean or str (for 'batch').
|
||||
Whether to shuffle the samples at each epoch.
|
||||
'batch' is a special option for dealing with the
|
||||
limitations of HDF5 data; it shuffles in batch-sized chunks.
|
||||
show_accuracy: boolean. Whether to display
|
||||
class accuracy in the logs to stdout at each epoch.
|
||||
class_weight: dictionary mapping classes to a weight value,
|
||||
used for scaling the loss function (during training only).
|
||||
sample_weight: list or numpy array with 1:1 mapping to
|
||||
the training samples, used for scaling the loss function
|
||||
(during training only). For time-distributed data,
|
||||
there is one weight per sample *per timestep*,
|
||||
i.e. if your output data is shaped
|
||||
`(nb_samples, timesteps, output_dim)`,
|
||||
your mask should be of shape `(nb_samples, timesteps, 1)`.
|
||||
This allows you to mask out or reweight individual
|
||||
output timesteps, which is useful
|
||||
in sequence to sequence learning.
|
||||
'''
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
|
||||
@ -480,8 +513,11 @@ class Sequential(Model, containers.Sequential):
|
||||
sample_weight_val = standardize_weights(y_val,
|
||||
sample_weight=sample_weight_val)
|
||||
else:
|
||||
raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val) or (X_val, y_val, sample_weight). \
|
||||
X_val may be a numpy array or a list of numpy arrays depending on your model input.")
|
||||
raise Exception('Invalid format for validation data; '
|
||||
'provide a tuple (X_val, y_val) or '
|
||||
'(X_val, y_val, sample_weight). '
|
||||
'X_val may be a numpy array or a list of '
|
||||
'numpy arrays depending on your model input.')
|
||||
val_ins = X_val + [y_val, sample_weight_val]
|
||||
|
||||
elif 0 < validation_split < 1:
|
||||
@ -490,7 +526,8 @@ class Sequential(Model, containers.Sequential):
|
||||
y, y_val = (slice_X(y, 0, split_at), slice_X(y, split_at))
|
||||
if sample_weight is not None:
|
||||
sample_weight, sample_weight_val = (slice_X(sample_weight, 0, split_at), slice_X(sample_weight, split_at))
|
||||
sample_weight_val = standardize_weights(y_val, sample_weight=sample_weight_val)
|
||||
sample_weight_val = standardize_weights(y_val,
|
||||
sample_weight=sample_weight_val)
|
||||
else:
|
||||
sample_weight_val = standardize_weights(y_val)
|
||||
val_ins = X_val + [y_val, sample_weight_val]
|
||||
@ -502,7 +539,8 @@ class Sequential(Model, containers.Sequential):
|
||||
f = self._train
|
||||
out_labels = ['loss']
|
||||
|
||||
sample_weight = standardize_weights(y, class_weight=class_weight, sample_weight=sample_weight)
|
||||
sample_weight = standardize_weights(y, class_weight=class_weight,
|
||||
sample_weight=sample_weight)
|
||||
ins = X + [y, sample_weight]
|
||||
metrics = ['loss', 'acc', 'val_loss', 'val_acc']
|
||||
return self._fit(f, ins, out_labels=out_labels,
|
||||
@ -512,24 +550,67 @@ class Sequential(Model, containers.Sequential):
|
||||
shuffle=shuffle, metrics=metrics)
|
||||
|
||||
def predict(self, X, batch_size=128, verbose=0):
|
||||
'''Generate output predictions for the input samples
|
||||
batch by batch.
|
||||
|
||||
# Arguments
|
||||
X: the input data, as a numpy array.
|
||||
batch_size: integer.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
|
||||
# Returns
|
||||
A numpy array of predictions.
|
||||
'''
|
||||
X = standardize_X(X)
|
||||
return self._predict_loop(self._predict, X, batch_size, verbose)[0]
|
||||
|
||||
def predict_proba(self, X, batch_size=128, verbose=1):
|
||||
'''Generate class probability predictions for the input samples
|
||||
batch by batch.
|
||||
|
||||
# Arguments
|
||||
X: the input data, as a numpy array.
|
||||
batch_size: integer.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
|
||||
# Returns
|
||||
A numpy array of probability predictions.
|
||||
'''
|
||||
preds = self.predict(X, batch_size, verbose)
|
||||
if preds.min() < 0 or preds.max() > 1:
|
||||
warnings.warn("Network returning invalid probability values.")
|
||||
warnings.warn('Network returning invalid probability values.')
|
||||
return preds
|
||||
|
||||
def predict_classes(self, X, batch_size=128, verbose=1):
|
||||
'''Generate class predictions for the input samples
|
||||
batch by batch.
|
||||
|
||||
# Arguments
|
||||
X: the input data, as a numpy array.
|
||||
batch_size: integer.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
|
||||
# Returns
|
||||
A numpy array of class predictions.
|
||||
'''
|
||||
proba = self.predict(X, batch_size=batch_size, verbose=verbose)
|
||||
if self.class_mode == "categorical":
|
||||
if self.class_mode == 'categorical':
|
||||
return proba.argmax(axis=-1)
|
||||
else:
|
||||
return (proba > 0.5).astype('int32')
|
||||
|
||||
def evaluate(self, X, y, batch_size=128, show_accuracy=False,
|
||||
verbose=1, sample_weight=None):
|
||||
'''Compute the loss on some input data, batch by batch.
|
||||
|
||||
# Arguments
|
||||
X: input data, as a numpy array.
|
||||
y: labels, as a numpy array.
|
||||
batch_size: integer.
|
||||
show_accuracy: boolean.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
sample_weight: sample weights, as a numpy array.
|
||||
'''
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
sample_weight = standardize_weights(y, sample_weight=sample_weight)
|
||||
@ -545,8 +626,50 @@ class Sequential(Model, containers.Sequential):
|
||||
else:
|
||||
return outs[0]
|
||||
|
||||
def train_on_batch(self, X, y, accuracy=False,
|
||||
class_weight=None, sample_weight=None):
|
||||
'''Single gradient update over one batch of samples.
|
||||
|
||||
Returns the loss over the data,
|
||||
or a tuple `(loss, accuracy)` if `accuracy=True`.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
sample_weight = standardize_weights(y, class_weight=class_weight,
|
||||
sample_weight=sample_weight)
|
||||
ins = X + [y, sample_weight]
|
||||
if accuracy:
|
||||
return self._train_with_acc(ins)
|
||||
else:
|
||||
return self._train(ins)
|
||||
|
||||
def test_on_batch(self, X, y, accuracy=False, sample_weight=None):
|
||||
'''Returns the loss over a single batch of samples,
|
||||
or a tuple `(loss, accuracy)` if `accuracy=True`.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
sample_weight = standardize_weights(y, sample_weight=sample_weight)
|
||||
|
||||
ins = X + [y, sample_weight]
|
||||
if accuracy:
|
||||
return self._test_with_acc(ins)
|
||||
else:
|
||||
return self._test(ins)
|
||||
|
||||
def predict_on_batch(self, X):
|
||||
'''Returns predictions for a single batch of samples.
|
||||
'''
|
||||
ins = standardize_X(X)
|
||||
return self._predict(ins)
|
||||
|
||||
def save_weights(self, filepath, overwrite=False):
|
||||
# Save weights from all layers to HDF5
|
||||
'''Dump all layer weights to a HDF5 file.
|
||||
'''
|
||||
import h5py
|
||||
import os.path
|
||||
# if file exists and should not be overwritten
|
||||
@ -555,7 +678,8 @@ class Sequential(Model, containers.Sequential):
|
||||
get_input = input
|
||||
if sys.version_info[:2] <= (2, 7):
|
||||
get_input = raw_input
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? [y/n]' % (filepath))
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath))
|
||||
while overwrite not in ['y', 'n']:
|
||||
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
|
||||
if overwrite == 'n':
|
||||
@ -570,20 +694,20 @@ class Sequential(Model, containers.Sequential):
|
||||
g.attrs['nb_params'] = len(weights)
|
||||
for n, param in enumerate(weights):
|
||||
param_name = 'param_{}'.format(n)
|
||||
param_dset = g.create_dataset(param_name, param.shape, dtype=param.dtype)
|
||||
param_dset = g.create_dataset(param_name, param.shape,
|
||||
dtype=param.dtype)
|
||||
param_dset[:] = param
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
def load_weights(self, filepath):
|
||||
'''Load all layer weights from a HDF5 save file.
|
||||
'''
|
||||
This method does not make use of Sequential.set_weights()
|
||||
for backwards compatibility.
|
||||
'''
|
||||
# Loads weights from HDF5 file
|
||||
import h5py
|
||||
f = h5py.File(filepath)
|
||||
for k in range(f.attrs['nb_layers']):
|
||||
# This method does not make use of Sequential.set_weights()
|
||||
# for backwards compatibility.
|
||||
g = f['layer_{}'.format(k)]
|
||||
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
|
||||
self.layers[k].set_weights(weights)
|
||||
@ -591,8 +715,24 @@ class Sequential(Model, containers.Sequential):
|
||||
|
||||
|
||||
class Graph(Model, containers.Graph):
|
||||
def compile(self, optimizer, loss, theano_mode=None):
|
||||
# loss is a dictionary mapping output name to loss functions
|
||||
'''Arbitrary connection graph.
|
||||
It can have any number of inputs and outputs,
|
||||
with each output trained with its own loss function.
|
||||
The quantity being optimized by a Graph model is
|
||||
the sum of all loss functions over the different outputs.
|
||||
|
||||
Inherits from `containers.Graph`.
|
||||
'''
|
||||
def compile(self, optimizer, loss):
|
||||
'''Configure the learning process.
|
||||
|
||||
# Arguments
|
||||
optimizer: str (name of optimizer) or optimizer object.
|
||||
See [optimizers](optimizers.md).
|
||||
loss: dictionary mapping the name(s) of the output(s) to
|
||||
a loss function (string name of objective function or
|
||||
objective function. See [objectives](objectives.md)).
|
||||
'''
|
||||
ys = []
|
||||
ys_train = []
|
||||
ys_test = []
|
||||
@ -627,38 +767,49 @@ class Graph(Model, containers.Graph):
|
||||
for r in self.regularizers:
|
||||
train_loss = r(train_loss)
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
updates = self.optimizer.get_updates(self.params, self.constraints, train_loss)
|
||||
updates = self.optimizer.get_updates(self.params,
|
||||
self.constraints,
|
||||
train_loss)
|
||||
updates += self.updates
|
||||
self.theano_mode = theano_mode
|
||||
self.loss = loss
|
||||
|
||||
self._train = K.function(train_ins, [train_loss], updates=updates)
|
||||
self._test = K.function(test_ins, [test_loss])
|
||||
self._predict = K.function(inputs=ins, outputs=ys_test, updates=self.state_updates)
|
||||
|
||||
def train_on_batch(self, data, class_weight={}, sample_weight={}):
|
||||
# data is a dictionary mapping output and input names to arrays
|
||||
sample_weight = [standardize_weights(data[name],
|
||||
sample_weight=sample_weight.get(name),
|
||||
class_weight=class_weight.get(name)) for name in self.output_order]
|
||||
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
|
||||
return self._train(ins)
|
||||
|
||||
def test_on_batch(self, data, sample_weight={}):
|
||||
# data is a dictionary mapping input names to arrays
|
||||
sample_weight = [standardize_weights(data[name],
|
||||
sample_weight=sample_weight.get(name)) for name in self.output_order]
|
||||
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
|
||||
return self._test(ins)
|
||||
|
||||
def predict_on_batch(self, data):
|
||||
# data is a dictionary mapping input names to arrays
|
||||
ins = [data[name] for name in self.input_order]
|
||||
return self._predict(ins)
|
||||
self._predict = K.function(inputs=ins, outputs=ys_test,
|
||||
updates=self.state_updates)
|
||||
|
||||
def fit(self, data, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
|
||||
validation_split=0., validation_data=None, shuffle=True,
|
||||
class_weight={}, sample_weight={}):
|
||||
'''Train the model for a fixed number of epochs.
|
||||
|
||||
Returns a history object. It `history` attribute is a record of
|
||||
training loss values at successive epochs,
|
||||
as well as validation loss values (if applicable).
|
||||
|
||||
# Arguments
|
||||
data: dictionary mapping input names and outputs names to
|
||||
appropriate numpy arrays. All arrays should contain
|
||||
the same number of samples.
|
||||
batch_size: int. Number of samples per gradient update.
|
||||
nb_epoch: int.
|
||||
verbose: 0 for no logging to stdout,
|
||||
1 for progress bar logging, 2 for one log line per epoch.
|
||||
callbacks: `keras.callbacks.Callback` list. List of callbacks
|
||||
to apply during training. See [callbacks](callbacks.md).
|
||||
validation_split: float (0. < x < 1). Fraction of the data to
|
||||
use as held-out validation data.
|
||||
validation_data: dictionary mapping input names and outputs names
|
||||
to appropriate numpy arrays to be used as
|
||||
held-out validation data.
|
||||
All arrays should contain the same number of samples.
|
||||
Will override validation_split.
|
||||
shuffle: boolean. Whether to shuffle the samples at each epoch.
|
||||
class_weight: dictionary mapping output names to
|
||||
class weight dictionaries.
|
||||
sample_weight: dictionary mapping output names to
|
||||
numpy arrays of sample weights.
|
||||
'''
|
||||
X = [data[name] for name in self.input_order]
|
||||
y = [standardize_y(data[name]) for name in self.output_order]
|
||||
|
||||
@ -691,13 +842,18 @@ class Graph(Model, containers.Graph):
|
||||
sample_weight=sample_weight_list[i],
|
||||
class_weight=class_weight_list[i]) for i in range(len(self.output_order))]
|
||||
ins = X + y + sample_weight_list
|
||||
history = self._fit(f, ins, out_labels=out_labels, batch_size=batch_size, nb_epoch=nb_epoch,
|
||||
history = self._fit(f, ins, out_labels=out_labels,
|
||||
batch_size=batch_size, nb_epoch=nb_epoch,
|
||||
verbose=verbose, callbacks=callbacks,
|
||||
val_f=val_f, val_ins=val_ins,
|
||||
shuffle=shuffle, metrics=metrics)
|
||||
return history
|
||||
|
||||
def evaluate(self, data, batch_size=128, verbose=0, sample_weight={}):
|
||||
'''Compute the loss on some input data, batch by batch.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
sample_weight = [standardize_weights(data[name],
|
||||
sample_weight=sample_weight.get(name)) for name in self.output_order]
|
||||
|
||||
@ -706,12 +862,45 @@ class Graph(Model, containers.Graph):
|
||||
return outs[0]
|
||||
|
||||
def predict(self, data, batch_size=128, verbose=0):
|
||||
'''Generate output predictions for the input samples
|
||||
batch by batch.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
ins = [data[name] for name in self.input_order]
|
||||
outs = self._predict_loop(self._predict, ins, batch_size, verbose)
|
||||
return dict(zip(self.output_order, outs))
|
||||
|
||||
def train_on_batch(self, data, class_weight={}, sample_weight={}):
|
||||
'''Single gradient update on a batch of samples.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
sample_weight = [standardize_weights(data[name],
|
||||
sample_weight=sample_weight.get(name),
|
||||
class_weight=class_weight.get(name)) for name in self.output_order]
|
||||
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
|
||||
return self._train(ins)
|
||||
|
||||
def test_on_batch(self, data, sample_weight={}):
|
||||
'''Compute the loss on a single batch of samples.
|
||||
|
||||
Arguments: see `fit` method.
|
||||
'''
|
||||
sample_weight = [standardize_weights(data[name],
|
||||
sample_weight=sample_weight.get(name)) for name in self.output_order]
|
||||
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
|
||||
return self._test(ins)
|
||||
|
||||
def predict_on_batch(self, data):
|
||||
'''Generate predictions for a single batch of samples.
|
||||
'''
|
||||
ins = [data[name] for name in self.input_order]
|
||||
return self._predict(ins)
|
||||
|
||||
def save_weights(self, filepath, overwrite=False):
|
||||
# Save weights from all layers to HDF5
|
||||
'''Save weights from all layers to a HDF5 files.
|
||||
'''
|
||||
import h5py
|
||||
import os.path
|
||||
# if file exists and should not be overwritten
|
||||
@ -720,7 +909,8 @@ class Graph(Model, containers.Graph):
|
||||
get_input = input
|
||||
if sys.version_info[:2] <= (2, 7):
|
||||
get_input = raw_input
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? [y/n]' % (filepath))
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath))
|
||||
while overwrite not in ['y', 'n']:
|
||||
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
|
||||
if overwrite == 'n':
|
||||
@ -733,13 +923,15 @@ class Graph(Model, containers.Graph):
|
||||
g.attrs['nb_params'] = len(weights)
|
||||
for n, param in enumerate(weights):
|
||||
param_name = 'param_{}'.format(n)
|
||||
param_dset = g.create_dataset(param_name, param.shape, dtype=param.dtype)
|
||||
param_dset = g.create_dataset(param_name, param.shape,
|
||||
dtype=param.dtype)
|
||||
param_dset[:] = param
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
def load_weights(self, filepath):
|
||||
# Loads weights from HDF5 file
|
||||
'''Load weights from a HDF5 file.
|
||||
'''
|
||||
import h5py
|
||||
f = h5py.File(filepath)
|
||||
g = f['graph']
|
||||
|
@ -16,6 +16,18 @@ def kl_divergence(p, p_hat):
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
'''Abstract optimizer base class.
|
||||
|
||||
Note: this is the parent class of all optimizers, not an actual optimizer
|
||||
that can be used for training models.
|
||||
|
||||
All Keras optimizers support the following keyword arguments:
|
||||
|
||||
clipnorm: float >= 0. Gradients will be clipped
|
||||
when their L2 norm exceeds this value.
|
||||
clipvalue: float >= 0. Gradients will be clipped
|
||||
when their absolute value exceeds this value.
|
||||
'''
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
self.updates = []
|
||||
@ -45,7 +57,15 @@ class Optimizer(object):
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
'''Stochastic gradient descent, with support for momentum,
|
||||
decay, and Nesterov momentum.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate.
|
||||
momentum: float >= 0. Parameter updates momentum.
|
||||
decay: float >= 0. Learning rate decay over each update.
|
||||
nesterov: boolean. Whether to apply Nesterov momentum.
|
||||
'''
|
||||
def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False,
|
||||
*args, **kwargs):
|
||||
super(SGD, self).__init__(**kwargs)
|
||||
@ -82,6 +102,19 @@ class SGD(Optimizer):
|
||||
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
'''RMSProp optimizer.
|
||||
|
||||
It is recommended to leave the parameters of this optimizer
|
||||
at their default values.
|
||||
|
||||
This optimizer is usually a good choice for recurrent
|
||||
neural networks.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate.
|
||||
rho: float >= 0.
|
||||
epsilon: float >= 0. Fuzz factor.
|
||||
'''
|
||||
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs):
|
||||
super(RMSprop, self).__init__(**kwargs)
|
||||
self.__dict__.update(locals())
|
||||
@ -110,6 +143,15 @@ class RMSprop(Optimizer):
|
||||
|
||||
|
||||
class Adagrad(Optimizer):
|
||||
'''Adagrad optimizer.
|
||||
|
||||
It is recommended to leave the parameters of this optimizer
|
||||
at their default values.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate.
|
||||
epsilon: float >= 0.
|
||||
'''
|
||||
def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs):
|
||||
super(Adagrad, self).__init__(**kwargs)
|
||||
self.__dict__.update(locals())
|
||||
@ -134,8 +176,18 @@ class Adagrad(Optimizer):
|
||||
|
||||
|
||||
class Adadelta(Optimizer):
|
||||
'''
|
||||
Reference: http://arxiv.org/abs/1212.5701
|
||||
'''Adadelta optimizer.
|
||||
|
||||
It is recommended to leave the parameters of this optimizer
|
||||
at their default values.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate. It is recommended to leave it at the default value.
|
||||
rho: float >= 0.
|
||||
epsilon: float >= 0. Fuzz factor.
|
||||
|
||||
# References
|
||||
- [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701)
|
||||
'''
|
||||
def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):
|
||||
super(Adadelta, self).__init__(**kwargs)
|
||||
@ -173,10 +225,17 @@ class Adadelta(Optimizer):
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
'''
|
||||
Reference: http://arxiv.org/abs/1412.6980v8
|
||||
'''Adam optimizer.
|
||||
|
||||
Default parameters follow those provided in the original paper.
|
||||
Default parameters follow those provided in the original paper.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate.
|
||||
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
|
||||
epsilon: float >= 0. Fuzz factor.
|
||||
|
||||
# References
|
||||
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
|
||||
'''
|
||||
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8,
|
||||
*args, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user