Make deconv VAE compatible with both dim orderings

This commit is contained in:
Francois Chollet 2016-09-30 16:26:50 -07:00
parent 3bf8964355
commit 8fab33c245

@ -1,4 +1,5 @@
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers.
'''This script demonstrates how to build a variational autoencoder
with Keras and deconvolution layers.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''
@ -6,14 +7,12 @@ import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D
from keras.layers import Convolution2D, Deconvolution2D
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
K.set_image_dim_ordering('th') # this is a Theano oriented example
# input image dimensions
img_rows, img_cols, img_chns = 28, 28, 1
# number of convolutional filters to use
@ -22,14 +21,16 @@ nb_filters = 64
nb_conv = 3
batch_size = 100
original_dim = (img_chns, img_rows, img_cols)
if K.image_dim_ordering() == 'th':
original_img_size = (img_chns, img_rows, img_cols)
else:
original_img_size = (img_rows, img_cols, img_chns)
latent_dim = 2
intermediate_dim = 128
epsilon_std = 0.01
nb_epoch = 5
x = Input(batch_shape=(batch_size,) + original_dim)
x = Input(batch_shape=(batch_size,) + original_img_size)
conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x)
conv_2 = Convolution2D(nb_filters, 2, 2,
border_mode='same', activation='relu',
@ -60,23 +61,35 @@ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# we instantiate these layers separately so as to reuse them later
decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu')
decoder_reshape = Reshape((nb_filters, 14, 14))
if K.image_dim_ordering() == 'th':
output_shape = (batch_size, nb_filters, 14, 14)
else:
output_shape = (batch_size, 14, 14, nb_filters)
decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
(batch_size, nb_filters, 14, 14),
output_shape,
border_mode='same',
subsample=(1, 1),
activation='relu')
decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
(batch_size, nb_filters, 14, 14),
output_shape,
border_mode='same',
subsample=(1, 1),
activation='relu')
if K.image_dim_ordering() == 'th':
output_shape = (batch_size, nb_filters, 29, 29)
else:
output_shape = (batch_size, 29, 29, nb_filters)
decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2,
(batch_size, nb_filters, 29, 29),
output_shape,
border_mode='valid',
subsample=(2, 2),
activation='relu')
decoder_mean_squash = Convolution2D(img_chns, 2, 2, border_mode='valid', activation='sigmoid')
decoder_mean_squash = Convolution2D(img_chns, 2, 2,
border_mode='valid',
activation='sigmoid')
hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
@ -87,7 +100,8 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
def vae_loss(x, x_decoded_mean):
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these!
# NOTE: binary_crossentropy expects a batch_size by dim
# for x and x_decoded_mean, so we MUST flatten these!
x = K.flatten(x)
x_decoded_mean = K.flatten(x_decoded_mean)
xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
@ -99,12 +113,14 @@ vae.compile(optimizer='rmsprop', loss=vae_loss)
vae.summary()
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')[:, None, :, :] / 255.
x_test = x_test.astype('float32')[:, None, :, :] / 255.
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
print(x_train.shape)
print('x_train.shape:', x_train.shape)
vae.fit(x_train, x_train,
shuffle=True,
@ -112,7 +128,6 @@ vae.fit(x_train, x_train,
batch_size=batch_size,
validation_data=(x_test, x_test))
# build a model to project inputs on the latent space
encoder = Model(x, z_mean)