keras/examples/mnist_swwae.py
2017-03-28 13:44:56 +02:00

204 lines
7.6 KiB
Python

'''Trains a stacked what-where autoencoder built on residual blocks on the
MNIST dataset. It exemplifies two influential methods that have been developed
in the past few years.
The first is the idea of properly 'unpooling.' During any max pool, the
exact location (the 'where') of the maximal value in a pooled receptive field
is lost, however it can be very useful in the overall reconstruction of an
input image. Therefore, if the 'where' is handed from the encoder
to the corresponding decoder layer, features being decoded can be 'placed' in
the right location, allowing for reconstructions of much higher fidelity.
References:
[1]
'Visualizing and Understanding Convolutional Networks'
Matthew D Zeiler, Rob Fergus
https://arxiv.org/abs/1311.2901v3
[2]
'Stacked What-Where Auto-encoders'
Junbo Zhao, Michael Mathieu, Ross Goroshin, Yann LeCun
https://arxiv.org/abs/1506.02351v8
The second idea exploited here is that of residual learning. Residual blocks
ease the training process by allowing skip connections that give the network
the ability to be as linear (or non-linear) as the data sees fit. This allows
for much deep networks to be easily trained. The residual element seems to
be advantageous in the context of this example as it allows a nice symmetry
between the encoder and decoder. Normally, in the decoder, the final
projection to the space where the image is reconstructed is linear, however
this does not have to be the case for a residual block as the degree to which
its output is linear or non-linear is determined by the data it is fed.
However, in order to cap the reconstruction in this example, a hard softmax is
applied as a bias because we know the MNIST digits are mapped to [0,1].
References:
[3]
'Deep Residual Learning for Image Recognition'
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
https://arxiv.org/abs/1512.03385v1
[4]
'Identity Mappings in Deep Residual Networks'
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
https://arxiv.org/abs/1603.05027v3
'''
from __future__ import print_function
import numpy as np
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Activation
from keras.layers import UpSampling2D, Conv2D, MaxPooling2D
from keras.layers import Input, BatchNormalization, ELU
import matplotlib.pyplot as plt
import keras.backend as K
from keras import layers
def convresblock(x, nfeats=8, ksize=3, nskipped=2, elu=True):
"""The proposed residual block from [4].
Running with elu=True will use ELU nonlinearity and running with
elu=False will use BatchNorm + RELU nonlinearity. While ELU's are fast
due to the fact they do not suffer from BatchNorm overhead, they may
overfit because they do not offer the stochastic element of the batch
formation process of BatchNorm, which acts as a good regularizer.
# Arguments
x: 4D tensor, the tensor to feed through the block
nfeats: Integer, number of feature maps for conv layers.
ksize: Integer, width and height of conv kernels in first convolution.
nskipped: Integer, number of conv layers for the residual function.
elu: Boolean, whether to use ELU or BN+RELU.
# Input shape
4D tensor with shape:
`(batch, channels, rows, cols)`
# Output shape
4D tensor with shape:
`(batch, filters, rows, cols)`
"""
y0 = Conv2D(nfeats, ksize, padding='same')(x)
y = y0
for i in range(nskipped):
if elu:
y = ELU()(y)
else:
y = BatchNormalization(axis=1)(y)
y = Activation('relu')(y)
y = Conv2D(nfeats, 1, padding='same')(y)
return layers.add([y0, y])
def getwhere(x):
''' Calculate the 'where' mask that contains switches indicating which
index contained the max value when MaxPool2D was applied. Using the
gradient of the sum is a nice trick to keep everything high level.'''
y_prepool, y_postpool = x
return K.gradients(K.sum(y_postpool), y_prepool)
if K.backend() == 'tensorflow':
raise RuntimeError('This example can only run with the '
'Theano backend for the time being, '
'because it requires taking the gradient '
'of a gradient, which isn\'t '
'supported for all TF ops.')
# This example assume 'channels_first' data format.
K.set_image_data_format('channels_first')
# input image dimensions
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# The size of the kernel used for the MaxPooling2D
pool_size = 2
# The total number of feature maps at each layer
nfeats = [8, 16, 32, 64, 128]
# The sizes of the pooling kernel at each layer
pool_sizes = np.array([1, 1, 1, 1, 1]) * pool_size
# The convolution kernel size
ksize = 3
# Number of epochs to train for
epochs = 5
# Batch size during training
batch_size = 128
if pool_size == 2:
# if using a 5 layer net of pool_size = 2
x_train = np.pad(x_train, [[0, 0], [0, 0], [2, 2], [2, 2]],
mode='constant')
x_test = np.pad(x_test, [[0, 0], [0, 0], [2, 2], [2, 2]], mode='constant')
nlayers = 5
elif pool_size == 3:
# if using a 3 layer net of pool_size = 3
x_train = x_train[:, :, :-1, :-1]
x_test = x_test[:, :, :-1, :-1]
nlayers = 3
else:
import sys
sys.exit('Script supports pool_size of 2 and 3.')
# Shape of input to train on (note that model is fully convolutional however)
input_shape = x_train.shape[1:]
# The final list of the size of axis=1 for all layers, including input
nfeats_all = [input_shape[0]] + nfeats
# First build the encoder, all the while keeping track of the 'where' masks
img_input = Input(shape=input_shape)
# We push the 'where' masks to the following list
wheres = [None] * nlayers
y = img_input
for i in range(nlayers):
y_prepool = convresblock(y, nfeats=nfeats_all[i + 1], ksize=ksize)
y = MaxPooling2D(pool_size=(pool_sizes[i], pool_sizes[i]))(y_prepool)
wheres[i] = layers.Lambda(
getwhere, output_shape=lambda x: x[0])([y_prepool, y])
# Now build the decoder, and use the stored 'where' masks to place the features
for i in range(nlayers):
ind = nlayers - 1 - i
y = UpSampling2D(size=(pool_sizes[ind], pool_sizes[ind]))(y)
y = layers.multiply([y, wheres[ind]])
y = convresblock(y, nfeats=nfeats_all[ind], ksize=ksize)
# Use hard_simgoid to clip range of reconstruction
y = Activation('hard_sigmoid')(y)
# Define the model and it's mean square error loss, and compile it with Adam
model = Model(img_input, y)
model.compile('adam', 'mse')
# Fit the model
model.fit(x_train, x_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, x_test))
# Plot
x_recon = model.predict(x_test[:25])
x_plot = np.concatenate((x_test[:25], x_recon), axis=1)
x_plot = x_plot.reshape((5, 10, input_shape[-2], input_shape[-1]))
x_plot = np.vstack([np.hstack(x) for x in x_plot])
plt.figure()
plt.axis('off')
plt.title('Test Samples: Originals/Reconstructions')
plt.imshow(x_plot, interpolation='none', cmap='gray')
plt.savefig('reconstructions.png')