From 8e7e7552955168174eb652bc98b74aac96cf9dc3 Mon Sep 17 00:00:00 2001 From: Joao Felipe Santos Date: Sun, 19 Apr 2015 14:33:49 -0400 Subject: [PATCH] Added test script to save_weights/load_weights --- keras/models.py | 12 ++++++------ test/test_save_weights.py | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 test/test_save_weights.py diff --git a/keras/models.py b/keras/models.py index 56cef2b77..72ae00b49 100644 --- a/keras/models.py +++ b/keras/models.py @@ -211,16 +211,16 @@ class Sequential(object): return tot_score/len(batches) - def save(self, filepath): + def save_weights(self, filepath): # Save weights from all layers to HDF5 import h5py # FIXME: fail if file exists, or add option to overwrite! f = h5py.File(filepath, 'w') - f.attrs['n_layers'] = len(self.layers) + f.attrs['nb_layers'] = len(self.layers) for k, l in enumerate(self.layers): g = f.create_group('layer_{}'.format(k)) weights = l.get_weights() - g.attrs['n_params'] = len(weights) + g.attrs['nb_params'] = len(weights) for n, param in enumerate(weights): param_name = 'param_{}'.format(n) param_dset = g.create_dataset(param_name, param.shape, dtype='float64') @@ -228,13 +228,13 @@ class Sequential(object): f.flush() f.close() - def load(self, filepath): + def load_weights(self, filepath): # Loads weights from HDF5 file import h5py f = h5py.File(filepath) - for k in range(f.attrs['n_layers']): + for k in range(f.attrs['nb_layers']): g = f['layer_{}'.format(k)] - weights = [g['param_{}'.format(p)] for p in range(g.attrs['n_params'])] + weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])] self.layers[k].set_weights(weights) f.close() diff --git a/test/test_save_weights.py b/test/test_save_weights.py new file mode 100644 index 000000000..098815f17 --- /dev/null +++ b/test/test_save_weights.py @@ -0,0 +1,40 @@ +from keras.models import Sequential +from keras.layers.core import Dense, Dropout, Activation +from keras.optimizers import SGD + +import sys +sys.setrecursionlimit(10000) # to be able to pickle Theano compiled functions + +import pickle, numpy + +def create_model(): + model = Sequential() + model.add(Dense(256, 2048, init='uniform', activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(2048, 2048, init='uniform', activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(2048, 2048, init='uniform', activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(2048, 2048, init='uniform', activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(2048, 256, init='uniform', activation='linear')) + return model + +model = create_model() +sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) +model.compile(loss='mse', optimizer=sgd) + +pickle.dump(model, open('/tmp/model.pkl', 'wb')) +model.save_weights('/tmp/model_weights.hdf5') + +model_loaded = create_model() +model_loaded.load_weights('/tmp/model_weights.hdf5') + +for k in range(len(model.layers)): + weights_orig = model.layers[k].get_weights() + weights_loaded = model_loaded.layers[k].get_weights() + for x, y in zip(weights_orig, weights_loaded): + if numpy.any(x != y): + raise ValueError('Loaded weights are different from pickled weights!') + +