From 8fab33c245208f9c39f55292af9773574d891a16 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 30 Sep 2016 16:26:50 -0700 Subject: [PATCH] Make deconv VAE compatible with both dim orderings --- examples/variational_autoencoder_deconv.py | 51 ++++++++++++++-------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/examples/variational_autoencoder_deconv.py b/examples/variational_autoencoder_deconv.py index d70a1d243..25821eca0 100644 --- a/examples/variational_autoencoder_deconv.py +++ b/examples/variational_autoencoder_deconv.py @@ -1,4 +1,5 @@ -'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers. +'''This script demonstrates how to build a variational autoencoder +with Keras and deconvolution layers. Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114 ''' @@ -6,14 +7,12 @@ import numpy as np import matplotlib.pyplot as plt from keras.layers import Input, Dense, Lambda, Flatten, Reshape -from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D +from keras.layers import Convolution2D, Deconvolution2D from keras.models import Model from keras import backend as K from keras import objectives from keras.datasets import mnist -K.set_image_dim_ordering('th') # this is a Theano oriented example - # input image dimensions img_rows, img_cols, img_chns = 28, 28, 1 # number of convolutional filters to use @@ -22,14 +21,16 @@ nb_filters = 64 nb_conv = 3 batch_size = 100 -original_dim = (img_chns, img_rows, img_cols) +if K.image_dim_ordering() == 'th': + original_img_size = (img_chns, img_rows, img_cols) +else: + original_img_size = (img_rows, img_cols, img_chns) latent_dim = 2 intermediate_dim = 128 epsilon_std = 0.01 nb_epoch = 5 - -x = Input(batch_shape=(batch_size,) + original_dim) +x = Input(batch_shape=(batch_size,) + original_img_size) conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x) conv_2 = Convolution2D(nb_filters, 2, 2, border_mode='same', activation='relu', @@ -60,23 +61,35 @@ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var]) # we instantiate these layers separately so as to reuse them later decoder_hid = Dense(intermediate_dim, activation='relu') decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu') -decoder_reshape = Reshape((nb_filters, 14, 14)) + +if K.image_dim_ordering() == 'th': + output_shape = (batch_size, nb_filters, 14, 14) +else: + output_shape = (batch_size, 14, 14, nb_filters) + +decoder_reshape = Reshape(output_shape[1:]) decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv, - (batch_size, nb_filters, 14, 14), + output_shape, border_mode='same', subsample=(1, 1), activation='relu') decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv, - (batch_size, nb_filters, 14, 14), + output_shape, border_mode='same', subsample=(1, 1), activation='relu') +if K.image_dim_ordering() == 'th': + output_shape = (batch_size, nb_filters, 29, 29) +else: + output_shape = (batch_size, 29, 29, nb_filters) decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2, - (batch_size, nb_filters, 29, 29), + output_shape, border_mode='valid', subsample=(2, 2), activation='relu') -decoder_mean_squash = Convolution2D(img_chns, 2, 2, border_mode='valid', activation='sigmoid') +decoder_mean_squash = Convolution2D(img_chns, 2, 2, + border_mode='valid', + activation='sigmoid') hid_decoded = decoder_hid(z) up_decoded = decoder_upsample(hid_decoded) @@ -87,7 +100,8 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded) x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu) def vae_loss(x, x_decoded_mean): - # NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these! + # NOTE: binary_crossentropy expects a batch_size by dim + # for x and x_decoded_mean, so we MUST flatten these! x = K.flatten(x) x_decoded_mean = K.flatten(x_decoded_mean) xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean) @@ -99,12 +113,14 @@ vae.compile(optimizer='rmsprop', loss=vae_loss) vae.summary() # train the VAE on MNIST digits -(x_train, y_train), (x_test, y_test) = mnist.load_data() +(x_train, _), (x_test, y_test) = mnist.load_data() -x_train = x_train.astype('float32')[:, None, :, :] / 255. -x_test = x_test.astype('float32')[:, None, :, :] / 255. +x_train = x_train.astype('float32') / 255. +x_train = x_train.reshape((x_train.shape[0],) + original_img_size) +x_test = x_test.astype('float32') / 255. +x_test = x_test.reshape((x_test.shape[0],) + original_img_size) -print(x_train.shape) +print('x_train.shape:', x_train.shape) vae.fit(x_train, x_train, shuffle=True, @@ -112,7 +128,6 @@ vae.fit(x_train, x_train, batch_size=batch_size, validation_data=(x_test, x_test)) - # build a model to project inputs on the latent space encoder = Model(x, z_mean)