diff --git a/docs/templates/getting-started/faq.md b/docs/templates/getting-started/faq.md index 428aad33e..8818fde26 100644 --- a/docs/templates/getting-started/faq.md +++ b/docs/templates/getting-started/faq.md @@ -113,12 +113,39 @@ Note that you will first need to install HDF5 and the Python library h5py, which model.save_weights('my_model_weights.h5') ``` -Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the same architecture: +Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the *same* architecture: ```python model.load_weights('my_model_weights.h5') ``` +If you need to load weights into a *different* architecture (with some layers in common), for instance for fine-tuning or transfer-learning, you can load weights by *layer name*: + +```python +model.load_weights('my_model_weights.h5', by_name=True) +``` + +For example: + +```python +""" +Assume original model looks like this: + model = Sequential() + model.add(Dense(2, input_dim=3, name="dense_1")) + model.add(Dense(3, name="dense_2")) + ... + model.save_weights(fname) +""" + +# new model +model = Sequential() +model.add(Dense(2, input_dim=3, name="dense_1")) # will be loaded +model.add(Dense(10, name="new_dense")) # will not be loaded + +# load weights from first model; will only affect the first layer, dense_1. +model.load_weights(fname, by_name=True) +``` + --- ### Why is the training loss much higher than the testing loss? diff --git a/docs/templates/models/about-keras-models.md b/docs/templates/models/about-keras-models.md index b4112f426..bb0c579a4 100644 --- a/docs/templates/models/about-keras-models.md +++ b/docs/templates/models/about-keras-models.md @@ -30,4 +30,4 @@ yaml_string = model.to_yaml() model = model_from_yaml(yaml_string) ``` - `model.save_weights(filepath)`: saves the weights of the model as a HDF5 file. -- `model.load_weights(filepath)`: loads the weights of the model from a HDF5 file (created by `save_weights`). \ No newline at end of file +- `model.load_weights(filepath, by_name=False)`: loads the weights of the model from a HDF5 file (created by `save_weights`). By default, the architecture is expected to be unchanged. To load weights into a different architecture (with some layers in common), use `by_name=True` to load only those layers with the same name. \ No newline at end of file diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 034764a5d..76af57a2d 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -2469,14 +2469,30 @@ class Container(Layer): else: param_dset[:] = val - def load_weights(self, filepath): + def load_weights(self, filepath, by_name=False): '''Load all layer weights from a HDF5 save file. + + If `by_name` is False (default) weights are loaded + based on the network's topology, meaning the architecture + should be the same as when the weights were saved. + Note that layers that don't have weights are not taken + into account in the topological ordering, so adding or + removing layers is fine as long as they don't have weights. + + If `by_name` is True, weights are loaded into layers + only if they share the same name. This is useful + for fine-tuning or transfer-learning models where + some of the layers have changed. ''' import h5py f = h5py.File(filepath, mode='r') if 'layer_names' not in f.attrs and 'model_weights' in f: f = f['model_weights'] - self.load_weights_from_hdf5_group(f) + if by_name: + self.load_weights_from_hdf5_group_by_name(f) + else: + self.load_weights_from_hdf5_group(f) + if hasattr(f, 'close'): f.close() @@ -2552,6 +2568,54 @@ class Container(Layer): weight_value_tuples += zip(symbolic_weights, weight_values) K.batch_set_value(weight_value_tuples) + def load_weights_from_hdf5_group_by_name(self, f): + ''' Name-based weight loading + (instead of topological weight loading). + Layers that have no matching name are skipped. + ''' + if hasattr(self, 'flattened_layers'): + # support for legacy Sequential/Merge behavior + flattened_layers = self.flattened_layers + else: + flattened_layers = self.layers + + if 'nb_layers' in f.attrs: + raise Exception('The weight file you are trying to load is' + + ' in a legacy format that does not support' + + ' name-based weight loading.') + else: + # new file format + layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + + # Reverse index of layer name to list of layers with name. + index = {} + for layer in flattened_layers: + if layer.name: + index.setdefault(layer.name, []).append(layer) + + # we batch weight value assignments in a single backend call + # which provides a speedup in TensorFlow. + weight_value_tuples = [] + for k, name in enumerate(layer_names): + g = f[name] + weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_values = [g[weight_name] for weight_name in weight_names] + + for layer in index.get(name, []): + symbolic_weights = layer.weights + if len(weight_values) != len(symbolic_weights): + raise Exception('Layer #' + str(k) + + ' (named "' + layer.name + + '") expects ' + + str(len(symbolic_weights)) + + ' weight(s), but the saved weights' + + ' have ' + str(len(weight_values)) + + ' element(s).') + # set values + for i in range(len(weight_values)): + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + K.batch_set_value(weight_value_tuples) + def _updated_config(self): '''shared between different serialization methods''' from keras import __version__ as keras_version diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py index 9b64e7411..3610f2868 100644 --- a/tests/test_model_saving.py +++ b/tests/test_model_saving.py @@ -1,5 +1,6 @@ import pytest import os +import tempfile import numpy as np from numpy.testing import assert_allclose @@ -28,7 +29,7 @@ def test_sequential_model_saving(): model.train_on_batch(x, y) out = model.predict(x) - fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + _, fname = tempfile.mkstemp('.h5') save_model(model, fname) new_model = load_model(fname) @@ -62,7 +63,7 @@ def test_sequential_model_saving_2(): model.train_on_batch(x, y) out = model.predict(x) - fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + _, fname = tempfile.mkstemp('.h5') save_model(model, fname) model = load_model(fname, @@ -89,7 +90,7 @@ def test_fuctional_model_saving(): model.train_on_batch(x, y) out = model.predict(x) - fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + _, fname = tempfile.mkstemp('.h5') save_model(model, fname) model = load_model(fname) @@ -106,7 +107,7 @@ def test_saving_without_compilation(): model.add(Dense(3)) model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + _, fname = tempfile.mkstemp('.h5') save_model(model, fname) model = load_model(fname) os.remove(fname) @@ -120,11 +121,116 @@ def test_saving_right_after_compilation(): model.compile(loss='mse', optimizer='sgd', metrics=['acc']) model.model._make_train_function() - fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' + _, fname = tempfile.mkstemp('.h5') save_model(model, fname) model = load_model(fname) os.remove(fname) +@keras_test +def test_loading_weights_by_name(): + """ + test loading model weights by name on: + - sequential model + """ + + # test with custom optimizer, loss + custom_opt = optimizers.rmsprop + custom_loss = objectives.mse + + # sequential model + model = Sequential() + model.add(Dense(2, input_dim=3, name="rick")) + model.add(Dense(3, name="morty")) + model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + old_weights = [layer.get_weights() for layer in model.layers] + _, fname = tempfile.mkstemp('.h5') + + model.save_weights(fname) + + # delete and recreate model + del(model) + model = Sequential() + model.add(Dense(2, input_dim=3, name="rick")) + model.add(Dense(3, name="morty")) + model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc']) + + # load weights from first model + model.load_weights(fname, by_name=True) + os.remove(fname) + + out2 = model.predict(x) + assert_allclose(out, out2, atol=1e-05) + for i in range(len(model.layers)): + new_weights = model.layers[i].get_weights() + for j in range(len(new_weights)): + assert_allclose(old_weights[i][j], new_weights[j], atol=1e-05) + + +@keras_test +def test_loading_weights_by_name_2(): + """ + test loading model weights by name on: + - both sequential and functional api models + - different architecture with shared names + """ + + # test with custom optimizer, loss + custom_opt = optimizers.rmsprop + custom_loss = objectives.mse + + # sequential model + model = Sequential() + model.add(Dense(2, input_dim=3, name="rick")) + model.add(Dense(3, name="morty")) + model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + old_weights = [layer.get_weights() for layer in model.layers] + _, fname = tempfile.mkstemp('.h5') + + model.save_weights(fname) + + # delete and recreate model using Functional API + del(model) + data = Input(shape=(3,)) + rick = Dense(2, name="rick")(data) + jerry = Dense(3, name="jerry")(rick) # add 2 layers (but maintain shapes) + jessica = Dense(2, name="jessica")(jerry) + morty = Dense(3, name="morty")(jessica) + + model = Model(input=[data], output=[morty]) + model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc']) + + # load weights from first model + model.load_weights(fname, by_name=True) + os.remove(fname) + + out2 = model.predict(x) + assert np.max(np.abs(out - out2)) > 1e-05 + + rick = model.layers[1].get_weights() + jerry = model.layers[2].get_weights() + jessica = model.layers[3].get_weights() + morty = model.layers[4].get_weights() + + assert_allclose(old_weights[0][0], rick[0], atol=1e-05) + assert_allclose(old_weights[0][1], rick[1], atol=1e-05) + assert_allclose(old_weights[1][0], morty[0], atol=1e-05) + assert_allclose(old_weights[1][1], morty[1], atol=1e-05) + assert_allclose(np.zeros_like(jerry[1]), jerry[1]) # biases init to 0 + assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0 + + if __name__ == '__main__': pytest.main([__file__])