Make deconv VAE compatible with both dim orderings
This commit is contained in:
parent
3bf8964355
commit
8fab33c245
@ -1,4 +1,5 @@
|
|||||||
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers.
|
'''This script demonstrates how to build a variational autoencoder
|
||||||
|
with Keras and deconvolution layers.
|
||||||
|
|
||||||
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
|
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
|
||||||
'''
|
'''
|
||||||
@ -6,14 +7,12 @@ import numpy as np
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
|
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
|
||||||
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D
|
from keras.layers import Convolution2D, Deconvolution2D
|
||||||
from keras.models import Model
|
from keras.models import Model
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
from keras import objectives
|
from keras import objectives
|
||||||
from keras.datasets import mnist
|
from keras.datasets import mnist
|
||||||
|
|
||||||
K.set_image_dim_ordering('th') # this is a Theano oriented example
|
|
||||||
|
|
||||||
# input image dimensions
|
# input image dimensions
|
||||||
img_rows, img_cols, img_chns = 28, 28, 1
|
img_rows, img_cols, img_chns = 28, 28, 1
|
||||||
# number of convolutional filters to use
|
# number of convolutional filters to use
|
||||||
@ -22,14 +21,16 @@ nb_filters = 64
|
|||||||
nb_conv = 3
|
nb_conv = 3
|
||||||
|
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
original_dim = (img_chns, img_rows, img_cols)
|
if K.image_dim_ordering() == 'th':
|
||||||
|
original_img_size = (img_chns, img_rows, img_cols)
|
||||||
|
else:
|
||||||
|
original_img_size = (img_rows, img_cols, img_chns)
|
||||||
latent_dim = 2
|
latent_dim = 2
|
||||||
intermediate_dim = 128
|
intermediate_dim = 128
|
||||||
epsilon_std = 0.01
|
epsilon_std = 0.01
|
||||||
nb_epoch = 5
|
nb_epoch = 5
|
||||||
|
|
||||||
|
x = Input(batch_shape=(batch_size,) + original_img_size)
|
||||||
x = Input(batch_shape=(batch_size,) + original_dim)
|
|
||||||
conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x)
|
conv_1 = Convolution2D(img_chns, 2, 2, border_mode='same', activation='relu')(x)
|
||||||
conv_2 = Convolution2D(nb_filters, 2, 2,
|
conv_2 = Convolution2D(nb_filters, 2, 2,
|
||||||
border_mode='same', activation='relu',
|
border_mode='same', activation='relu',
|
||||||
@ -60,23 +61,35 @@ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
|
|||||||
# we instantiate these layers separately so as to reuse them later
|
# we instantiate these layers separately so as to reuse them later
|
||||||
decoder_hid = Dense(intermediate_dim, activation='relu')
|
decoder_hid = Dense(intermediate_dim, activation='relu')
|
||||||
decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu')
|
decoder_upsample = Dense(nb_filters * 14 * 14, activation='relu')
|
||||||
decoder_reshape = Reshape((nb_filters, 14, 14))
|
|
||||||
|
if K.image_dim_ordering() == 'th':
|
||||||
|
output_shape = (batch_size, nb_filters, 14, 14)
|
||||||
|
else:
|
||||||
|
output_shape = (batch_size, 14, 14, nb_filters)
|
||||||
|
|
||||||
|
decoder_reshape = Reshape(output_shape[1:])
|
||||||
decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
decoder_deconv_1 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
||||||
(batch_size, nb_filters, 14, 14),
|
output_shape,
|
||||||
border_mode='same',
|
border_mode='same',
|
||||||
subsample=(1, 1),
|
subsample=(1, 1),
|
||||||
activation='relu')
|
activation='relu')
|
||||||
decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
decoder_deconv_2 = Deconvolution2D(nb_filters, nb_conv, nb_conv,
|
||||||
(batch_size, nb_filters, 14, 14),
|
output_shape,
|
||||||
border_mode='same',
|
border_mode='same',
|
||||||
subsample=(1, 1),
|
subsample=(1, 1),
|
||||||
activation='relu')
|
activation='relu')
|
||||||
|
if K.image_dim_ordering() == 'th':
|
||||||
|
output_shape = (batch_size, nb_filters, 29, 29)
|
||||||
|
else:
|
||||||
|
output_shape = (batch_size, 29, 29, nb_filters)
|
||||||
decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2,
|
decoder_deconv_3_upsamp = Deconvolution2D(nb_filters, 2, 2,
|
||||||
(batch_size, nb_filters, 29, 29),
|
output_shape,
|
||||||
border_mode='valid',
|
border_mode='valid',
|
||||||
subsample=(2, 2),
|
subsample=(2, 2),
|
||||||
activation='relu')
|
activation='relu')
|
||||||
decoder_mean_squash = Convolution2D(img_chns, 2, 2, border_mode='valid', activation='sigmoid')
|
decoder_mean_squash = Convolution2D(img_chns, 2, 2,
|
||||||
|
border_mode='valid',
|
||||||
|
activation='sigmoid')
|
||||||
|
|
||||||
hid_decoded = decoder_hid(z)
|
hid_decoded = decoder_hid(z)
|
||||||
up_decoded = decoder_upsample(hid_decoded)
|
up_decoded = decoder_upsample(hid_decoded)
|
||||||
@ -87,7 +100,8 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
|
|||||||
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
|
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
|
||||||
|
|
||||||
def vae_loss(x, x_decoded_mean):
|
def vae_loss(x, x_decoded_mean):
|
||||||
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these!
|
# NOTE: binary_crossentropy expects a batch_size by dim
|
||||||
|
# for x and x_decoded_mean, so we MUST flatten these!
|
||||||
x = K.flatten(x)
|
x = K.flatten(x)
|
||||||
x_decoded_mean = K.flatten(x_decoded_mean)
|
x_decoded_mean = K.flatten(x_decoded_mean)
|
||||||
xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
|
xent_loss = img_rows * img_cols * objectives.binary_crossentropy(x, x_decoded_mean)
|
||||||
@ -99,12 +113,14 @@ vae.compile(optimizer='rmsprop', loss=vae_loss)
|
|||||||
vae.summary()
|
vae.summary()
|
||||||
|
|
||||||
# train the VAE on MNIST digits
|
# train the VAE on MNIST digits
|
||||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
(x_train, _), (x_test, y_test) = mnist.load_data()
|
||||||
|
|
||||||
x_train = x_train.astype('float32')[:, None, :, :] / 255.
|
x_train = x_train.astype('float32') / 255.
|
||||||
x_test = x_test.astype('float32')[:, None, :, :] / 255.
|
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
|
||||||
|
x_test = x_test.astype('float32') / 255.
|
||||||
|
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
|
||||||
|
|
||||||
print(x_train.shape)
|
print('x_train.shape:', x_train.shape)
|
||||||
|
|
||||||
vae.fit(x_train, x_train,
|
vae.fit(x_train, x_train,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
@ -112,7 +128,6 @@ vae.fit(x_train, x_train,
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
validation_data=(x_test, x_test))
|
validation_data=(x_test, x_test))
|
||||||
|
|
||||||
|
|
||||||
# build a model to project inputs on the latent space
|
# build a model to project inputs on the latent space
|
||||||
encoder = Model(x, z_mean)
|
encoder = Model(x, z_mean)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user