Optionally load weights by name (#3488)
* Adding feature to load_weights by name Squashed commit of the following: commit fd47e763855c34ed78d26ee441d83e0e63f08119 Author: Arel Cordero <arel@ditto.us.com> Date: Thu Aug 18 16:02:14 2016 +0000 typo commit d0b06c03080131c55ab4777064a196ff339ad7df Author: Arel Cordero <arel@ditto.us.com> Date: Thu Aug 18 15:52:35 2016 +0000 update documentation for "load_weights" commit 844cfc2e8c9c6f267799a22ed54ac4d75807c5ab Author: Arel Cordero <arel@ditto.us.com> Date: Thu Aug 18 02:42:10 2016 +0000 batch updating weights commit f361a70da4b40b961f1af9c8f1c3cd26273d0cad Author: Arel Cordero <arel@ditto.us.com> Date: Thu Aug 18 02:29:17 2016 +0000 removing pudb line commit 738de4c371503626b4c9dbae6428fb279b368a76 Author: Arel Cordero <arel@ditto.us.com> Date: Wed Aug 17 19:56:51 2016 +0000 adding unit tests for loading weights by name commit cb0971b3cfe62452ab445e4034098cab2be3031b Author: Arel Cordero <arel@ditto.us.com> Date: Tue Aug 16 23:45:32 2016 +0000 cleaning up code based on comments commit ef08fd2c9f5d3c65359cbdf5b090e08733a518de Author: Arel Cordero <arel@ditto.us.com> Date: Tue Aug 16 04:50:46 2016 +0000 debugging commit 0d74f0e997960886b1044c26001de6cd6ad90bb9 Author: Arel Cordero <arel@ditto.us.com> Date: Tue Aug 16 04:15:43 2016 +0000 optionally load model by name * changed random file names to use tempfile module * clean up documentation strings * clarifying documentation
This commit is contained in:
parent
b8fddc862e
commit
607635d2ce
29
docs/templates/getting-started/faq.md
vendored
29
docs/templates/getting-started/faq.md
vendored
@ -113,12 +113,39 @@ Note that you will first need to install HDF5 and the Python library h5py, which
|
||||
model.save_weights('my_model_weights.h5')
|
||||
```
|
||||
|
||||
Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the same architecture:
|
||||
Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the *same* architecture:
|
||||
|
||||
```python
|
||||
model.load_weights('my_model_weights.h5')
|
||||
```
|
||||
|
||||
If you need to load weights into a *different* architecture (with some layers in common), for instance for fine-tuning or transfer-learning, you can load weights by *layer name*:
|
||||
|
||||
```python
|
||||
model.load_weights('my_model_weights.h5', by_name=True)
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
"""
|
||||
Assume original model looks like this:
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3, name="dense_1"))
|
||||
model.add(Dense(3, name="dense_2"))
|
||||
...
|
||||
model.save_weights(fname)
|
||||
"""
|
||||
|
||||
# new model
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3, name="dense_1")) # will be loaded
|
||||
model.add(Dense(10, name="new_dense")) # will not be loaded
|
||||
|
||||
# load weights from first model; will only affect the first layer, dense_1.
|
||||
model.load_weights(fname, by_name=True)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Why is the training loss much higher than the testing loss?
|
||||
|
2
docs/templates/models/about-keras-models.md
vendored
2
docs/templates/models/about-keras-models.md
vendored
@ -30,4 +30,4 @@ yaml_string = model.to_yaml()
|
||||
model = model_from_yaml(yaml_string)
|
||||
```
|
||||
- `model.save_weights(filepath)`: saves the weights of the model as a HDF5 file.
|
||||
- `model.load_weights(filepath)`: loads the weights of the model from a HDF5 file (created by `save_weights`).
|
||||
- `model.load_weights(filepath, by_name=False)`: loads the weights of the model from a HDF5 file (created by `save_weights`). By default, the architecture is expected to be unchanged. To load weights into a different architecture (with some layers in common), use `by_name=True` to load only those layers with the same name.
|
@ -2469,14 +2469,30 @@ class Container(Layer):
|
||||
else:
|
||||
param_dset[:] = val
|
||||
|
||||
def load_weights(self, filepath):
|
||||
def load_weights(self, filepath, by_name=False):
|
||||
'''Load all layer weights from a HDF5 save file.
|
||||
|
||||
If `by_name` is False (default) weights are loaded
|
||||
based on the network's topology, meaning the architecture
|
||||
should be the same as when the weights were saved.
|
||||
Note that layers that don't have weights are not taken
|
||||
into account in the topological ordering, so adding or
|
||||
removing layers is fine as long as they don't have weights.
|
||||
|
||||
If `by_name` is True, weights are loaded into layers
|
||||
only if they share the same name. This is useful
|
||||
for fine-tuning or transfer-learning models where
|
||||
some of the layers have changed.
|
||||
'''
|
||||
import h5py
|
||||
f = h5py.File(filepath, mode='r')
|
||||
if 'layer_names' not in f.attrs and 'model_weights' in f:
|
||||
f = f['model_weights']
|
||||
self.load_weights_from_hdf5_group(f)
|
||||
if by_name:
|
||||
self.load_weights_from_hdf5_group_by_name(f)
|
||||
else:
|
||||
self.load_weights_from_hdf5_group(f)
|
||||
|
||||
if hasattr(f, 'close'):
|
||||
f.close()
|
||||
|
||||
@ -2552,6 +2568,54 @@ class Container(Layer):
|
||||
weight_value_tuples += zip(symbolic_weights, weight_values)
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
def load_weights_from_hdf5_group_by_name(self, f):
|
||||
''' Name-based weight loading
|
||||
(instead of topological weight loading).
|
||||
Layers that have no matching name are skipped.
|
||||
'''
|
||||
if hasattr(self, 'flattened_layers'):
|
||||
# support for legacy Sequential/Merge behavior
|
||||
flattened_layers = self.flattened_layers
|
||||
else:
|
||||
flattened_layers = self.layers
|
||||
|
||||
if 'nb_layers' in f.attrs:
|
||||
raise Exception('The weight file you are trying to load is' +
|
||||
' in a legacy format that does not support' +
|
||||
' name-based weight loading.')
|
||||
else:
|
||||
# new file format
|
||||
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
|
||||
|
||||
# Reverse index of layer name to list of layers with name.
|
||||
index = {}
|
||||
for layer in flattened_layers:
|
||||
if layer.name:
|
||||
index.setdefault(layer.name, []).append(layer)
|
||||
|
||||
# we batch weight value assignments in a single backend call
|
||||
# which provides a speedup in TensorFlow.
|
||||
weight_value_tuples = []
|
||||
for k, name in enumerate(layer_names):
|
||||
g = f[name]
|
||||
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
|
||||
weight_values = [g[weight_name] for weight_name in weight_names]
|
||||
|
||||
for layer in index.get(name, []):
|
||||
symbolic_weights = layer.weights
|
||||
if len(weight_values) != len(symbolic_weights):
|
||||
raise Exception('Layer #' + str(k) +
|
||||
' (named "' + layer.name +
|
||||
'") expects ' +
|
||||
str(len(symbolic_weights)) +
|
||||
' weight(s), but the saved weights' +
|
||||
' have ' + str(len(weight_values)) +
|
||||
' element(s).')
|
||||
# set values
|
||||
for i in range(len(weight_values)):
|
||||
weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
def _updated_config(self):
|
||||
'''shared between different serialization methods'''
|
||||
from keras import __version__ as keras_version
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
@ -28,7 +29,7 @@ def test_sequential_model_saving():
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
|
||||
new_model = load_model(fname)
|
||||
@ -62,7 +63,7 @@ def test_sequential_model_saving_2():
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
|
||||
model = load_model(fname,
|
||||
@ -89,7 +90,7 @@ def test_fuctional_model_saving():
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
|
||||
model = load_model(fname)
|
||||
@ -106,7 +107,7 @@ def test_saving_without_compilation():
|
||||
model.add(Dense(3))
|
||||
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
|
||||
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
model = load_model(fname)
|
||||
os.remove(fname)
|
||||
@ -120,11 +121,116 @@ def test_saving_right_after_compilation():
|
||||
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
|
||||
model.model._make_train_function()
|
||||
|
||||
fname = 'tmp_' + str(np.random.randint(10000)) + '.h5'
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
model = load_model(fname)
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_loading_weights_by_name():
|
||||
"""
|
||||
test loading model weights by name on:
|
||||
- sequential model
|
||||
"""
|
||||
|
||||
# test with custom optimizer, loss
|
||||
custom_opt = optimizers.rmsprop
|
||||
custom_loss = objectives.mse
|
||||
|
||||
# sequential model
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3, name="rick"))
|
||||
model.add(Dense(3, name="morty"))
|
||||
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
|
||||
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
old_weights = [layer.get_weights() for layer in model.layers]
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
|
||||
model.save_weights(fname)
|
||||
|
||||
# delete and recreate model
|
||||
del(model)
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3, name="rick"))
|
||||
model.add(Dense(3, name="morty"))
|
||||
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
|
||||
|
||||
# load weights from first model
|
||||
model.load_weights(fname, by_name=True)
|
||||
os.remove(fname)
|
||||
|
||||
out2 = model.predict(x)
|
||||
assert_allclose(out, out2, atol=1e-05)
|
||||
for i in range(len(model.layers)):
|
||||
new_weights = model.layers[i].get_weights()
|
||||
for j in range(len(new_weights)):
|
||||
assert_allclose(old_weights[i][j], new_weights[j], atol=1e-05)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_loading_weights_by_name_2():
|
||||
"""
|
||||
test loading model weights by name on:
|
||||
- both sequential and functional api models
|
||||
- different architecture with shared names
|
||||
"""
|
||||
|
||||
# test with custom optimizer, loss
|
||||
custom_opt = optimizers.rmsprop
|
||||
custom_loss = objectives.mse
|
||||
|
||||
# sequential model
|
||||
model = Sequential()
|
||||
model.add(Dense(2, input_dim=3, name="rick"))
|
||||
model.add(Dense(3, name="morty"))
|
||||
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
|
||||
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
old_weights = [layer.get_weights() for layer in model.layers]
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
|
||||
model.save_weights(fname)
|
||||
|
||||
# delete and recreate model using Functional API
|
||||
del(model)
|
||||
data = Input(shape=(3,))
|
||||
rick = Dense(2, name="rick")(data)
|
||||
jerry = Dense(3, name="jerry")(rick) # add 2 layers (but maintain shapes)
|
||||
jessica = Dense(2, name="jessica")(jerry)
|
||||
morty = Dense(3, name="morty")(jessica)
|
||||
|
||||
model = Model(input=[data], output=[morty])
|
||||
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])
|
||||
|
||||
# load weights from first model
|
||||
model.load_weights(fname, by_name=True)
|
||||
os.remove(fname)
|
||||
|
||||
out2 = model.predict(x)
|
||||
assert np.max(np.abs(out - out2)) > 1e-05
|
||||
|
||||
rick = model.layers[1].get_weights()
|
||||
jerry = model.layers[2].get_weights()
|
||||
jessica = model.layers[3].get_weights()
|
||||
morty = model.layers[4].get_weights()
|
||||
|
||||
assert_allclose(old_weights[0][0], rick[0], atol=1e-05)
|
||||
assert_allclose(old_weights[0][1], rick[1], atol=1e-05)
|
||||
assert_allclose(old_weights[1][0], morty[0], atol=1e-05)
|
||||
assert_allclose(old_weights[1][1], morty[1], atol=1e-05)
|
||||
assert_allclose(np.zeros_like(jerry[1]), jerry[1]) # biases init to 0
|
||||
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
Loading…
Reference in New Issue
Block a user