Fix autoencoder serialization
This commit is contained in:
parent
5e9579aeac
commit
78feed7fa9
@ -5,7 +5,7 @@ import theano
|
||||
import copy
|
||||
|
||||
from ..layers.advanced_activations import LeakyReLU, PReLU
|
||||
from ..layers.core import Dense, Merge, Dropout, Activation, Reshape, Flatten, RepeatVector, Layer
|
||||
from ..layers.core import Dense, Merge, Dropout, Activation, Reshape, Flatten, RepeatVector, Layer, AutoEncoder
|
||||
from ..layers.core import ActivityRegularization, TimeDistributedDense, AutoEncoder, MaxoutDense
|
||||
from ..layers.convolutional import Convolution1D, Convolution2D, MaxPooling1D, MaxPooling2D, ZeroPadding2D
|
||||
from ..layers.embeddings import Embedding, WordContextProduct
|
||||
@ -58,6 +58,14 @@ def container_from_config(original_layer_dict):
|
||||
graph_layer.add_output(**output)
|
||||
return graph_layer
|
||||
|
||||
elif name == 'AutoEncoder':
|
||||
kwargs = {'encoder': container_from_config(layer_dict.get('encoder_config')),
|
||||
'decoder': container_from_config(layer_dict.get('decoder_config'))}
|
||||
for kwarg in ['output_reconstruction', 'weights']:
|
||||
if kwarg in layer_dict:
|
||||
kwargs[kwarg] = layer_dict[kwarg]
|
||||
return AutoEncoder(**kwargs)
|
||||
|
||||
else:
|
||||
layer_dict.pop('name')
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from keras.datasets import mnist
|
||||
from keras.models import Sequential
|
||||
from keras.models import Sequential, model_from_config
|
||||
from keras.layers.core import AutoEncoder, Dense, Activation, TimeDistributedDense, Flatten
|
||||
from keras.layers.recurrent import LSTM
|
||||
from keras.layers.embeddings import Embedding
|
||||
@ -57,6 +57,7 @@ print('\nclassical_score:', classical_score)
|
||||
# autoencoder model test #
|
||||
##########################
|
||||
|
||||
|
||||
def build_lstm_autoencoder(autoencoder, X_train, X_test):
|
||||
X_train = X_train[:, np.newaxis, :]
|
||||
X_test = X_test[:, np.newaxis, :]
|
||||
@ -95,7 +96,6 @@ for autoencoder_type in ['classical', 'lstm']:
|
||||
print("Error: unknown autoencoder type!")
|
||||
exit(-1)
|
||||
|
||||
autoencoder.get_config(verbose=1)
|
||||
autoencoder.compile(loss='mean_squared_error', optimizer='adam')
|
||||
# Do NOT use validation data with return output_reconstruction=True
|
||||
autoencoder.fit(X_train, X_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=False, verbose=1)
|
||||
@ -128,3 +128,7 @@ for autoencoder_type in ['classical', 'lstm']:
|
||||
|
||||
print('Loss change:', (score[0] - classical_score[0])/classical_score[0], '%')
|
||||
print('Accuracy change:', (score[1] - classical_score[1])/classical_score[1], '%')
|
||||
|
||||
# check serialization
|
||||
config = autoencoder.get_config(verbose=1)
|
||||
autoencoder = model_from_config(config)
|
||||
|
Loading…
Reference in New Issue
Block a user