Added functions to save/load model weights to/from HDF5 file
This commit is contained in:
parent
4bb90a5bfd
commit
fbc7dece87
@ -211,7 +211,31 @@ class Sequential(object):
|
||||
return tot_score/len(batches)
|
||||
|
||||
|
||||
|
||||
|
||||
def save(self, filepath):
|
||||
# Save weights from all layers to HDF5
|
||||
import h5py
|
||||
# FIXME: fail if file exists, or add option to overwrite!
|
||||
f = h5py.File(filepath, 'w')
|
||||
f.attrs['n_layers'] = len(self.layers)
|
||||
for k, l in enumerate(self.layers):
|
||||
g = f.create_group('layer_{}'.format(k))
|
||||
weights = l.get_weights()
|
||||
g.attrs['n_params'] = len(weights)
|
||||
for n, param in enumerate(weights):
|
||||
param_name = 'param_{}'.format(n)
|
||||
param_dset = g.create_dataset(param_name, param.shape, dtype='float64')
|
||||
param_dset[:] = param
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
def load(self, filepath):
|
||||
# Loads weights from HDF5 file
|
||||
import h5py
|
||||
f = h5py.File(filepath)
|
||||
for k in range(f.attrs['n_layers']):
|
||||
g = f['layer_{}'.format(k)]
|
||||
weights = [g['param_{}'.format(p)] for p in range(g.attrs['n_params'])]
|
||||
self.layers[k].set_weights(weights)
|
||||
f.close()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user