From a9fc2bed49c7988ccd6d365c1542818172b64ffe Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sun, 31 Jul 2016 22:58:44 -0700 Subject: [PATCH] Allow to call load_weights on model save file --- keras/engine/topology.py | 5 ++++- tests/test_model_saving.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/keras/engine/topology.py b/keras/engine/topology.py index e0d973b5d..bdef7c8f0 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -2411,8 +2411,11 @@ class Container(Layer): ''' import h5py f = h5py.File(filepath, mode='r') + if 'layer_names' not in f.attrs and 'model_weights' in f: + f = f['model_weights'] self.load_weights_from_hdf5_group(f) - f.close() + if hasattr(f, 'close'): + f.close() def load_weights_from_hdf5_group(self, f): '''Weight loading is based on layer order in a list diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py index e6b7b31b7..cf7a612c1 100644 --- a/tests/test_model_saving.py +++ b/tests/test_model_saving.py @@ -29,7 +29,6 @@ def test_sequential_model_saving(): save_model(model, fname) new_model = load_model(fname) - os.remove(fname) out2 = new_model.predict(x) assert_allclose(out, out2, atol=1e-05) @@ -43,6 +42,10 @@ def test_sequential_model_saving(): out2 = new_model.predict(x) assert_allclose(out, out2, atol=1e-05) + # test load_weights on model file + model.load_weights(fname) + os.remove(fname) + @keras_test def test_sequential_model_saving_2():