Added test script to save_weights/load_weights

This commit is contained in:
Joao Felipe Santos 2015-04-19 14:33:49 -04:00
parent fbc7dece87
commit 8e7e755295
2 changed files with 46 additions and 6 deletions

@ -211,16 +211,16 @@ class Sequential(object):
return tot_score/len(batches) return tot_score/len(batches)
def save(self, filepath): def save_weights(self, filepath):
# Save weights from all layers to HDF5 # Save weights from all layers to HDF5
import h5py import h5py
# FIXME: fail if file exists, or add option to overwrite! # FIXME: fail if file exists, or add option to overwrite!
f = h5py.File(filepath, 'w') f = h5py.File(filepath, 'w')
f.attrs['n_layers'] = len(self.layers) f.attrs['nb_layers'] = len(self.layers)
for k, l in enumerate(self.layers): for k, l in enumerate(self.layers):
g = f.create_group('layer_{}'.format(k)) g = f.create_group('layer_{}'.format(k))
weights = l.get_weights() weights = l.get_weights()
g.attrs['n_params'] = len(weights) g.attrs['nb_params'] = len(weights)
for n, param in enumerate(weights): for n, param in enumerate(weights):
param_name = 'param_{}'.format(n) param_name = 'param_{}'.format(n)
param_dset = g.create_dataset(param_name, param.shape, dtype='float64') param_dset = g.create_dataset(param_name, param.shape, dtype='float64')
@ -228,13 +228,13 @@ class Sequential(object):
f.flush() f.flush()
f.close() f.close()
def load(self, filepath): def load_weights(self, filepath):
# Loads weights from HDF5 file # Loads weights from HDF5 file
import h5py import h5py
f = h5py.File(filepath) f = h5py.File(filepath)
for k in range(f.attrs['n_layers']): for k in range(f.attrs['nb_layers']):
g = f['layer_{}'.format(k)] g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['n_params'])] weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
self.layers[k].set_weights(weights) self.layers[k].set_weights(weights)
f.close() f.close()

40
test/test_save_weights.py Normal file

@ -0,0 +1,40 @@
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
import sys
sys.setrecursionlimit(10000) # to be able to pickle Theano compiled functions
import pickle, numpy
def create_model():
model = Sequential()
model.add(Dense(256, 2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2048, 2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2048, 2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2048, 2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2048, 256, init='uniform', activation='linear'))
return model
model = create_model()
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mse', optimizer=sgd)
pickle.dump(model, open('/tmp/model.pkl', 'wb'))
model.save_weights('/tmp/model_weights.hdf5')
model_loaded = create_model()
model_loaded.load_weights('/tmp/model_weights.hdf5')
for k in range(len(model.layers)):
weights_orig = model.layers[k].get_weights()
weights_loaded = model_loaded.layers[k].get_weights()
for x, y in zip(weights_orig, weights_loaded):
if numpy.any(x != y):
raise ValueError('Loaded weights are different from pickled weights!')