Add model saving functionality (#3314)
* Add model saving functionality * Update model saving functionality * Fix py3 bytes/str issue * Fix tests
This commit is contained in:
parent
df84c69676
commit
ea561ba6d8
@ -239,16 +239,20 @@ class ModelCheckpoint(Callback):
|
||||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only: if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
|
||||
'''
|
||||
def __init__(self, filepath, monitor='val_loss', verbose=0,
|
||||
save_best_only=False, mode='auto'):
|
||||
|
||||
save_best_only=False, save_weights_only=False,
|
||||
mode='auto'):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
self.save_best_only = save_best_only
|
||||
self.save_weights_only = save_weights_only
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
warnings.warn('ModelCheckpoint mode %s is unknown, '
|
||||
@ -285,7 +289,10 @@ class ModelCheckpoint(Callback):
|
||||
% (epoch, self.monitor, self.best,
|
||||
current, filepath))
|
||||
self.best = current
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
if self.save_weights_only:
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
else:
|
||||
self.model.save(filepath, overwrite=True)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print('Epoch %05d: %s did not improve' %
|
||||
@ -293,7 +300,10 @@ class ModelCheckpoint(Callback):
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print('Epoch %05d: saving model to %s' % (epoch, filepath))
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
if self.save_weights_only:
|
||||
self.model.save_weights(filepath, overwrite=True)
|
||||
else:
|
||||
self.model.save(filepath, overwrite=True)
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
|
@ -10,9 +10,11 @@ import marshal
|
||||
import types as python_types
|
||||
import warnings
|
||||
import copy
|
||||
import os
|
||||
from six.moves import zip
|
||||
|
||||
from keras import backend as K
|
||||
from .. import backend as K
|
||||
from ..utils.io_utils import ask_to_proceed_with_overwrite
|
||||
|
||||
|
||||
def to_list(x):
|
||||
@ -2310,7 +2312,38 @@ class Container(Layer):
|
||||
output_tensors.append(layer_output_tensors[tensor_index])
|
||||
return cls(input=input_tensors, output=output_tensors, name=name)
|
||||
|
||||
def save_weights(self, filepath, overwrite=False):
|
||||
def save(self, filepath, overwrite=True):
|
||||
'''Save into a single HDF5 file:
|
||||
- the model architecture, allowing to re-instantiate the model
|
||||
- the model weights
|
||||
- the state of the optimizer, allowing to resume training
|
||||
exactly where you left off.
|
||||
|
||||
This allows you to save the entirety of the state of a model
|
||||
in a single file.
|
||||
|
||||
Saved models can be reinstantiated via `keras.models.load_model`.
|
||||
The model returned by `load_model`
|
||||
is a compiled model ready to be used (unless the saved model
|
||||
was never compiled in the first place).
|
||||
|
||||
# Example usage
|
||||
|
||||
```python
|
||||
from keras.models import load_model
|
||||
|
||||
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
|
||||
del model # deletes the existing model
|
||||
|
||||
# returns a compiled model
|
||||
# identical to the previous one
|
||||
model = load_model('my_model.h5')
|
||||
```
|
||||
'''
|
||||
from ..models import save_model
|
||||
save_model(self, filepath, overwrite)
|
||||
|
||||
def save_weights(self, filepath, overwrite=True):
|
||||
'''Dumps all layer weights to a HDF5 file.
|
||||
|
||||
The weight file has:
|
||||
@ -2323,28 +2356,23 @@ class Container(Layer):
|
||||
storing the weight value, named after the weight tensor
|
||||
'''
|
||||
import h5py
|
||||
import os.path
|
||||
# if file exists and should not be overwritten
|
||||
if not overwrite and os.path.isfile(filepath):
|
||||
import sys
|
||||
get_input = input
|
||||
if sys.version_info[:2] <= (2, 7):
|
||||
get_input = raw_input
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath))
|
||||
while overwrite not in ['y', 'n']:
|
||||
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
|
||||
if overwrite == 'n':
|
||||
proceed = ask_to_proceed_with_overwrite(filepath)
|
||||
if not proceed:
|
||||
return
|
||||
print('[TIP] Next time specify overwrite=True in save_weights!')
|
||||
f = h5py.File(filepath, 'w')
|
||||
self.save_weights_to_hdf5_group(f)
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
def save_weights_to_hdf5_group(self, f):
|
||||
if hasattr(self, 'flattened_layers'):
|
||||
# support for legacy Sequential/Merge behavior
|
||||
flattened_layers = self.flattened_layers
|
||||
else:
|
||||
flattened_layers = self.layers
|
||||
|
||||
f = h5py.File(filepath, 'w')
|
||||
f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in flattened_layers]
|
||||
|
||||
for layer in flattened_layers:
|
||||
@ -2362,21 +2390,38 @@ class Container(Layer):
|
||||
for name, val in zip(weight_names, weight_values):
|
||||
param_dset = g.create_dataset(name, val.shape,
|
||||
dtype=val.dtype)
|
||||
param_dset[:] = val
|
||||
f.flush()
|
||||
f.close()
|
||||
if not val.shape:
|
||||
# scalar
|
||||
param_dset[()] = val
|
||||
else:
|
||||
param_dset[:] = val
|
||||
|
||||
def load_weights(self, filepath):
|
||||
'''Load all layer weights from a HDF5 save file.
|
||||
'''
|
||||
import h5py
|
||||
f = h5py.File(filepath, mode='r')
|
||||
self.load_weights_from_hdf5_group(f)
|
||||
f.close()
|
||||
|
||||
def load_weights_from_hdf5_group(self, f):
|
||||
'''Weight loading is based on layer order in a list
|
||||
(matching model.flattened_layers for Sequential models,
|
||||
and model.layers for Model class instances), not
|
||||
on layer names.
|
||||
Layers that have no weights are skipped.
|
||||
'''
|
||||
if hasattr(self, 'flattened_layers'):
|
||||
# support for legacy Sequential/Merge behavior
|
||||
flattened_layers = self.flattened_layers
|
||||
else:
|
||||
flattened_layers = self.layers
|
||||
filtered_layers = []
|
||||
for layer in flattened_layers:
|
||||
weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
if weights:
|
||||
filtered_layers.append(layer)
|
||||
flattened_layers = filtered_layers
|
||||
|
||||
if 'nb_layers' in f.attrs:
|
||||
# legacy format
|
||||
@ -2394,6 +2439,13 @@ class Container(Layer):
|
||||
else:
|
||||
# new file format
|
||||
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
|
||||
filtered_layer_names = []
|
||||
for name in layer_names:
|
||||
g = f[name]
|
||||
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
|
||||
if len(weight_names):
|
||||
filtered_layer_names.append(name)
|
||||
layer_names = filtered_layer_names
|
||||
if len(layer_names) != len(flattened_layers):
|
||||
raise Exception('You are trying to load a weight file '
|
||||
'containing ' + str(len(layer_names)) +
|
||||
@ -2406,24 +2458,22 @@ class Container(Layer):
|
||||
for k, name in enumerate(layer_names):
|
||||
g = f[name]
|
||||
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
|
||||
if len(weight_names):
|
||||
weight_values = [g[weight_name] for weight_name in weight_names]
|
||||
layer = flattened_layers[k]
|
||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
if len(weight_values) != len(symbolic_weights):
|
||||
raise Exception('Layer #' + str(k) +
|
||||
' (named "' + layer.name +
|
||||
'" in the current model) was found to '
|
||||
'correspond to layer ' + name +
|
||||
' in the save file. '
|
||||
'However the new layer ' + layer.name +
|
||||
' expects ' + str(len(symbolic_weights)) +
|
||||
' weights, but the saved weights have ' +
|
||||
str(len(weight_values)) +
|
||||
' elements.')
|
||||
weight_value_tuples += zip(symbolic_weights, weight_values)
|
||||
weight_values = [g[weight_name] for weight_name in weight_names]
|
||||
layer = flattened_layers[k]
|
||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
if len(weight_values) != len(symbolic_weights):
|
||||
raise Exception('Layer #' + str(k) +
|
||||
' (named "' + layer.name +
|
||||
'" in the current model) was found to '
|
||||
'correspond to layer ' + name +
|
||||
' in the save file. '
|
||||
'However the new layer ' + layer.name +
|
||||
' expects ' + str(len(symbolic_weights)) +
|
||||
' weights, but the saved weights have ' +
|
||||
str(len(weight_values)) +
|
||||
' elements.')
|
||||
weight_value_tuples += zip(symbolic_weights, weight_values)
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
f.close()
|
||||
|
||||
def _updated_config(self):
|
||||
'''shared between different serialization methods'''
|
||||
@ -2435,14 +2485,6 @@ class Container(Layer):
|
||||
'config': config,
|
||||
'keras_version': keras_version
|
||||
}
|
||||
|
||||
if hasattr(self, 'optimizer'):
|
||||
model_config['optimizer'] = self.optimizer.get_config()
|
||||
model_config['loss'] = getattr(self.loss, '__name__', self.loss)
|
||||
model_config['sample_weight_mode'] = self.sample_weight_mode
|
||||
|
||||
if hasattr(self, 'loss_weights'):
|
||||
model_config['loss_weights'] = self.loss_weights
|
||||
return model_config
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
@ -2462,7 +2504,7 @@ class Container(Layer):
|
||||
if type(obj).__name__ == type.__name__:
|
||||
return obj.__name__
|
||||
|
||||
raise TypeError('Not JSON Serializable')
|
||||
raise TypeError('Not JSON Serializable:', obj)
|
||||
|
||||
model_config = self._updated_config()
|
||||
return json.dumps(model_config, default=get_json_type, **kwargs)
|
||||
|
@ -608,8 +608,9 @@ class Model(Container):
|
||||
self.targets.append(K.placeholder(ndim=len(shape), name=name + '_target'))
|
||||
|
||||
# prepare metrics
|
||||
self.metrics = metrics
|
||||
self.metrics_names = ['loss']
|
||||
self.metrics = []
|
||||
self.metrics_tensors = []
|
||||
|
||||
# compute total loss
|
||||
total_loss = None
|
||||
@ -623,7 +624,7 @@ class Model(Container):
|
||||
output_loss = weighted_loss(y_true, y_pred,
|
||||
sample_weight, mask)
|
||||
if len(self.outputs) > 1:
|
||||
self.metrics.append(output_loss)
|
||||
self.metrics_tensors.append(output_loss)
|
||||
self.metrics_names.append(self.output_names[i] + '_loss')
|
||||
if total_loss is None:
|
||||
total_loss = loss_weight * output_loss
|
||||
@ -648,21 +649,21 @@ class Model(Container):
|
||||
output_shape = self.internal_output_shapes[i]
|
||||
if output_shape[-1] == 1 or self.loss_functions[i] == objectives.binary_crossentropy:
|
||||
# case: binary accuracy
|
||||
self.metrics.append(metrics_module.binary_accuracy(y_true, y_pred))
|
||||
self.metrics_tensors.append(metrics_module.binary_accuracy(y_true, y_pred))
|
||||
elif self.loss_functions[i] == objectives.sparse_categorical_crossentropy:
|
||||
# case: categorical accuracy with sparse targets
|
||||
self.metrics.append(
|
||||
self.metrics_tensors.append(
|
||||
metrics_module.sparse_categorical_accuracy(y_true, y_pred))
|
||||
else:
|
||||
# case: categorical accuracy with dense targets
|
||||
self.metrics.append(metrics_module.categorical_accuracy(y_true, y_pred))
|
||||
self.metrics_tensors.append(metrics_module.categorical_accuracy(y_true, y_pred))
|
||||
if len(self.output_names) == 1:
|
||||
self.metrics_names.append('acc')
|
||||
else:
|
||||
self.metrics_names.append(self.output_layers[i].name + '_acc')
|
||||
else:
|
||||
metric_fn = metrics_module.get(metric)
|
||||
self.metrics.append(metric_fn(y_true, y_pred))
|
||||
self.metrics_tensors.append(metric_fn(y_true, y_pred))
|
||||
if len(self.output_names) == 1:
|
||||
self.metrics_names.append(metric_fn.__name__)
|
||||
else:
|
||||
@ -698,7 +699,7 @@ class Model(Container):
|
||||
|
||||
# returns loss and metrics. Updates weights at each call.
|
||||
self.train_function = K.function(inputs,
|
||||
[self.total_loss] + self.metrics,
|
||||
[self.total_loss] + self.metrics_tensors,
|
||||
updates=updates,
|
||||
**self._function_kwargs)
|
||||
|
||||
@ -713,7 +714,7 @@ class Model(Container):
|
||||
# return loss and metrics, no gradient updates.
|
||||
# Does update the network states.
|
||||
self.test_function = K.function(inputs,
|
||||
[self.total_loss] + self.metrics,
|
||||
[self.total_loss] + self.metrics_tensors,
|
||||
updates=self.state_updates,
|
||||
**self._function_kwargs)
|
||||
|
||||
|
162
keras/models.py
162
keras/models.py
@ -1,13 +1,172 @@
|
||||
from __future__ import print_function
|
||||
import warnings
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from . import backend as K
|
||||
from .utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from .engine.training import Model
|
||||
from .engine.topology import get_source_inputs, Node
|
||||
from .optimizers import optimizer_from_config
|
||||
from .legacy.models import Graph
|
||||
|
||||
|
||||
def save_model(model, filepath, overwrite=True):
|
||||
|
||||
def get_json_type(obj):
|
||||
# if obj is a serializable Keras class instance
|
||||
# e.g. optimizer, layer
|
||||
if hasattr(obj, 'get_config'):
|
||||
return {'class_name': obj.__class__.__name__,
|
||||
'config': obj.get_config()}
|
||||
|
||||
# if obj is any numpy type
|
||||
if type(obj).__module__ == np.__name__:
|
||||
return obj.item()
|
||||
|
||||
# misc functions (e.g. loss function)
|
||||
if hasattr(obj, '__call__'):
|
||||
return obj.__name__
|
||||
|
||||
# if obj is a python 'type'
|
||||
if type(obj).__name__ == type.__name__:
|
||||
return obj.__name__
|
||||
|
||||
raise TypeError('Not JSON Serializable:', obj)
|
||||
|
||||
import h5py
|
||||
# if file exists and should not be overwritten
|
||||
if not overwrite and os.path.isfile(filepath):
|
||||
proceed = ask_to_proceed_with_overwrite(filepath)
|
||||
if not proceed:
|
||||
return
|
||||
|
||||
f = h5py.File(filepath, 'w')
|
||||
|
||||
f.attrs['model_config'] = json.dumps({
|
||||
'class_name': model.__class__.__name__,
|
||||
'config': model.get_config()
|
||||
}, default=get_json_type).encode('utf8')
|
||||
|
||||
model_weights_group = f.create_group('model_weights')
|
||||
model.save_weights_to_hdf5_group(model_weights_group)
|
||||
|
||||
if hasattr(model, 'optimizer'):
|
||||
f.attrs['training_config'] = json.dumps({
|
||||
'optimizer_config': {
|
||||
'class_name': model.optimizer.__class__.__name__,
|
||||
'config': model.optimizer.get_config()
|
||||
},
|
||||
'loss': model.loss,
|
||||
'metrics': model.metrics,
|
||||
'sample_weight_mode': model.sample_weight_mode,
|
||||
'loss_weights': model.loss_weights,
|
||||
}, default=get_json_type).encode('utf8')
|
||||
|
||||
# save optimizer weights
|
||||
symbolic_weights = getattr(model.optimizer, 'weights')
|
||||
if symbolic_weights:
|
||||
optimizer_weights_group = f.create_group('optimizer_weights')
|
||||
weight_values = K.batch_get_value(symbolic_weights)
|
||||
weight_names = []
|
||||
for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
|
||||
if hasattr(w, 'name') and w.name:
|
||||
name = str(w.name)
|
||||
else:
|
||||
name = 'param_' + str(i)
|
||||
weight_names.append(name.encode('utf8'))
|
||||
optimizer_weights_group.attrs['weight_names'] = weight_names
|
||||
for name, val in zip(weight_names, weight_values):
|
||||
param_dset = optimizer_weights_group.create_dataset(
|
||||
name,
|
||||
val.shape,
|
||||
dtype=val.dtype)
|
||||
if not val.shape:
|
||||
# scalar
|
||||
param_dset[()] = val
|
||||
else:
|
||||
param_dset[:] = val
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
|
||||
def load_model(filepath, custom_objects={}):
|
||||
|
||||
def deserialize(obj):
|
||||
if type(obj) is list:
|
||||
deserialized = []
|
||||
for value in obj:
|
||||
if value in custom_objects:
|
||||
deserialized.append(custom_objects[value])
|
||||
else:
|
||||
deserialized.append(value)
|
||||
return deserialized
|
||||
if type(obj) is dict:
|
||||
deserialized = {}
|
||||
for key, value in obj.items():
|
||||
if value in custom_objects:
|
||||
deserialized[key] = custom_objects[value]
|
||||
else:
|
||||
deserialized[key] = value
|
||||
return deserialized
|
||||
if obj in custom_objects:
|
||||
return custom_objects[obj]
|
||||
return obj
|
||||
|
||||
import h5py
|
||||
f = h5py.File(filepath, mode='r')
|
||||
|
||||
# instantiate model
|
||||
model_config = f.attrs.get('model_config')
|
||||
if model_config is None:
|
||||
raise ValueError('No model found in config file.')
|
||||
model_config = json.loads(model_config.decode('utf-8'))
|
||||
model = model_from_config(model_config, custom_objects=custom_objects)
|
||||
|
||||
# set weights
|
||||
model.load_weights_from_hdf5_group(f['model_weights'])
|
||||
|
||||
# instantiate optimizer
|
||||
training_config = f.attrs.get('training_config')
|
||||
if training_config is None:
|
||||
warnings.warn('No training configuration found in save file: '
|
||||
'the model was *not* compiled. Compile it manually.')
|
||||
f.close()
|
||||
return model
|
||||
training_config = json.loads(training_config.decode('utf-8'))
|
||||
optimizer_config = training_config['optimizer_config']
|
||||
optimizer = optimizer_from_config(optimizer_config)
|
||||
|
||||
# recover loss functions and metrics
|
||||
loss = deserialize(training_config['loss'])
|
||||
metrics = deserialize(training_config['metrics'])
|
||||
sample_weight_mode = training_config['sample_weight_mode']
|
||||
loss_weights = training_config['loss_weights']
|
||||
|
||||
# compile model
|
||||
model.compile(optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
loss_weights=loss_weights,
|
||||
sample_weight_mode=sample_weight_mode)
|
||||
|
||||
# set optimizer weights
|
||||
if 'optimizer_weights' in f:
|
||||
# build train function (to get weight updates)
|
||||
if model.__class__.__name__ == 'Sequential':
|
||||
model.model._make_train_function()
|
||||
else:
|
||||
model._make_train_function()
|
||||
optimizer_weights_group = f['optimizer_weights']
|
||||
optimizer_weight_names = [n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names']]
|
||||
optimizer_weight_values = [optimizer_weights_group[n] for n in optimizer_weight_names]
|
||||
model.optimizer.set_weights(optimizer_weight_values)
|
||||
f.close()
|
||||
return model
|
||||
|
||||
|
||||
def model_from_config(config, custom_objects={}):
|
||||
from keras.utils.layer_utils import layer_from_config
|
||||
if isinstance(config, list):
|
||||
@ -362,6 +521,9 @@ class Sequential(Model):
|
||||
**kwargs)
|
||||
self.optimizer = self.model.optimizer
|
||||
self.loss = self.model.loss
|
||||
self.loss_weights = self.model.loss_weights
|
||||
self.metrics = self.model.metrics
|
||||
self.metrics_tensors = self.model.metrics_tensors
|
||||
self.metrics_names = self.model.metrics_names
|
||||
self.sample_weight_mode = self.model.sample_weight_mode
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import absolute_import
|
||||
from . import backend as K
|
||||
import numpy as np
|
||||
from .utils.generic_utils import get_from_module
|
||||
from six.moves import zip
|
||||
|
||||
@ -11,8 +10,24 @@ def clip_norm(g, c, n):
|
||||
return g
|
||||
|
||||
|
||||
def kl_divergence(p, p_hat):
|
||||
return p_hat - p + p * K.log(p / p_hat)
|
||||
def optimizer_from_config(config, custom_objects={}):
|
||||
all_classes = {
|
||||
'sgd': SGD,
|
||||
'rmsprop': RMSprop,
|
||||
'adagrad': Adagrad,
|
||||
'adadelta': Adadelta,
|
||||
'adam': Adam,
|
||||
'adamax': Adamax,
|
||||
'nadam': Nadam,
|
||||
}
|
||||
class_name = config['class_name']
|
||||
if class_name in custom_objects:
|
||||
cls = custom_objects[class_name]
|
||||
else:
|
||||
if class_name.lower() not in all_classes:
|
||||
raise ValueError('Optimizer class not found:', class_name)
|
||||
cls = all_classes[class_name.lower()]
|
||||
return cls.from_config(config['config'])
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
@ -90,13 +105,17 @@ class Optimizer(object):
|
||||
return K.batch_get_value(self.weights)
|
||||
|
||||
def get_config(self):
|
||||
config = {'name': self.__class__.__name__}
|
||||
config = {}
|
||||
if hasattr(self, 'clipnorm'):
|
||||
config['clipnorm'] = self.clipnorm
|
||||
if hasattr(self, 'clipvalue'):
|
||||
config['clipvalue'] = self.clipvalue
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
'''Stochastic gradient descent, with support for momentum,
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
import h5py
|
||||
import numpy as np
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@ -69,3 +71,17 @@ def load_array(name):
|
||||
a[:] = array[:]
|
||||
f.close()
|
||||
return a
|
||||
|
||||
|
||||
def ask_to_proceed_with_overwrite(filepath):
|
||||
get_input = input
|
||||
if sys.version_info[:2] <= (2, 7):
|
||||
get_input = raw_input
|
||||
overwrite = get_input('[WARNING] %s already exists - overwrite? '
|
||||
'[y/n]' % (filepath))
|
||||
while overwrite not in ['y', 'n']:
|
||||
overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).')
|
||||
if overwrite == 'n':
|
||||
return False
|
||||
print('[TIP] Next time specify overwrite=True!')
|
||||
return True
|
||||
|
@ -4,10 +4,12 @@ import pytest
|
||||
from keras.models import Sequential
|
||||
from keras.engine.training import weighted_objective
|
||||
from keras.layers.core import TimeDistributedDense, Masking
|
||||
from keras.utils.test_utils import keras_test
|
||||
from keras import objectives
|
||||
from keras import backend as K
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_masking():
|
||||
np.random.seed(1337)
|
||||
X = np.array([[[1], [1]],
|
||||
@ -22,6 +24,7 @@ def test_masking():
|
||||
assert loss == 0
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_loss_masking():
|
||||
weighted_loss = weighted_objective(objectives.get('mae'))
|
||||
shape = (3, 4, 2)
|
||||
|
@ -8,6 +8,7 @@ from keras.utils.test_utils import get_test_data
|
||||
from keras.models import Sequential, Graph
|
||||
from keras.layers import Dense, Activation, RepeatVector, TimeDistributedDense, GRU
|
||||
from keras.utils import np_utils
|
||||
from keras.utils.test_utils import keras_test
|
||||
|
||||
nb_classes = 10
|
||||
batch_size = 128
|
||||
@ -69,6 +70,7 @@ def create_temporal_sequential_model():
|
||||
return model
|
||||
|
||||
|
||||
@keras_test
|
||||
def _test_weights_sequential(model, class_weight=None, sample_weight=None,
|
||||
X_train=X_train, Y_train=Y_train,
|
||||
X_test=X_test, Y_test=Y_test):
|
||||
@ -108,6 +110,7 @@ model.compile(loss=loss, optimizer='rmsprop')
|
||||
standard_score_sequential = _test_weights_sequential(model)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_class_weights():
|
||||
model = create_sequential_model()
|
||||
model.compile(loss=loss, optimizer='rmsprop')
|
||||
@ -115,6 +118,7 @@ def test_sequential_class_weights():
|
||||
assert(score < standard_score_sequential)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_sample_weights():
|
||||
model = create_sequential_model()
|
||||
model.compile(loss=loss, optimizer='rmsprop')
|
||||
@ -122,6 +126,7 @@ def test_sequential_sample_weights():
|
||||
assert(score < standard_score_sequential)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_temporal_sample_weights():
|
||||
model = create_temporal_sequential_model()
|
||||
model.compile(loss=loss, optimizer='rmsprop',
|
||||
|
151
tests/test_model_saving.py
Normal file
151
tests/test_model_saving.py
Normal file
@ -0,0 +1,151 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from keras.models import Model, Sequential
|
||||
from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed
|
||||
from keras.layers import Input
|
||||
from keras import optimizers
|
||||
from keras import objectives
|
||||
from keras import metrics
|
||||
from keras.utils.test_utils import keras_test
|
||||
from keras.models import save_model, load_model
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_model_saving():
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3))
|
||||
model.add(Dense(3))
|
||||
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
|
||||
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
|
||||
new_model = load_model(fname)
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
# test that new updates are the same with both models
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
new_model.train_on_batch(x, y)
|
||||
out = model.predict(x)
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_model_saving_2():
|
||||
# test with funkier config
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3))
|
||||
model.add(RepeatVector(3))
|
||||
model.add(TimeDistributed(Dense(3)))
|
||||
model.compile(loss=objectives.MSE,
|
||||
optimizer=optimizers.RMSprop(lr=0.0001),
|
||||
metrics=[metrics.categorical_accuracy],
|
||||
sample_weight_mode='temporal')
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
|
||||
new_model = load_model(fname)
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
# test that new updates are the same with both models
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3, 3))
|
||||
model.train_on_batch(x, y)
|
||||
new_model.train_on_batch(x, y)
|
||||
out = model.predict(x)
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_model_saving_3():
|
||||
# test with custom optimizer, loss
|
||||
custom_opt = optimizers.rmsprop
|
||||
custom_loss = objectives.mse
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3))
|
||||
model.add(Dense(3))
|
||||
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
|
||||
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
|
||||
model = load_model(fname,
|
||||
custom_objects={'custom_opt': custom_opt,
|
||||
'custom_loss': custom_loss})
|
||||
out2 = model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_fuctional_model_saving():
|
||||
input = Input(shape=(3,))
|
||||
x = Dense(2)(input)
|
||||
output = Dense(3)(x)
|
||||
|
||||
model = Model(input, output)
|
||||
model.compile(loss=objectives.MSE,
|
||||
optimizer=optimizers.RMSprop(lr=0.0001),
|
||||
metrics=[metrics.categorical_accuracy])
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
|
||||
model = load_model(fname)
|
||||
out2 = model.predict(x)
|
||||
assert_allclose(out, out2)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_saving_without_compilation():
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3))
|
||||
model.add(Dense(3))
|
||||
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
|
||||
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
model = load_model(fname)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_saving_right_after_compilation():
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3))
|
||||
model.add(Dense(3))
|
||||
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
|
||||
model.model._make_train_function()
|
||||
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
save_model(model, fname)
|
||||
model = load_model(fname)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
Loading…
Reference in New Issue
Block a user