Add model saving functionality (#3314)

* Add model saving functionality

* Update model saving functionality

* Fix py3 bytes/str issue

* Fix tests
This commit is contained in:
François Chollet 2016-07-26 20:45:28 -07:00 committed by GitHub
parent df84c69676
commit ea561ba6d8
9 changed files with 468 additions and 59 deletions

@ -239,16 +239,20 @@ class ModelCheckpoint(Callback):
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
'''
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, mode='auto'):
save_best_only=False, save_weights_only=False,
mode='auto'):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
@ -285,7 +289,10 @@ class ModelCheckpoint(Callback):
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
self.model.save_weights(filepath, overwrite=True)
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('Epoch %05d: %s did not improve' %
@ -293,7 +300,10 @@ class ModelCheckpoint(Callback):
else:
if self.verbose > 0:
print('Epoch %05d: saving model to %s' % (epoch, filepath))
self.model.save_weights(filepath, overwrite=True)
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
class EarlyStopping(Callback):

@ -10,9 +10,11 @@ import marshal
import types as python_types
import warnings
import copy
import os
from six.moves import zip
from keras import backend as K
from .. import backend as K
from ..utils.io_utils import ask_to_proceed_with_overwrite
def to_list(x):
@ -2310,7 +2312,38 @@ class Container(Layer):
output_tensors.append(layer_output_tensors[tensor_index])
return cls(input=input_tensors, output=output_tensors, name=name)
def save_weights(self, filepath, overwrite=False):
def save(self, filepath, overwrite=True):
'''Save into a single HDF5 file:
- the model architecture, allowing to re-instantiate the model
- the model weights
- the state of the optimizer, allowing to resume training
exactly where you left off.
This allows you to save the entirety of the state of a model
in a single file.
Saved models can be reinstantiated via `keras.models.load_model`.
The model returned by `load_model`
is a compiled model ready to be used (unless the saved model
was never compiled in the first place).
# Example usage
```python
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')
```
'''
from ..models import save_model
save_model(self, filepath, overwrite)
def save_weights(self, filepath, overwrite=True):
'''Dumps all layer weights to a HDF5 file.
The weight file has:
@ -2323,28 +2356,23 @@ class Container(Layer):
storing the weight value, named after the weight tensor
'''
import h5py
import os.path
# if file exists and should not be overwritten
if not overwrite and os.path.isfile(filepath):
import sys
get_input = input
if sys.version_info[:2] <= (2, 7):
get_input = raw_input
overwrite = get_input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath))
while overwrite not in ['y', 'n']:
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
if overwrite == 'n':
proceed = ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
print('[TIP] Next time specify overwrite=True in save_weights!')
f = h5py.File(filepath, 'w')
self.save_weights_to_hdf5_group(f)
f.flush()
f.close()
def save_weights_to_hdf5_group(self, f):
if hasattr(self, 'flattened_layers'):
# support for legacy Sequential/Merge behavior
flattened_layers = self.flattened_layers
else:
flattened_layers = self.layers
f = h5py.File(filepath, 'w')
f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in flattened_layers]
for layer in flattened_layers:
@ -2362,21 +2390,38 @@ class Container(Layer):
for name, val in zip(weight_names, weight_values):
param_dset = g.create_dataset(name, val.shape,
dtype=val.dtype)
param_dset[:] = val
f.flush()
f.close()
if not val.shape:
# scalar
param_dset[()] = val
else:
param_dset[:] = val
def load_weights(self, filepath):
'''Load all layer weights from a HDF5 save file.
'''
import h5py
f = h5py.File(filepath, mode='r')
self.load_weights_from_hdf5_group(f)
f.close()
def load_weights_from_hdf5_group(self, f):
'''Weight loading is based on layer order in a list
(matching model.flattened_layers for Sequential models,
and model.layers for Model class instances), not
on layer names.
Layers that have no weights are skipped.
'''
if hasattr(self, 'flattened_layers'):
# support for legacy Sequential/Merge behavior
flattened_layers = self.flattened_layers
else:
flattened_layers = self.layers
filtered_layers = []
for layer in flattened_layers:
weights = layer.trainable_weights + layer.non_trainable_weights
if weights:
filtered_layers.append(layer)
flattened_layers = filtered_layers
if 'nb_layers' in f.attrs:
# legacy format
@ -2394,6 +2439,13 @@ class Container(Layer):
else:
# new file format
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
filtered_layer_names = []
for name in layer_names:
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
if len(weight_names):
filtered_layer_names.append(name)
layer_names = filtered_layer_names
if len(layer_names) != len(flattened_layers):
raise Exception('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
@ -2406,24 +2458,22 @@ class Container(Layer):
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
if len(weight_names):
weight_values = [g[weight_name] for weight_name in weight_names]
layer = flattened_layers[k]
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
if len(weight_values) != len(symbolic_weights):
raise Exception('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the save file. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights, but the saved weights have ' +
str(len(weight_values)) +
' elements.')
weight_value_tuples += zip(symbolic_weights, weight_values)
weight_values = [g[weight_name] for weight_name in weight_names]
layer = flattened_layers[k]
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
if len(weight_values) != len(symbolic_weights):
raise Exception('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the save file. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights, but the saved weights have ' +
str(len(weight_values)) +
' elements.')
weight_value_tuples += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_value_tuples)
f.close()
def _updated_config(self):
'''shared between different serialization methods'''
@ -2435,14 +2485,6 @@ class Container(Layer):
'config': config,
'keras_version': keras_version
}
if hasattr(self, 'optimizer'):
model_config['optimizer'] = self.optimizer.get_config()
model_config['loss'] = getattr(self.loss, '__name__', self.loss)
model_config['sample_weight_mode'] = self.sample_weight_mode
if hasattr(self, 'loss_weights'):
model_config['loss_weights'] = self.loss_weights
return model_config
def to_json(self, **kwargs):
@ -2462,7 +2504,7 @@ class Container(Layer):
if type(obj).__name__ == type.__name__:
return obj.__name__
raise TypeError('Not JSON Serializable')
raise TypeError('Not JSON Serializable:', obj)
model_config = self._updated_config()
return json.dumps(model_config, default=get_json_type, **kwargs)

@ -608,8 +608,9 @@ class Model(Container):
self.targets.append(K.placeholder(ndim=len(shape), name=name + '_target'))
# prepare metrics
self.metrics = metrics
self.metrics_names = ['loss']
self.metrics = []
self.metrics_tensors = []
# compute total loss
total_loss = None
@ -623,7 +624,7 @@ class Model(Container):
output_loss = weighted_loss(y_true, y_pred,
sample_weight, mask)
if len(self.outputs) > 1:
self.metrics.append(output_loss)
self.metrics_tensors.append(output_loss)
self.metrics_names.append(self.output_names[i] + '_loss')
if total_loss is None:
total_loss = loss_weight * output_loss
@ -648,21 +649,21 @@ class Model(Container):
output_shape = self.internal_output_shapes[i]
if output_shape[-1] == 1 or self.loss_functions[i] == objectives.binary_crossentropy:
# case: binary accuracy
self.metrics.append(metrics_module.binary_accuracy(y_true, y_pred))
self.metrics_tensors.append(metrics_module.binary_accuracy(y_true, y_pred))
elif self.loss_functions[i] == objectives.sparse_categorical_crossentropy:
# case: categorical accuracy with sparse targets
self.metrics.append(
self.metrics_tensors.append(
metrics_module.sparse_categorical_accuracy(y_true, y_pred))
else:
# case: categorical accuracy with dense targets
self.metrics.append(metrics_module.categorical_accuracy(y_true, y_pred))
self.metrics_tensors.append(metrics_module.categorical_accuracy(y_true, y_pred))
if len(self.output_names) == 1:
self.metrics_names.append('acc')
else:
self.metrics_names.append(self.output_layers[i].name + '_acc')
else:
metric_fn = metrics_module.get(metric)
self.metrics.append(metric_fn(y_true, y_pred))
self.metrics_tensors.append(metric_fn(y_true, y_pred))
if len(self.output_names) == 1:
self.metrics_names.append(metric_fn.__name__)
else:
@ -698,7 +699,7 @@ class Model(Container):
# returns loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics,
[self.total_loss] + self.metrics_tensors,
updates=updates,
**self._function_kwargs)
@ -713,7 +714,7 @@ class Model(Container):
# return loss and metrics, no gradient updates.
# Does update the network states.
self.test_function = K.function(inputs,
[self.total_loss] + self.metrics,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
**self._function_kwargs)

@ -1,13 +1,172 @@
from __future__ import print_function
import warnings
import copy
import json
import os
import numpy as np
from . import backend as K
from .utils.io_utils import ask_to_proceed_with_overwrite
from .engine.training import Model
from .engine.topology import get_source_inputs, Node
from .optimizers import optimizer_from_config
from .legacy.models import Graph
def save_model(model, filepath, overwrite=True):
def get_json_type(obj):
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__,
'config': obj.get_config()}
# if obj is any numpy type
if type(obj).__module__ == np.__name__:
return obj.item()
# misc functions (e.g. loss function)
if hasattr(obj, '__call__'):
return obj.__name__
# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__
raise TypeError('Not JSON Serializable:', obj)
import h5py
# if file exists and should not be overwritten
if not overwrite and os.path.isfile(filepath):
proceed = ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
f = h5py.File(filepath, 'w')
f.attrs['model_config'] = json.dumps({
'class_name': model.__class__.__name__,
'config': model.get_config()
}, default=get_json_type).encode('utf8')
model_weights_group = f.create_group('model_weights')
model.save_weights_to_hdf5_group(model_weights_group)
if hasattr(model, 'optimizer'):
f.attrs['training_config'] = json.dumps({
'optimizer_config': {
'class_name': model.optimizer.__class__.__name__,
'config': model.optimizer.get_config()
},
'loss': model.loss,
'metrics': model.metrics,
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
}, default=get_json_type).encode('utf8')
# save optimizer weights
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
optimizer_weights_group = f.create_group('optimizer_weights')
weight_values = K.batch_get_value(symbolic_weights)
weight_names = []
for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
if hasattr(w, 'name') and w.name:
name = str(w.name)
else:
name = 'param_' + str(i)
weight_names.append(name.encode('utf8'))
optimizer_weights_group.attrs['weight_names'] = weight_names
for name, val in zip(weight_names, weight_values):
param_dset = optimizer_weights_group.create_dataset(
name,
val.shape,
dtype=val.dtype)
if not val.shape:
# scalar
param_dset[()] = val
else:
param_dset[:] = val
f.flush()
f.close()
def load_model(filepath, custom_objects={}):
def deserialize(obj):
if type(obj) is list:
deserialized = []
for value in obj:
if value in custom_objects:
deserialized.append(custom_objects[value])
else:
deserialized.append(value)
return deserialized
if type(obj) is dict:
deserialized = {}
for key, value in obj.items():
if value in custom_objects:
deserialized[key] = custom_objects[value]
else:
deserialized[key] = value
return deserialized
if obj in custom_objects:
return custom_objects[obj]
return obj
import h5py
f = h5py.File(filepath, mode='r')
# instantiate model
model_config = f.attrs.get('model_config')
if model_config is None:
raise ValueError('No model found in config file.')
model_config = json.loads(model_config.decode('utf-8'))
model = model_from_config(model_config, custom_objects=custom_objects)
# set weights
model.load_weights_from_hdf5_group(f['model_weights'])
# instantiate optimizer
training_config = f.attrs.get('training_config')
if training_config is None:
warnings.warn('No training configuration found in save file: '
'the model was *not* compiled. Compile it manually.')
f.close()
return model
training_config = json.loads(training_config.decode('utf-8'))
optimizer_config = training_config['optimizer_config']
optimizer = optimizer_from_config(optimizer_config)
# recover loss functions and metrics
loss = deserialize(training_config['loss'])
metrics = deserialize(training_config['metrics'])
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
# compile model
model.compile(optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)
# set optimizer weights
if 'optimizer_weights' in f:
# build train function (to get weight updates)
if model.__class__.__name__ == 'Sequential':
model.model._make_train_function()
else:
model._make_train_function()
optimizer_weights_group = f['optimizer_weights']
optimizer_weight_names = [n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names']]
optimizer_weight_values = [optimizer_weights_group[n] for n in optimizer_weight_names]
model.optimizer.set_weights(optimizer_weight_values)
f.close()
return model
def model_from_config(config, custom_objects={}):
from keras.utils.layer_utils import layer_from_config
if isinstance(config, list):
@ -362,6 +521,9 @@ class Sequential(Model):
**kwargs)
self.optimizer = self.model.optimizer
self.loss = self.model.loss
self.loss_weights = self.model.loss_weights
self.metrics = self.model.metrics
self.metrics_tensors = self.model.metrics_tensors
self.metrics_names = self.model.metrics_names
self.sample_weight_mode = self.model.sample_weight_mode

@ -1,6 +1,5 @@
from __future__ import absolute_import
from . import backend as K
import numpy as np
from .utils.generic_utils import get_from_module
from six.moves import zip
@ -11,8 +10,24 @@ def clip_norm(g, c, n):
return g
def kl_divergence(p, p_hat):
return p_hat - p + p * K.log(p / p_hat)
def optimizer_from_config(config, custom_objects={}):
all_classes = {
'sgd': SGD,
'rmsprop': RMSprop,
'adagrad': Adagrad,
'adadelta': Adadelta,
'adam': Adam,
'adamax': Adamax,
'nadam': Nadam,
}
class_name = config['class_name']
if class_name in custom_objects:
cls = custom_objects[class_name]
else:
if class_name.lower() not in all_classes:
raise ValueError('Optimizer class not found:', class_name)
cls = all_classes[class_name.lower()]
return cls.from_config(config['config'])
class Optimizer(object):
@ -90,13 +105,17 @@ class Optimizer(object):
return K.batch_get_value(self.weights)
def get_config(self):
config = {'name': self.__class__.__name__}
config = {}
if hasattr(self, 'clipnorm'):
config['clipnorm'] = self.clipnorm
if hasattr(self, 'clipvalue'):
config['clipvalue'] = self.clipvalue
return config
@classmethod
def from_config(cls, config):
return cls(**config)
class SGD(Optimizer):
'''Stochastic gradient descent, with support for momentum,

@ -1,6 +1,8 @@
from __future__ import absolute_import
from __future__ import print_function
import h5py
import numpy as np
import sys
from collections import defaultdict
@ -69,3 +71,17 @@ def load_array(name):
a[:] = array[:]
f.close()
return a
def ask_to_proceed_with_overwrite(filepath):
get_input = input
if sys.version_info[:2] <= (2, 7):
get_input = raw_input
overwrite = get_input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath))
while overwrite not in ['y', 'n']:
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
if overwrite == 'n':
return False
print('[TIP] Next time specify overwrite=True!')
return True

@ -4,10 +4,12 @@ import pytest
from keras.models import Sequential
from keras.engine.training import weighted_objective
from keras.layers.core import TimeDistributedDense, Masking
from keras.utils.test_utils import keras_test
from keras import objectives
from keras import backend as K
@keras_test
def test_masking():
np.random.seed(1337)
X = np.array([[[1], [1]],
@ -22,6 +24,7 @@ def test_masking():
assert loss == 0
@keras_test
def test_loss_masking():
weighted_loss = weighted_objective(objectives.get('mae'))
shape = (3, 4, 2)

@ -8,6 +8,7 @@ from keras.utils.test_utils import get_test_data
from keras.models import Sequential, Graph
from keras.layers import Dense, Activation, RepeatVector, TimeDistributedDense, GRU
from keras.utils import np_utils
from keras.utils.test_utils import keras_test
nb_classes = 10
batch_size = 128
@ -69,6 +70,7 @@ def create_temporal_sequential_model():
return model
@keras_test
def _test_weights_sequential(model, class_weight=None, sample_weight=None,
X_train=X_train, Y_train=Y_train,
X_test=X_test, Y_test=Y_test):
@ -108,6 +110,7 @@ model.compile(loss=loss, optimizer='rmsprop')
standard_score_sequential = _test_weights_sequential(model)
@keras_test
def test_sequential_class_weights():
model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop')
@ -115,6 +118,7 @@ def test_sequential_class_weights():
assert(score < standard_score_sequential)
@keras_test
def test_sequential_sample_weights():
model = create_sequential_model()
model.compile(loss=loss, optimizer='rmsprop')
@ -122,6 +126,7 @@ def test_sequential_sample_weights():
assert(score < standard_score_sequential)
@keras_test
def test_sequential_temporal_sample_weights():
model = create_temporal_sequential_model()
model.compile(loss=loss, optimizer='rmsprop',

151
tests/test_model_saving.py Normal file

@ -0,0 +1,151 @@
import pytest
import numpy as np
from numpy.testing import assert_allclose
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed
from keras.layers import Input
from keras import optimizers
from keras import objectives
from keras import metrics
from keras.utils.test_utils import keras_test
from keras.models import save_model, load_model
@keras_test
def test_sequential_model_saving():
model = Sequential()
model.add(Dense(2, input_dim=3))
model.add(Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
new_model = load_model(fname)
out2 = new_model.predict(x)
assert_allclose(out, out2)
# test that new updates are the same with both models
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
out = model.predict(x)
out2 = new_model.predict(x)
assert_allclose(out, out2)
@keras_test
def test_sequential_model_saving_2():
# test with funkier config
model = Sequential()
model.add(Dense(2, input_dim=3))
model.add(RepeatVector(3))
model.add(TimeDistributed(Dense(3)))
model.compile(loss=objectives.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy],
sample_weight_mode='temporal')
x = np.random.random((1, 3))
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
new_model = load_model(fname)
out2 = new_model.predict(x)
assert_allclose(out, out2)
# test that new updates are the same with both models
x = np.random.random((1, 3))
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
out = model.predict(x)
out2 = new_model.predict(x)
assert_allclose(out, out2)
@keras_test
def test_sequential_model_saving_3():
# test with custom optimizer, loss
custom_opt = optimizers.rmsprop
custom_loss = objectives.mse
model = Sequential()
model.add(Dense(2, input_dim=3))
model.add(Dense(3))
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
model = load_model(fname,
custom_objects={'custom_opt': custom_opt,
'custom_loss': custom_loss})
out2 = model.predict(x)
assert_allclose(out, out2)
@keras_test
def test_fuctional_model_saving():
input = Input(shape=(3,))
x = Dense(2)(input)
output = Dense(3)(x)
model = Model(input, output)
model.compile(loss=objectives.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
model = load_model(fname)
out2 = model.predict(x)
assert_allclose(out, out2)
@keras_test
def test_saving_without_compilation():
model = Sequential()
model.add(Dense(2, input_dim=3))
model.add(Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
model = load_model(fname)
@keras_test
def test_saving_right_after_compilation():
model = Sequential()
model.add(Dense(2, input_dim=3))
model.add(Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
model.model._make_train_function()
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
save_model(model, fname)
model = load_model(fname)
if __name__ == '__main__':
pytest.main([__file__])