Make deconv VAE compatible with both dim orderings

This commit is contained in:
Francois Chollet 2016-09-30 16:26:50 -07:00
parent 3bf8964355
commit 8fab33c245

@ -1,4 +1,5 @@
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers. '''This script demonstrates how to build a variational autoencoder
with Keras and deconvolution layers.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114 Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
''' '''
@ -6,14 +7,12 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Flatten, Reshape from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D from keras.layers import Convolution2D, Deconvolution2D
from keras.models import Model from keras.models import Model
from keras import backend as K from keras import backend as K
from keras import objectives from keras import objectives
from keras.datasets import mnist from keras.datasets import mnist
K.set_image_dim_ordering('th') # this is a Theano oriented example
# input image dimensions # input image dimensions
img_rows, img_cols, img_chns = 28, 28, 1 img_rows, img_cols, img_chns = 28, 28, 1
# number of convolutional filters to use # number of convolutional filters to use
@ -22,14 +21,16 @@ nb_filters = 64
nb_conv = 3 nb_conv = 3
batch_size = 100 batch_size = 100
original_dim = (img_chns, img_rows, img_cols) if K.image_dim_ordering() == 'th':
original_img_size = (img_chns, img_rows, img_cols)
else:
original_img_size = (img_rows, img_cols, img_chns)
latent_dim = 2 latent_dim = 2
intermediate_dim = 128 intermediate_dim = 128
epsilon_std = 0.01 epsilon_std = 0.01
nb_epoch = 5 nb_epoch = 5
x = Input(batch_shape=(batch_size,) + original_img_size)
x = Input(batch_shape=(batch_size,) + original_dim)
conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x) conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x)
conv_2 = Convolution2D(nb_filters, 2, 2, conv_2 = Convolution2D(nb_filters, 2, 2,
border_mode='same', activation='relu', border_mode='same', activation='relu',
@ -60,23 +61,35 @@ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# we instantiate these layers separately so as to reuse them later # we instantiate these layers separately so as to reuse them later
decoder_hid = Dense(intermediate_dim, activation='relu') decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu') decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu')
decoder_reshape = Reshape((nb_filters, 14, 14))
if K.image_dim_ordering() == 'th':
output_shape = (batch_size, nb_filters, 14, 14)
else:
output_shape = (batch_size, 14, 14, nb_filters)
decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv, decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
(batch_size, nb_filters, 14, 14), output_shape,
border_mode='same', border_mode='same',
subsample=(1, 1), subsample=(1, 1),
activation='relu') activation='relu')
decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv, decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
(batch_size, nb_filters, 14, 14), output_shape,
border_mode='same', border_mode='same',
subsample=(1, 1), subsample=(1, 1),
activation='relu') activation='relu')
if K.image_dim_ordering() == 'th':
output_shape = (batch_size, nb_filters, 29, 29)
else:
output_shape = (batch_size, 29, 29, nb_filters)
decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2, decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2,
(batch_size, nb_filters, 29, 29), output_shape,
border_mode='valid', border_mode='valid',
subsample=(2, 2), subsample=(2, 2),
activation='relu') activation='relu')
decoder_mean_squash = Convolution2D(img_chns, 2, 2, border_mode='valid', activation='sigmoid') decoder_mean_squash = Convolution2D(img_chns, 2, 2,
border_mode='valid',
activation='sigmoid')
hid_decoded = decoder_hid(z) hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded) up_decoded = decoder_upsample(hid_decoded)
@ -87,7 +100,8 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu) x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
def vae_loss(x, x_decoded_mean): def vae_loss(x, x_decoded_mean):
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these! # NOTE: binary_crossentropy expects a batch_size by dim
# for x and x_decoded_mean, so we MUST flatten these!
x = K.flatten(x) x = K.flatten(x)
x_decoded_mean = K.flatten(x_decoded_mean) x_decoded_mean = K.flatten(x_decoded_mean)
xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean) xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
@ -99,12 +113,14 @@ vae.compile(optimizer='rmsprop', loss=vae_loss)
vae.summary() vae.summary()
# train the VAE on MNIST digits # train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data() (x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')[:, None, :, :] / 255. x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32')[:, None, :, :] / 255. x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
print(x_train.shape) print('x_train.shape:', x_train.shape)
vae.fit(x_train, x_train, vae.fit(x_train, x_train,
shuffle=True, shuffle=True,
@ -112,7 +128,6 @@ vae.fit(x_train, x_train,
batch_size=batch_size, batch_size=batch_size,
validation_data=(x_test, x_test)) validation_data=(x_test, x_test))
# build a model to project inputs on the latent space # build a model to project inputs on the latent space
encoder = Model(x, z_mean) encoder = Model(x, z_mean)