Added functions to save/load model weights to/from HDF5 file

This commit is contained in:
Joao Felipe Santos 2015-04-19 12:28:34 -04:00
parent 4bb90a5bfd
commit fbc7dece87

@ -211,7 +211,31 @@ class Sequential(object):
return tot_score/len(batches)
def save(self, filepath):
# Save weights from all layers to HDF5
import h5py
# FIXME: fail if file exists, or add option to overwrite!
f = h5py.File(filepath, 'w')
f.attrs['n_layers'] = len(self.layers)
for k, l in enumerate(self.layers):
g = f.create_group('layer_{}'.format(k))
weights = l.get_weights()
g.attrs['n_params'] = len(weights)
for n, param in enumerate(weights):
param_name = 'param_{}'.format(n)
param_dset = g.create_dataset(param_name, param.shape, dtype='float64')
param_dset[:] = param
f.flush()
f.close()
def load(self, filepath):
# Loads weights from HDF5 file
import h5py
f = h5py.File(filepath)
for k in range(f.attrs['n_layers']):
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['n_params'])]
self.layers[k].set_weights(weights)
f.close()