Optionally load weights by name (#3488)

* Adding feature to load_weights by name

Squashed commit of the following:

commit fd47e763855c34ed78d26ee441d83e0e63f08119
Author: Arel Cordero <arel@ditto.us.com>
Date:   Thu Aug 18 16:02:14 2016 +0000

    typo

commit d0b06c03080131c55ab4777064a196ff339ad7df
Author: Arel Cordero <arel@ditto.us.com>
Date:   Thu Aug 18 15:52:35 2016 +0000

    update documentation for "load_weights"

commit 844cfc2e8c9c6f267799a22ed54ac4d75807c5ab
Author: Arel Cordero <arel@ditto.us.com>
Date:   Thu Aug 18 02:42:10 2016 +0000

    batch updating weights

commit f361a70da4b40b961f1af9c8f1c3cd26273d0cad
Author: Arel Cordero <arel@ditto.us.com>
Date:   Thu Aug 18 02:29:17 2016 +0000

    removing pudb line

commit 738de4c371503626b4c9dbae6428fb279b368a76
Author: Arel Cordero <arel@ditto.us.com>
Date:   Wed Aug 17 19:56:51 2016 +0000

    adding unit tests for loading weights by name

commit cb0971b3cfe62452ab445e4034098cab2be3031b
Author: Arel Cordero <arel@ditto.us.com>
Date:   Tue Aug 16 23:45:32 2016 +0000

    cleaning up code based on comments

commit ef08fd2c9f5d3c65359cbdf5b090e08733a518de
Author: Arel Cordero <arel@ditto.us.com>
Date:   Tue Aug 16 04:50:46 2016 +0000

    debugging

commit 0d74f0e997960886b1044c26001de6cd6ad90bb9
Author: Arel Cordero <arel@ditto.us.com>
Date:   Tue Aug 16 04:15:43 2016 +0000

    optionally load model by name

* changed random file names to use tempfile module

* clean up documentation strings

* clarifying documentation
This commit is contained in:
Arel Cordero 2016-09-06 14:42:31 -04:00 committed by François Chollet
parent b8fddc862e
commit 607635d2ce
4 changed files with 206 additions and 9 deletions

@ -113,12 +113,39 @@ Note that you will first need to install HDF5 and the Python library h5py, which
model.save_weights('my_model_weights.h5') model.save_weights('my_model_weights.h5')
``` ```
Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the same architecture: Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the *same* architecture:
```python ```python
model.load_weights('my_model_weights.h5') model.load_weights('my_model_weights.h5')
``` ```
If you need to load weights into a *different* architecture (with some layers in common), for instance for fine-tuning or transfer-learning, you can load weights by *layer name*:
```python
model.load_weights('my_model_weights.h5', by_name=True)
```
For example:
```python
"""
Assume original model looks like this:
model = Sequential()
model.add(Dense(2, input_dim=3, name="dense_1"))
model.add(Dense(3, name="dense_2"))
...
model.save_weights(fname)
"""
# new model
model = Sequential()
model.add(Dense(2, input_dim=3, name="dense_1")) # will be loaded
model.add(Dense(10, name="new_dense")) # will not be loaded
# load weights from first model; will only affect the first layer, dense_1.
model.load_weights(fname, by_name=True)
```
--- ---
### Why is the training loss much higher than the testing loss? ### Why is the training loss much higher than the testing loss?

@ -30,4 +30,4 @@ yaml_string = model.to_yaml()
model = model_from_yaml(yaml_string) model = model_from_yaml(yaml_string)
``` ```
- `model.save_weights(filepath)`: saves the weights of the model as a HDF5 file. - `model.save_weights(filepath)`: saves the weights of the model as a HDF5 file.
- `model.load_weights(filepath)`: loads the weights of the model from a HDF5 file (created by `save_weights`). - `model.load_weights(filepath, by_name=False)`: loads the weights of the model from a HDF5 file (created by `save_weights`). By default, the architecture is expected to be unchanged. To load weights into a different architecture (with some layers in common), use `by_name=True` to load only those layers with the same name.

@ -2469,14 +2469,30 @@ class Container(Layer):
else: else:
param_dset[:] = val param_dset[:] = val
def load_weights(self, filepath): def load_weights(self, filepath, by_name=False):
'''Load all layer weights from a HDF5 save file. '''Load all layer weights from a HDF5 save file.
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
''' '''
import h5py import h5py
f = h5py.File(filepath, mode='r') f = h5py.File(filepath, mode='r')
if 'layer_names' not in f.attrs and 'model_weights' in f: if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights'] f = f['model_weights']
self.load_weights_from_hdf5_group(f) if by_name:
self.load_weights_from_hdf5_group_by_name(f)
else:
self.load_weights_from_hdf5_group(f)
if hasattr(f, 'close'): if hasattr(f, 'close'):
f.close() f.close()
@ -2552,6 +2568,54 @@ class Container(Layer):
weight_value_tuples += zip(symbolic_weights, weight_values) weight_value_tuples += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_value_tuples) K.batch_set_value(weight_value_tuples)
def load_weights_from_hdf5_group_by_name(self, f):
''' Name-based weight loading
(instead of topological weight loading).
Layers that have no matching name are skipped.
'''
if hasattr(self, 'flattened_layers'):
# support for legacy Sequential/Merge behavior
flattened_layers = self.flattened_layers
else:
flattened_layers = self.layers
if 'nb_layers' in f.attrs:
raise Exception('The weight file you are trying to load is' +
' in a legacy format that does not support' +
' name-based weight loading.')
else:
# new file format
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
# Reverse index of layer name to list of layers with name.
index = {}
for layer in flattened_layers:
if layer.name:
index.setdefault(layer.name, []).append(layer)
# we batch weight value assignments in a single backend call
# which provides a speedup in TensorFlow.
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = [g[weight_name] for weight_name in weight_names]
for layer in index.get(name, []):
symbolic_weights = layer.weights
if len(weight_values) != len(symbolic_weights):
raise Exception('Layer #' + str(k) +
' (named "' + layer.name +
'") expects ' +
str(len(symbolic_weights)) +
' weight(s), but the saved weights' +
' have ' + str(len(weight_values)) +
' element(s).')
# set values
for i in range(len(weight_values)):
weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
def _updated_config(self): def _updated_config(self):
'''shared between different serialization methods''' '''shared between different serialization methods'''
from keras import __version__ as keras_version from keras import __version__ as keras_version

@ -1,5 +1,6 @@
import pytest import pytest
import os import os
import tempfile
import numpy as np import numpy as np
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
@ -28,7 +29,7 @@ def test_sequential_model_saving():
model.train_on_batch(x, y) model.train_on_batch(x, y)
out = model.predict(x) out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' _, fname = tempfile.mkstemp('.h5')
save_model(model, fname) save_model(model, fname)
new_model = load_model(fname) new_model = load_model(fname)
@ -62,7 +63,7 @@ def test_sequential_model_saving_2():
model.train_on_batch(x, y) model.train_on_batch(x, y)
out = model.predict(x) out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' _, fname = tempfile.mkstemp('.h5')
save_model(model, fname) save_model(model, fname)
model = load_model(fname, model = load_model(fname,
@ -89,7 +90,7 @@ def test_fuctional_model_saving():
model.train_on_batch(x, y) model.train_on_batch(x, y)
out = model.predict(x) out = model.predict(x)
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' _, fname = tempfile.mkstemp('.h5')
save_model(model, fname) save_model(model, fname)
model = load_model(fname) model = load_model(fname)
@ -106,7 +107,7 @@ def test_saving_without_compilation():
model.add(Dense(3)) model.add(Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc']) model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' _, fname = tempfile.mkstemp('.h5')
save_model(model, fname) save_model(model, fname)
model = load_model(fname) model = load_model(fname)
os.remove(fname) os.remove(fname)
@ -120,11 +121,116 @@ def test_saving_right_after_compilation():
model.compile(loss='mse', optimizer='sgd', metrics=['acc']) model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
model.model._make_train_function() model.model._make_train_function()
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5' _, fname = tempfile.mkstemp('.h5')
save_model(model, fname) save_model(model, fname)
model = load_model(fname) model = load_model(fname)
os.remove(fname) os.remove(fname)
@keras_test
def test_loading_weights_by_name():
"""
test loading model weights by name on:
- sequential model
"""
# test with custom optimizer, loss
custom_opt = optimizers.rmsprop
custom_loss = objectives.mse
# sequential model
model = Sequential()
model.add(Dense(2, input_dim=3, name="rick"))
model.add(Dense(3, name="morty"))
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
old_weights = [layer.get_weights() for layer in model.layers]
_, fname = tempfile.mkstemp('.h5')
model.save_weights(fname)
# delete and recreate model
del(model)
model = Sequential()
model.add(Dense(2, input_dim=3, name="rick"))
model.add(Dense(3, name="morty"))
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
# load weights from first model
model.load_weights(fname, by_name=True)
os.remove(fname)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
for i in range(len(model.layers)):
new_weights = model.layers[i].get_weights()
for j in range(len(new_weights)):
assert_allclose(old_weights[i][j], new_weights[j], atol=1e-05)
@keras_test
def test_loading_weights_by_name_2():
"""
test loading model weights by name on:
- both sequential and functional api models
- different architecture with shared names
"""
# test with custom optimizer, loss
custom_opt = optimizers.rmsprop
custom_loss = objectives.mse
# sequential model
model = Sequential()
model.add(Dense(2, input_dim=3, name="rick"))
model.add(Dense(3, name="morty"))
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
old_weights = [layer.get_weights() for layer in model.layers]
_, fname = tempfile.mkstemp('.h5')
model.save_weights(fname)
# delete and recreate model using Functional API
del(model)
data = Input(shape=(3,))
rick = Dense(2, name="rick")(data)
jerry = Dense(3, name="jerry")(rick) # add 2 layers (but maintain shapes)
jessica = Dense(2, name="jessica")(jerry)
morty = Dense(3, name="morty")(jessica)
model = Model(input=[data], output=[morty])
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
# load weights from first model
model.load_weights(fname, by_name=True)
os.remove(fname)
out2 = model.predict(x)
assert np.max(np.abs(out - out2)) > 1e-05
rick = model.layers[1].get_weights()
jerry = model.layers[2].get_weights()
jessica = model.layers[3].get_weights()
morty = model.layers[4].get_weights()
assert_allclose(old_weights[0][0], rick[0], atol=1e-05)
assert_allclose(old_weights[0][1], rick[1], atol=1e-05)
assert_allclose(old_weights[1][0], morty[0], atol=1e-05)
assert_allclose(old_weights[1][1], morty[1], atol=1e-05)
assert_allclose(np.zeros_like(jerry[1]), jerry[1]) # biases init to 0
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])