Make deconv VAE compatible with both dim orderings
This commit is contained in:
parent
3bf8964355
commit
8fab33c245
@ -1,4 +1,5 @@
|
||||
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers.
|
||||
'''This script demonstrates how to build a variational autoencoder
|
||||
with Keras and deconvolution layers.
|
||||
|
||||
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
|
||||
'''
|
||||
@ -6,14 +7,12 @@ import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
|
||||
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D
|
||||
from keras.layers import Convolution2D, Deconvolution2D
|
||||
from keras.models import Model
|
||||
from keras import backend as K
|
||||
from keras import objectives
|
||||
from keras.datasets import mnist
|
||||
|
||||
K.set_image_dim_ordering('th') # this is a Theano oriented example
|
||||
|
||||
# input image dimensions
|
||||
img_rows, img_cols, img_chns = 28, 28, 1
|
||||
# number of convolutional filters to use
|
||||
@ -22,14 +21,16 @@ nb_filters = 64
|
||||
nb_conv = 3
|
||||
|
||||
batch_size = 100
|
||||
original_dim = (img_chns, img_rows, img_cols)
|
||||
if K.image_dim_ordering() == 'th':
|
||||
original_img_size = (img_chns, img_rows, img_cols)
|
||||
else:
|
||||
original_img_size = (img_rows, img_cols, img_chns)
|
||||
latent_dim = 2
|
||||
intermediate_dim = 128
|
||||
epsilon_std = 0.01
|
||||
nb_epoch = 5
|
||||
|
||||
|
||||
x = Input(batch_shape=(batch_size,) + original_dim)
|
||||
x = Input(batch_shape=(batch_size,) + original_img_size)
|
||||
conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x)
|
||||
conv_2 = Convolution2D(nb_filters, 2, 2,
|
||||
border_mode='same', activation='relu',
|
||||
@ -60,23 +61,35 @@ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
|
||||
# we instantiate these layers separately so as to reuse them later
|
||||
decoder_hid = Dense(intermediate_dim, activation='relu')
|
||||
decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu')
|
||||
decoder_reshape = Reshape((nb_filters, 14, 14))
|
||||
|
||||
if K.image_dim_ordering() == 'th':
|
||||
output_shape = (batch_size, nb_filters, 14, 14)
|
||||
else:
|
||||
output_shape = (batch_size, 14, 14, nb_filters)
|
||||
|
||||
decoder_reshape = Reshape(output_shape[1:])
|
||||
decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
||||
(batch_size, nb_filters, 14, 14),
|
||||
output_shape,
|
||||
border_mode='same',
|
||||
subsample=(1, 1),
|
||||
activation='relu')
|
||||
decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
||||
(batch_size, nb_filters, 14, 14),
|
||||
output_shape,
|
||||
border_mode='same',
|
||||
subsample=(1, 1),
|
||||
activation='relu')
|
||||
if K.image_dim_ordering() == 'th':
|
||||
output_shape = (batch_size, nb_filters, 29, 29)
|
||||
else:
|
||||
output_shape = (batch_size, 29, 29, nb_filters)
|
||||
decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2,
|
||||
(batch_size, nb_filters, 29, 29),
|
||||
output_shape,
|
||||
border_mode='valid',
|
||||
subsample=(2, 2),
|
||||
activation='relu')
|
||||
decoder_mean_squash = Convolution2D(img_chns, 2, 2, border_mode='valid', activation='sigmoid')
|
||||
decoder_mean_squash = Convolution2D(img_chns, 2, 2,
|
||||
border_mode='valid',
|
||||
activation='sigmoid')
|
||||
|
||||
hid_decoded = decoder_hid(z)
|
||||
up_decoded = decoder_upsample(hid_decoded)
|
||||
@ -87,7 +100,8 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
|
||||
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
|
||||
|
||||
def vae_loss(x, x_decoded_mean):
|
||||
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these!
|
||||
# NOTE: binary_crossentropy expects a batch_size by dim
|
||||
# for x and x_decoded_mean, so we MUST flatten these!
|
||||
x = K.flatten(x)
|
||||
x_decoded_mean = K.flatten(x_decoded_mean)
|
||||
xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
|
||||
@ -99,12 +113,14 @@ vae.compile(optimizer='rmsprop', loss=vae_loss)
|
||||
vae.summary()
|
||||
|
||||
# train the VAE on MNIST digits
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
(x_train, _), (x_test, y_test) = mnist.load_data()
|
||||
|
||||
x_train = x_train.astype('float32')[:, None, :, :] / 255.
|
||||
x_test = x_test.astype('float32')[:, None, :, :] / 255.
|
||||
x_train = x_train.astype('float32') / 255.
|
||||
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
|
||||
x_test = x_test.astype('float32') / 255.
|
||||
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
|
||||
|
||||
print(x_train.shape)
|
||||
print('x_train.shape:', x_train.shape)
|
||||
|
||||
vae.fit(x_train, x_train,
|
||||
shuffle=True,
|
||||
@ -112,7 +128,6 @@ vae.fit(x_train, x_train,
|
||||
batch_size=batch_size,
|
||||
validation_data=(x_test, x_test))
|
||||
|
||||
|
||||
# build a model to project inputs on the latent space
|
||||
encoder = Model(x, z_mean)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user