Allow to call load_weights on model save file
This commit is contained in:
parent
1855c49d1f
commit
a9fc2bed49
@ -2411,7 +2411,10 @@ class Container(Layer):
|
|||||||
'''
|
'''
|
||||||
import h5py
|
import h5py
|
||||||
f = h5py.File(filepath, mode='r')
|
f = h5py.File(filepath, mode='r')
|
||||||
|
if 'layer_names' not in f.attrs and 'model_weights' in f:
|
||||||
|
f = f['model_weights']
|
||||||
self.load_weights_from_hdf5_group(f)
|
self.load_weights_from_hdf5_group(f)
|
||||||
|
if hasattr(f, 'close'):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def load_weights_from_hdf5_group(self, f):
|
def load_weights_from_hdf5_group(self, f):
|
||||||
|
@ -29,7 +29,6 @@ def test_sequential_model_saving():
|
|||||||
save_model(model, fname)
|
save_model(model, fname)
|
||||||
|
|
||||||
new_model = load_model(fname)
|
new_model = load_model(fname)
|
||||||
os.remove(fname)
|
|
||||||
|
|
||||||
out2 = new_model.predict(x)
|
out2 = new_model.predict(x)
|
||||||
assert_allclose(out, out2, atol=1e-05)
|
assert_allclose(out, out2, atol=1e-05)
|
||||||
@ -43,6 +42,10 @@ def test_sequential_model_saving():
|
|||||||
out2 = new_model.predict(x)
|
out2 = new_model.predict(x)
|
||||||
assert_allclose(out, out2, atol=1e-05)
|
assert_allclose(out, out2, atol=1e-05)
|
||||||
|
|
||||||
|
# test load_weights on model file
|
||||||
|
model.load_weights(fname)
|
||||||
|
os.remove(fname)
|
||||||
|
|
||||||
|
|
||||||
@keras_test
|
@keras_test
|
||||||
def test_sequential_model_saving_2():
|
def test_sequential_model_saving_2():
|
||||||
|
Loading…
Reference in New Issue
Block a user