Allow to call load_weights on model save file
This commit is contained in:
parent
1855c49d1f
commit
a9fc2bed49
@ -2411,7 +2411,10 @@ class Container(Layer):
|
||||
'''
|
||||
import h5py
|
||||
f = h5py.File(filepath, mode='r')
|
||||
if 'layer_names' not in f.attrs and 'model_weights' in f:
|
||||
f = f['model_weights']
|
||||
self.load_weights_from_hdf5_group(f)
|
||||
if hasattr(f, 'close'):
|
||||
f.close()
|
||||
|
||||
def load_weights_from_hdf5_group(self, f):
|
||||
|
@ -29,7 +29,6 @@ def test_sequential_model_saving():
|
||||
save_model(model, fname)
|
||||
|
||||
new_model = load_model(fname)
|
||||
os.remove(fname)
|
||||
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2, atol=1e-05)
|
||||
@ -43,6 +42,10 @@ def test_sequential_model_saving():
|
||||
out2 = new_model.predict(x)
|
||||
assert_allclose(out, out2, atol=1e-05)
|
||||
|
||||
# test load_weights on model file
|
||||
model.load_weights(fname)
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sequential_model_saving_2():
|
||||
|
Loading…
Reference in New Issue
Block a user