2016-09-30 23:26:50 +00:00
|
|
|
'''This script demonstrates how to build a variational autoencoder
|
|
|
|
with Keras and deconvolution layers.
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
|
|
|
|
'''
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
2016-11-23 22:08:19 +00:00
|
|
|
from scipy.stats import norm
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
|
2017-02-28 18:21:05 +00:00
|
|
|
from keras.layers import Conv2D, Conv2DTranspose
|
2016-07-25 17:33:03 +00:00
|
|
|
from keras.models import Model
|
|
|
|
from keras import backend as K
|
2017-02-28 18:21:05 +00:00
|
|
|
from keras import metrics
|
2016-07-25 17:33:03 +00:00
|
|
|
from keras.datasets import mnist
|
|
|
|
|
|
|
|
# input image dimensions
|
|
|
|
img_rows, img_cols, img_chns = 28, 28, 1
|
|
|
|
# number of convolutional filters to use
|
2017-02-28 18:21:05 +00:00
|
|
|
filters = 64
|
2016-07-25 17:33:03 +00:00
|
|
|
# convolution kernel size
|
2017-02-15 00:08:30 +00:00
|
|
|
num_conv = 3
|
2016-07-25 17:33:03 +00:00
|
|
|
|
2016-09-28 20:40:44 +00:00
|
|
|
batch_size = 100
|
2017-01-13 23:39:04 +00:00
|
|
|
if K.image_data_format() == 'channels_first':
|
2016-09-30 23:26:50 +00:00
|
|
|
original_img_size = (img_chns, img_rows, img_cols)
|
|
|
|
else:
|
|
|
|
original_img_size = (img_rows, img_cols, img_chns)
|
2016-07-25 17:33:03 +00:00
|
|
|
latent_dim = 2
|
|
|
|
intermediate_dim = 128
|
2016-11-02 22:58:32 +00:00
|
|
|
epsilon_std = 1.0
|
2017-02-15 00:08:30 +00:00
|
|
|
epochs = 5
|
2016-07-25 17:33:03 +00:00
|
|
|
|
2016-09-30 23:26:50 +00:00
|
|
|
x = Input(batch_shape=(batch_size,) + original_img_size)
|
2017-02-28 18:21:05 +00:00
|
|
|
conv_1 = Conv2D(img_chns,
|
|
|
|
kernel_size=(2, 2),
|
|
|
|
padding='same', activation='relu')(x)
|
|
|
|
conv_2 = Conv2D(filters,
|
|
|
|
kernel_size=(2, 2),
|
|
|
|
padding='same', activation='relu',
|
|
|
|
strides=(2, 2))(conv_1)
|
|
|
|
conv_3 = Conv2D(filters,
|
|
|
|
kernel_size=num_conv,
|
|
|
|
padding='same', activation='relu',
|
|
|
|
strides=1)(conv_2)
|
|
|
|
conv_4 = Conv2D(filters,
|
|
|
|
kernel_size=num_conv,
|
|
|
|
padding='same', activation='relu',
|
|
|
|
strides=1)(conv_3)
|
2016-09-28 20:40:44 +00:00
|
|
|
flat = Flatten()(conv_4)
|
|
|
|
hidden = Dense(intermediate_dim, activation='relu')(flat)
|
|
|
|
|
|
|
|
z_mean = Dense(latent_dim)(hidden)
|
|
|
|
z_log_var = Dense(latent_dim)(hidden)
|
2016-07-25 17:42:37 +00:00
|
|
|
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
def sampling(args):
|
2016-07-25 17:42:37 +00:00
|
|
|
z_mean, z_log_var = args
|
2016-07-25 17:33:03 +00:00
|
|
|
epsilon = K.random_normal(shape=(batch_size, latent_dim),
|
2017-02-28 18:21:05 +00:00
|
|
|
mean=0., stddev=epsilon_std)
|
2016-07-25 17:42:37 +00:00
|
|
|
return z_mean + K.exp(z_log_var) * epsilon
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
# note that "output_shape" isn't necessary with the TensorFlow backend
|
2016-07-25 17:42:37 +00:00
|
|
|
# so you could write `Lambda(sampling)([z_mean, z_log_var])`
|
|
|
|
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
# we instantiate these layers separately so as to reuse them later
|
2016-09-28 20:40:44 +00:00
|
|
|
decoder_hid = Dense(intermediate_dim, activation='relu')
|
2017-02-28 18:21:05 +00:00
|
|
|
decoder_upsample = Dense(filters * 14 * 14, activation='relu')
|
2016-09-30 23:26:50 +00:00
|
|
|
|
2017-01-13 23:39:04 +00:00
|
|
|
if K.image_data_format() == 'channels_first':
|
2017-02-28 18:21:05 +00:00
|
|
|
output_shape = (batch_size, filters, 14, 14)
|
2016-09-30 23:26:50 +00:00
|
|
|
else:
|
2017-02-28 18:21:05 +00:00
|
|
|
output_shape = (batch_size, 14, 14, filters)
|
2016-09-30 23:26:50 +00:00
|
|
|
|
|
|
|
decoder_reshape = Reshape(output_shape[1:])
|
2017-02-28 18:21:05 +00:00
|
|
|
decoder_deconv_1 = Conv2DTranspose(filters,
|
|
|
|
kernel_size=num_conv,
|
|
|
|
padding='same',
|
|
|
|
strides=1,
|
2016-09-28 20:40:44 +00:00
|
|
|
activation='relu')
|
2017-02-28 18:21:05 +00:00
|
|
|
decoder_deconv_2 = Conv2DTranspose(filters, num_conv,
|
|
|
|
padding='same',
|
|
|
|
strides=1,
|
2016-09-28 20:40:44 +00:00
|
|
|
activation='relu')
|
2017-01-13 23:39:04 +00:00
|
|
|
if K.image_data_format() == 'channels_first':
|
2017-02-28 18:21:05 +00:00
|
|
|
output_shape = (batch_size, filters, 29, 29)
|
2016-09-30 23:26:50 +00:00
|
|
|
else:
|
2017-02-28 18:21:05 +00:00
|
|
|
output_shape = (batch_size, 29, 29, filters)
|
|
|
|
decoder_deconv_3_upsamp = Conv2DTranspose(filters,
|
|
|
|
kernel_size=(3, 3),
|
|
|
|
strides=(2, 2),
|
|
|
|
padding='valid',
|
2016-09-28 20:40:44 +00:00
|
|
|
activation='relu')
|
2017-02-28 18:21:05 +00:00
|
|
|
decoder_mean_squash = Conv2D(img_chns,
|
|
|
|
kernel_size=2,
|
|
|
|
padding='valid',
|
|
|
|
activation='sigmoid')
|
2016-09-28 20:40:44 +00:00
|
|
|
|
|
|
|
hid_decoded = decoder_hid(z)
|
|
|
|
up_decoded = decoder_upsample(hid_decoded)
|
|
|
|
reshape_decoded = decoder_reshape(up_decoded)
|
|
|
|
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
|
|
|
|
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
|
|
|
|
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
|
|
|
|
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
|
2016-07-25 17:33:03 +00:00
|
|
|
|
2017-01-11 19:39:58 +00:00
|
|
|
|
2016-07-25 17:33:03 +00:00
|
|
|
def vae_loss(x, x_decoded_mean):
|
2016-09-30 23:26:50 +00:00
|
|
|
# NOTE: binary_crossentropy expects a batch_size by dim
|
|
|
|
# for x and x_decoded_mean, so we MUST flatten these!
|
2016-07-25 17:33:03 +00:00
|
|
|
x = K.flatten(x)
|
|
|
|
x_decoded_mean = K.flatten(x_decoded_mean)
|
2017-02-28 18:21:05 +00:00
|
|
|
xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean)
|
2016-07-25 17:42:37 +00:00
|
|
|
kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
|
2016-07-25 17:33:03 +00:00
|
|
|
return xent_loss + kl_loss
|
|
|
|
|
2016-09-28 20:40:44 +00:00
|
|
|
vae = Model(x, x_decoded_mean_squash)
|
2016-07-25 17:33:03 +00:00
|
|
|
vae.compile(optimizer='rmsprop', loss=vae_loss)
|
|
|
|
vae.summary()
|
|
|
|
|
|
|
|
# train the VAE on MNIST digits
|
2016-09-30 23:26:50 +00:00
|
|
|
(x_train, _), (x_test, y_test) = mnist.load_data()
|
2016-07-25 17:33:03 +00:00
|
|
|
|
2016-09-30 23:26:50 +00:00
|
|
|
x_train = x_train.astype('float32') / 255.
|
|
|
|
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
|
|
|
|
x_test = x_test.astype('float32') / 255.
|
|
|
|
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
|
2016-07-25 17:33:03 +00:00
|
|
|
|
2016-09-30 23:26:50 +00:00
|
|
|
print('x_train.shape:', x_train.shape)
|
2016-09-28 20:40:44 +00:00
|
|
|
|
2016-07-25 17:33:03 +00:00
|
|
|
vae.fit(x_train, x_train,
|
|
|
|
shuffle=True,
|
2017-02-15 00:08:30 +00:00
|
|
|
epochs=epochs,
|
2016-07-25 17:33:03 +00:00
|
|
|
batch_size=batch_size,
|
|
|
|
validation_data=(x_test, x_test))
|
|
|
|
|
|
|
|
# build a model to project inputs on the latent space
|
|
|
|
encoder = Model(x, z_mean)
|
|
|
|
|
|
|
|
# display a 2D plot of the digit classes in the latent space
|
|
|
|
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
|
|
|
|
plt.figure(figsize=(6, 6))
|
|
|
|
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
|
|
|
|
plt.colorbar()
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
# build a digit generator that can sample from the learned distribution
|
|
|
|
decoder_input = Input(shape=(latent_dim,))
|
2016-09-28 20:40:44 +00:00
|
|
|
_hid_decoded = decoder_hid(decoder_input)
|
|
|
|
_up_decoded = decoder_upsample(_hid_decoded)
|
|
|
|
_reshape_decoded = decoder_reshape(_up_decoded)
|
|
|
|
_deconv_1_decoded = decoder_deconv_1(_reshape_decoded)
|
|
|
|
_deconv_2_decoded = decoder_deconv_2(_deconv_1_decoded)
|
|
|
|
_x_decoded_relu = decoder_deconv_3_upsamp(_deconv_2_decoded)
|
|
|
|
_x_decoded_mean_squash = decoder_mean_squash(_x_decoded_relu)
|
|
|
|
generator = Model(decoder_input, _x_decoded_mean_squash)
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
# display a 2D manifold of the digits
|
|
|
|
n = 15 # figure with 15x15 digits
|
|
|
|
digit_size = 28
|
|
|
|
figure = np.zeros((digit_size * n, digit_size * n))
|
2016-11-23 22:08:19 +00:00
|
|
|
# linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
|
|
|
|
# to produce values of the latent variables z, since the prior of the latent space is Gaussian
|
|
|
|
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
|
|
|
|
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
|
2016-07-25 17:33:03 +00:00
|
|
|
|
|
|
|
for i, yi in enumerate(grid_x):
|
|
|
|
for j, xi in enumerate(grid_y):
|
2016-07-25 17:42:37 +00:00
|
|
|
z_sample = np.array([[xi, yi]])
|
2016-09-28 20:40:44 +00:00
|
|
|
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
|
|
|
|
x_decoded = generator.predict(z_sample, batch_size=batch_size)
|
2016-07-25 17:33:03 +00:00
|
|
|
digit = x_decoded[0].reshape(digit_size, digit_size)
|
|
|
|
figure[i * digit_size: (i + 1) * digit_size,
|
|
|
|
j * digit_size: (j + 1) * digit_size] = digit
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10))
|
2016-11-23 22:08:19 +00:00
|
|
|
plt.imshow(figure, cmap='Greys_r')
|
2016-07-25 17:33:03 +00:00
|
|
|
plt.show()
|