diff --git a/keras/callbacks.py b/keras/callbacks.py index 7f2edfb43..6f9f9db17 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -239,16 +239,20 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. + save_weights_only: if True, then only the model's weights will be + saved (`model.save_weights(filepath)`), else the full model + is saved (`model.save(filepath)`). ''' def __init__(self, filepath, monitor='val_loss', verbose=0, - save_best_only=False, mode='auto'): - + save_best_only=False, save_weights_only=False, + mode='auto'): super(ModelCheckpoint, self).__init__() self.monitor = monitor self.verbose = verbose self.filepath = filepath self.save_best_only = save_best_only + self.save_weights_only = save_weights_only if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' @@ -285,7 +289,10 @@ class ModelCheckpoint(Callback): % (epoch, self.monitor, self.best, current, filepath)) self.best = current - self.model.save_weights(filepath, overwrite=True) + if self.save_weights_only: + self.model.save_weights(filepath, overwrite=True) + else: + self.model.save(filepath, overwrite=True) else: if self.verbose > 0: print('Epoch %05d: %s did not improve' % @@ -293,7 +300,10 @@ class ModelCheckpoint(Callback): else: if self.verbose > 0: print('Epoch %05d: saving model to %s' % (epoch, filepath)) - self.model.save_weights(filepath, overwrite=True) + if self.save_weights_only: + self.model.save_weights(filepath, overwrite=True) + else: + self.model.save(filepath, overwrite=True) class EarlyStopping(Callback): diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 876a1b820..aa38dbca8 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -10,9 +10,11 @@ import marshal import types as python_types import warnings import copy +import os from six.moves import zip -from keras import backend as K +from .. import backend as K +from ..utils.io_utils import ask_to_proceed_with_overwrite def to_list(x): @@ -2310,7 +2312,38 @@ class Container(Layer): output_tensors.append(layer_output_tensors[tensor_index]) return cls(input=input_tensors, output=output_tensors, name=name) - def save_weights(self, filepath, overwrite=False): + def save(self, filepath, overwrite=True): + '''Save into a single HDF5 file: + - the model architecture, allowing to re-instantiate the model + - the model weights + - the state of the optimizer, allowing to resume training + exactly where you left off. + + This allows you to save the entirety of the state of a model + in a single file. + + Saved models can be reinstantiated via `keras.models.load_model`. + The model returned by `load_model` + is a compiled model ready to be used (unless the saved model + was never compiled in the first place). + + # Example usage + + ```python + from keras.models import load_model + + model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' + del model # deletes the existing model + + # returns a compiled model + # identical to the previous one + model = load_model('my_model.h5') + ``` + ''' + from ..models import save_model + save_model(self, filepath, overwrite) + + def save_weights(self, filepath, overwrite=True): '''Dumps all layer weights to a HDF5 file. The weight file has: @@ -2323,28 +2356,23 @@ class Container(Layer): storing the weight value, named after the weight tensor ''' import h5py - import os.path # if file exists and should not be overwritten if not overwrite and os.path.isfile(filepath): - import sys - get_input = input - if sys.version_info[:2] <= (2, 7): - get_input = raw_input - overwrite = get_input('[WARNING] %s already exists - overwrite? ' - '[y/n]' % (filepath)) - while overwrite not in ['y', 'n']: - overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).') - if overwrite == 'n': + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: return - print('[TIP] Next time specify overwrite=True in save_weights!') + f = h5py.File(filepath, 'w') + self.save_weights_to_hdf5_group(f) + f.flush() + f.close() + def save_weights_to_hdf5_group(self, f): if hasattr(self, 'flattened_layers'): # support for legacy Sequential/Merge behavior flattened_layers = self.flattened_layers else: flattened_layers = self.layers - f = h5py.File(filepath, 'w') f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in flattened_layers] for layer in flattened_layers: @@ -2362,21 +2390,38 @@ class Container(Layer): for name, val in zip(weight_names, weight_values): param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) - param_dset[:] = val - f.flush() - f.close() + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val def load_weights(self, filepath): '''Load all layer weights from a HDF5 save file. ''' import h5py f = h5py.File(filepath, mode='r') + self.load_weights_from_hdf5_group(f) + f.close() + def load_weights_from_hdf5_group(self, f): + '''Weight loading is based on layer order in a list + (matching model.flattened_layers for Sequential models, + and model.layers for Model class instances), not + on layer names. + Layers that have no weights are skipped. + ''' if hasattr(self, 'flattened_layers'): # support for legacy Sequential/Merge behavior flattened_layers = self.flattened_layers else: flattened_layers = self.layers + filtered_layers = [] + for layer in flattened_layers: + weights = layer.trainable_weights + layer.non_trainable_weights + if weights: + filtered_layers.append(layer) + flattened_layers = filtered_layers if 'nb_layers' in f.attrs: # legacy format @@ -2394,6 +2439,13 @@ class Container(Layer): else: # new file format layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + filtered_layer_names = [] + for name in layer_names: + g = f[name] + weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + if len(weight_names): + filtered_layer_names.append(name) + layer_names = filtered_layer_names if len(layer_names) != len(flattened_layers): raise Exception('You are trying to load a weight file ' 'containing ' + str(len(layer_names)) + @@ -2406,24 +2458,22 @@ class Container(Layer): for k, name in enumerate(layer_names): g = f[name] weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - if len(weight_names): - weight_values = [g[weight_name] for weight_name in weight_names] - layer = flattened_layers[k] - symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - if len(weight_values) != len(symbolic_weights): - raise Exception('Layer #' + str(k) + - ' (named "' + layer.name + - '" in the current model) was found to ' - 'correspond to layer ' + name + - ' in the save file. ' - 'However the new layer ' + layer.name + - ' expects ' + str(len(symbolic_weights)) + - ' weights, but the saved weights have ' + - str(len(weight_values)) + - ' elements.') - weight_value_tuples += zip(symbolic_weights, weight_values) + weight_values = [g[weight_name] for weight_name in weight_names] + layer = flattened_layers[k] + symbolic_weights = layer.trainable_weights + layer.non_trainable_weights + if len(weight_values) != len(symbolic_weights): + raise Exception('Layer #' + str(k) + + ' (named "' + layer.name + + '" in the current model) was found to ' + 'correspond to layer ' + name + + ' in the save file. ' + 'However the new layer ' + layer.name + + ' expects ' + str(len(symbolic_weights)) + + ' weights, but the saved weights have ' + + str(len(weight_values)) + + ' elements.') + weight_value_tuples += zip(symbolic_weights, weight_values) K.batch_set_value(weight_value_tuples) - f.close() def _updated_config(self): '''shared between different serialization methods''' @@ -2435,14 +2485,6 @@ class Container(Layer): 'config': config, 'keras_version': keras_version } - - if hasattr(self, 'optimizer'): - model_config['optimizer'] = self.optimizer.get_config() - model_config['loss'] = getattr(self.loss, '__name__', self.loss) - model_config['sample_weight_mode'] = self.sample_weight_mode - - if hasattr(self, 'loss_weights'): - model_config['loss_weights'] = self.loss_weights return model_config def to_json(self, **kwargs): @@ -2462,7 +2504,7 @@ class Container(Layer): if type(obj).__name__ == type.__name__: return obj.__name__ - raise TypeError('Not JSON Serializable') + raise TypeError('Not JSON Serializable:', obj) model_config = self._updated_config() return json.dumps(model_config, default=get_json_type, **kwargs) diff --git a/keras/engine/training.py b/keras/engine/training.py index 217ec0206..030d12857 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -608,8 +608,9 @@ class Model(Container): self.targets.append(K.placeholder(ndim=len(shape), name=name + '_target')) # prepare metrics + self.metrics = metrics self.metrics_names = ['loss'] - self.metrics = [] + self.metrics_tensors = [] # compute total loss total_loss = None @@ -623,7 +624,7 @@ class Model(Container): output_loss = weighted_loss(y_true, y_pred, sample_weight, mask) if len(self.outputs) > 1: - self.metrics.append(output_loss) + self.metrics_tensors.append(output_loss) self.metrics_names.append(self.output_names[i] + '_loss') if total_loss is None: total_loss = loss_weight * output_loss @@ -648,21 +649,21 @@ class Model(Container): output_shape = self.internal_output_shapes[i] if output_shape[-1] == 1 or self.loss_functions[i] == objectives.binary_crossentropy: # case: binary accuracy - self.metrics.append(metrics_module.binary_accuracy(y_true, y_pred)) + self.metrics_tensors.append(metrics_module.binary_accuracy(y_true, y_pred)) elif self.loss_functions[i] == objectives.sparse_categorical_crossentropy: # case: categorical accuracy with sparse targets - self.metrics.append( + self.metrics_tensors.append( metrics_module.sparse_categorical_accuracy(y_true, y_pred)) else: # case: categorical accuracy with dense targets - self.metrics.append(metrics_module.categorical_accuracy(y_true, y_pred)) + self.metrics_tensors.append(metrics_module.categorical_accuracy(y_true, y_pred)) if len(self.output_names) == 1: self.metrics_names.append('acc') else: self.metrics_names.append(self.output_layers[i].name + '_acc') else: metric_fn = metrics_module.get(metric) - self.metrics.append(metric_fn(y_true, y_pred)) + self.metrics_tensors.append(metric_fn(y_true, y_pred)) if len(self.output_names) == 1: self.metrics_names.append(metric_fn.__name__) else: @@ -698,7 +699,7 @@ class Model(Container): # returns loss and metrics. Updates weights at each call. self.train_function = K.function(inputs, - [self.total_loss] + self.metrics, + [self.total_loss] + self.metrics_tensors, updates=updates, **self._function_kwargs) @@ -713,7 +714,7 @@ class Model(Container): # return loss and metrics, no gradient updates. # Does update the network states. self.test_function = K.function(inputs, - [self.total_loss] + self.metrics, + [self.total_loss] + self.metrics_tensors, updates=self.state_updates, **self._function_kwargs) diff --git a/keras/models.py b/keras/models.py index 64b35df50..ddea67d7a 100644 --- a/keras/models.py +++ b/keras/models.py @@ -1,13 +1,172 @@ from __future__ import print_function import warnings import copy +import json +import os +import numpy as np from . import backend as K +from .utils.io_utils import ask_to_proceed_with_overwrite from .engine.training import Model from .engine.topology import get_source_inputs, Node +from .optimizers import optimizer_from_config from .legacy.models import Graph +def save_model(model, filepath, overwrite=True): + + def get_json_type(obj): + # if obj is a serializable Keras class instance + # e.g. optimizer, layer + if hasattr(obj, 'get_config'): + return {'class_name': obj.__class__.__name__, + 'config': obj.get_config()} + + # if obj is any numpy type + if type(obj).__module__ == np.__name__: + return obj.item() + + # misc functions (e.g. loss function) + if hasattr(obj, '__call__'): + return obj.__name__ + + # if obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + raise TypeError('Not JSON Serializable:', obj) + + import h5py + # if file exists and should not be overwritten + if not overwrite and os.path.isfile(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + f = h5py.File(filepath, 'w') + + f.attrs['model_config'] = json.dumps({ + 'class_name': model.__class__.__name__, + 'config': model.get_config() + }, default=get_json_type).encode('utf8') + + model_weights_group = f.create_group('model_weights') + model.save_weights_to_hdf5_group(model_weights_group) + + if hasattr(model, 'optimizer'): + f.attrs['training_config'] = json.dumps({ + 'optimizer_config': { + 'class_name': model.optimizer.__class__.__name__, + 'config': model.optimizer.get_config() + }, + 'loss': model.loss, + 'metrics': model.metrics, + 'sample_weight_mode': model.sample_weight_mode, + 'loss_weights': model.loss_weights, + }, default=get_json_type).encode('utf8') + + # save optimizer weights + symbolic_weights = getattr(model.optimizer, 'weights') + if symbolic_weights: + optimizer_weights_group = f.create_group('optimizer_weights') + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): + if hasattr(w, 'name') and w.name: + name = str(w.name) + else: + name = 'param_' + str(i) + weight_names.append(name.encode('utf8')) + optimizer_weights_group.attrs['weight_names'] = weight_names + for name, val in zip(weight_names, weight_values): + param_dset = optimizer_weights_group.create_dataset( + name, + val.shape, + dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + f.flush() + f.close() + + +def load_model(filepath, custom_objects={}): + + def deserialize(obj): + if type(obj) is list: + deserialized = [] + for value in obj: + if value in custom_objects: + deserialized.append(custom_objects[value]) + else: + deserialized.append(value) + return deserialized + if type(obj) is dict: + deserialized = {} + for key, value in obj.items(): + if value in custom_objects: + deserialized[key] = custom_objects[value] + else: + deserialized[key] = value + return deserialized + if obj in custom_objects: + return custom_objects[obj] + return obj + + import h5py + f = h5py.File(filepath, mode='r') + + # instantiate model + model_config = f.attrs.get('model_config') + if model_config is None: + raise ValueError('No model found in config file.') + model_config = json.loads(model_config.decode('utf-8')) + model = model_from_config(model_config, custom_objects=custom_objects) + + # set weights + model.load_weights_from_hdf5_group(f['model_weights']) + + # instantiate optimizer + training_config = f.attrs.get('training_config') + if training_config is None: + warnings.warn('No training configuration found in save file: ' + 'the model was *not* compiled. Compile it manually.') + f.close() + return model + training_config = json.loads(training_config.decode('utf-8')) + optimizer_config = training_config['optimizer_config'] + optimizer = optimizer_from_config(optimizer_config) + + # recover loss functions and metrics + loss = deserialize(training_config['loss']) + metrics = deserialize(training_config['metrics']) + sample_weight_mode = training_config['sample_weight_mode'] + loss_weights = training_config['loss_weights'] + + # compile model + model.compile(optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode) + + # set optimizer weights + if 'optimizer_weights' in f: + # build train function (to get weight updates) + if model.__class__.__name__ == 'Sequential': + model.model._make_train_function() + else: + model._make_train_function() + optimizer_weights_group = f['optimizer_weights'] + optimizer_weight_names = [n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names']] + optimizer_weight_values = [optimizer_weights_group[n] for n in optimizer_weight_names] + model.optimizer.set_weights(optimizer_weight_values) + f.close() + return model + + def model_from_config(config, custom_objects={}): from keras.utils.layer_utils import layer_from_config if isinstance(config, list): @@ -362,6 +521,9 @@ class Sequential(Model): **kwargs) self.optimizer = self.model.optimizer self.loss = self.model.loss + self.loss_weights = self.model.loss_weights + self.metrics = self.model.metrics + self.metrics_tensors = self.model.metrics_tensors self.metrics_names = self.model.metrics_names self.sample_weight_mode = self.model.sample_weight_mode diff --git a/keras/optimizers.py b/keras/optimizers.py index d1b1deb7e..74484f4dc 100644 --- a/keras/optimizers.py +++ b/keras/optimizers.py @@ -1,6 +1,5 @@ from __future__ import absolute_import from . import backend as K -import numpy as np from .utils.generic_utils import get_from_module from six.moves import zip @@ -11,8 +10,24 @@ def clip_norm(g, c, n): return g -def kl_divergence(p, p_hat): - return p_hat - p + p * K.log(p / p_hat) +def optimizer_from_config(config, custom_objects={}): + all_classes = { + 'sgd': SGD, + 'rmsprop': RMSprop, + 'adagrad': Adagrad, + 'adadelta': Adadelta, + 'adam': Adam, + 'adamax': Adamax, + 'nadam': Nadam, + } + class_name = config['class_name'] + if class_name in custom_objects: + cls = custom_objects[class_name] + else: + if class_name.lower() not in all_classes: + raise ValueError('Optimizer class not found:', class_name) + cls = all_classes[class_name.lower()] + return cls.from_config(config['config']) class Optimizer(object): @@ -90,13 +105,17 @@ class Optimizer(object): return K.batch_get_value(self.weights) def get_config(self): - config = {'name': self.__class__.__name__} + config = {} if hasattr(self, 'clipnorm'): config['clipnorm'] = self.clipnorm if hasattr(self, 'clipvalue'): config['clipvalue'] = self.clipvalue return config + @classmethod + def from_config(cls, config): + return cls(**config) + class SGD(Optimizer): '''Stochastic gradient descent, with support for momentum, diff --git a/keras/utils/io_utils.py b/keras/utils/io_utils.py index be42dbc8d..28ad67cd6 100644 --- a/keras/utils/io_utils.py +++ b/keras/utils/io_utils.py @@ -1,6 +1,8 @@ from __future__ import absolute_import +from __future__ import print_function import h5py import numpy as np +import sys from collections import defaultdict @@ -69,3 +71,17 @@ def load_array(name): a[:] = array[:] f.close() return a + + +def ask_to_proceed_with_overwrite(filepath): + get_input = input + if sys.version_info[:2] <= (2, 7): + get_input = raw_input + overwrite = get_input('[WARNING] %s already exists - overwrite? ' + '[y/n]' % (filepath)) + while overwrite not in ['y', 'n']: + overwrite = get_input('Enter "y" (overwrite) or "n" (cancel).') + if overwrite == 'n': + return False + print('[TIP] Next time specify overwrite=True!') + return True diff --git a/tests/test_loss_masking.py b/tests/test_loss_masking.py index 5996fb3de..2f5cbb645 100644 --- a/tests/test_loss_masking.py +++ b/tests/test_loss_masking.py @@ -4,10 +4,12 @@ import pytest from keras.models import Sequential from keras.engine.training import weighted_objective from keras.layers.core import TimeDistributedDense, Masking +from keras.utils.test_utils import keras_test from keras import objectives from keras import backend as K +@keras_test def test_masking(): np.random.seed(1337) X = np.array([[[1], [1]], @@ -22,6 +24,7 @@ def test_masking(): assert loss == 0 +@keras_test def test_loss_masking(): weighted_loss = weighted_objective(objectives.get('mae')) shape = (3, 4, 2) diff --git a/tests/test_loss_weighting.py b/tests/test_loss_weighting.py index bad7457f2..6ed059b78 100644 --- a/tests/test_loss_weighting.py +++ b/tests/test_loss_weighting.py @@ -8,6 +8,7 @@ from keras.utils.test_utils import get_test_data from keras.models import Sequential, Graph from keras.layers import Dense, Activation, RepeatVector, TimeDistributedDense, GRU from keras.utils import np_utils +from keras.utils.test_utils import keras_test nb_classes = 10 batch_size = 128 @@ -69,6 +70,7 @@ def create_temporal_sequential_model(): return model +@keras_test def _test_weights_sequential(model, class_weight=None, sample_weight=None, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test): @@ -108,6 +110,7 @@ model.compile(loss=loss, optimizer='rmsprop') standard_score_sequential = _test_weights_sequential(model) +@keras_test def test_sequential_class_weights(): model = create_sequential_model() model.compile(loss=loss, optimizer='rmsprop') @@ -115,6 +118,7 @@ def test_sequential_class_weights(): assert(score < standard_score_sequential) +@keras_test def test_sequential_sample_weights(): model = create_sequential_model() model.compile(loss=loss, optimizer='rmsprop') @@ -122,6 +126,7 @@ def test_sequential_sample_weights(): assert(score < standard_score_sequential) +@keras_test def test_sequential_temporal_sample_weights(): model = create_temporal_sequential_model() model.compile(loss=loss, optimizer='rmsprop', diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py new file mode 100644 index 000000000..8c7c9e7e1 --- /dev/null +++ b/tests/test_model_saving.py @@ -0,0 +1,151 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from keras.models import Model, Sequential +from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed +from keras.layers import Input +from keras import optimizers +from keras import objectives +from keras import metrics +from keras.utils.test_utils import keras_test +from keras.models import save_model, load_model + + +@keras_test +def test_sequential_model_saving(): + model = Sequential() + model.add(Dense(2, input_dim=3)) + model.add(Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + + new_model = load_model(fname) + out2 = new_model.predict(x) + assert_allclose(out, out2) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + new_model.train_on_batch(x, y) + out = model.predict(x) + out2 = new_model.predict(x) + assert_allclose(out, out2) + + +@keras_test +def test_sequential_model_saving_2(): + # test with funkier config + model = Sequential() + model.add(Dense(2, input_dim=3)) + model.add(RepeatVector(3)) + model.add(TimeDistributed(Dense(3))) + model.compile(loss=objectives.MSE, + optimizer=optimizers.RMSprop(lr=0.0001), + metrics=[metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + + new_model = load_model(fname) + out2 = new_model.predict(x) + assert_allclose(out, out2) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + new_model.train_on_batch(x, y) + out = model.predict(x) + out2 = new_model.predict(x) + assert_allclose(out, out2) + + +@keras_test +def test_sequential_model_saving_3(): + # test with custom optimizer, loss + custom_opt = optimizers.rmsprop + custom_loss = objectives.mse + model = Sequential() + model.add(Dense(2, input_dim=3)) + model.add(Dense(3)) + model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + + model = load_model(fname, + custom_objects={'custom_opt': custom_opt, + 'custom_loss': custom_loss}) + out2 = model.predict(x) + assert_allclose(out, out2) + + +@keras_test +def test_fuctional_model_saving(): + input = Input(shape=(3,)) + x = Dense(2)(input) + output = Dense(3)(x) + + model = Model(input, output) + model.compile(loss=objectives.MSE, + optimizer=optimizers.RMSprop(lr=0.0001), + metrics=[metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + + model = load_model(fname) + out2 = model.predict(x) + assert_allclose(out, out2) + + +@keras_test +def test_saving_without_compilation(): + model = Sequential() + model.add(Dense(2, input_dim=3)) + model.add(Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + model = load_model(fname) + + +@keras_test +def test_saving_right_after_compilation(): + model = Sequential() + model.add(Dense(2, input_dim=3)) + model.add(Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + model.model._make_train_function() + + fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + save_model(model, fname) + model = load_model(fname) + + +if __name__ == '__main__': + pytest.main([__file__])