diff --git a/examples/variational_autoencoder_deconv.py b/examples/variational_autoencoder_deconv.py new file mode 100644 index 000000000..70b1971c8 --- /dev/null +++ b/examples/variational_autoencoder_deconv.py @@ -0,0 +1,122 @@ +'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers. + +Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114 +''' +import numpy as np +import matplotlib.pyplot as plt + +from keras.layers import Input, Dense, Lambda, Flatten, Reshape +from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D +from keras.models import Model +from keras import backend as K +from keras import objectives +from keras.datasets import mnist + +# input image dimensions +img_rows, img_cols, img_chns = 28, 28, 1 +# number of convolutional filters to use +nb_filters = 32 +# convolution kernel size +nb_conv = 3 + +batch_size = 16 +original_dim = (img_chns, img_rows, img_cols) +latent_dim = 2 +intermediate_dim = 128 +epsilon_std = 0.01 +nb_epoch = 5 + + +x = Input(batch_shape=(batch_size,) + original_dim) +c = Convolution2D(nb_filters, nb_conv, nb_conv, border_mode='same', activation='relu')(x) +f = Flatten()(c) +h = Dense(intermediate_dim, activation='relu')(f) + +z_mean = Dense(latent_dim)(h) +z_log_std = Dense(latent_dim)(h) + +def sampling(args): + z_mean, z_log_std = args + epsilon = K.random_normal(shape=(batch_size, latent_dim), + mean=0., std=epsilon_std) + return z_mean + K.exp(z_log_std) * epsilon + +# note that "output_shape" isn't necessary with the TensorFlow backend +# so you could write `Lambda(sampling)([z_mean, z_log_std])` +z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std]) + +# we instantiate these layers separately so as to reuse them later +decoder_h = Dense(intermediate_dim, activation='relu') +decoder_f = Dense(nb_filters*img_rows*img_cols, activation='relu') +decoder_c = Reshape((nb_filters, img_rows, img_cols)) +decoder_mean = Deconvolution2D(img_chns, nb_conv, nb_conv, + (batch_size, img_chns, img_rows, img_cols), border_mode='same') + +h_decoded = decoder_h(z) +f_decoded = decoder_f(h_decoded) +c_decoded = decoder_c(f_decoded) +x_decoded_mean = decoder_mean(c_decoded) + + +def vae_loss(x, x_decoded_mean): + # NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these! + x = K.flatten(x) + x_decoded_mean = K.flatten(x_decoded_mean) + xent_loss = objectives.binary_crossentropy(x, x_decoded_mean) + kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1) + return xent_loss + kl_loss + +vae = Model(x, x_decoded_mean) +vae.compile(optimizer='rmsprop', loss=vae_loss) +vae.summary() + +# train the VAE on MNIST digits +(x_train, y_train), (x_test, y_test) = mnist.load_data() + +x_train = x_train.astype('float32')[:, None, :, :] / 255. +x_test = x_test.astype('float32')[:, None, :, :] / 255. + +vae.fit(x_train, x_train, + shuffle=True, + nb_epoch=nb_epoch, + batch_size=batch_size, + validation_data=(x_test, x_test)) + + +# build a model to project inputs on the latent space +encoder = Model(x, z_mean) + +# display a 2D plot of the digit classes in the latent space +x_test_encoded = encoder.predict(x_test, batch_size=batch_size) +plt.figure(figsize=(6, 6)) +plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test) +plt.colorbar() +plt.show() + +# build a digit generator that can sample from the learned distribution +decoder_input = Input(shape=(latent_dim,)) +_h_decoded = decoder_h(decoder_input) +_f_decoded = decoder_f(_h_decoded) +_c_decoded = decoder_c(_f_decoded) +_x_decoded_mean = decoder_mean(_c_decoded) +generator = Model(decoder_input, _x_decoded_mean) + +# display a 2D manifold of the digits +n = 15 # figure with 15x15 digits +digit_size = 28 +figure = np.zeros((digit_size * n, digit_size * n)) +# we will sample n points within [-15, 15] standard deviations +grid_x = np.linspace(-15, 15, n) +grid_y = np.linspace(-15, 15, n) + +for i, yi in enumerate(grid_x): + for j, xi in enumerate(grid_y): + z_sample = np.array([[xi, yi]]) * epsilon_std + x_decoded = generator.predict(z_sample) + digit = x_decoded[0].reshape(digit_size, digit_size) + figure[i * digit_size: (i + 1) * digit_size, + j * digit_size: (j + 1) * digit_size] = digit + +plt.figure(figsize=(10, 10)) +plt.imshow(figure) +plt.show() diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 0d3ab9e64..a3d161e51 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1250,6 +1250,12 @@ def l2_normalize(x, axis): # CONVOLUTIONS +def _preprocess_deconv_output_shape(sh, dim_ordering): + if dim_ordering == "th": + sh = (sh[0], sh[2], sh[3], sh[1]) + return sh + + def _preprocess_conv2d_input(x, dim_ordering): if _FLOATX == 'float64': x = tf.cast(x, 'float32') @@ -1375,11 +1381,12 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1), raise Exception('Unknown dim_ordering ' + str(dim_ordering)) x = _preprocess_conv2d_input(x, dim_ordering) + output_shape = _preprocess_deconv_output_shape(output_shape, dim_ordering) kernel = _preprocess_conv2d_kernel(kernel, dim_ordering) + kernel = tf.transpose(kernel, (0, 1, 3, 2)) # tranpose kernel chanels padding = _preprocess_border_mode(border_mode) strides = (1,) + strides + (1,) - # TODO: pre-process output_shape if dim_ordering == th x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides, padding=padding) return _postprocess_conv2d_output(x, dim_ordering) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 9a22f88a4..ec9170a8f 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -937,6 +937,79 @@ def l2_normalize(x, axis): # CONVOLUTIONS +def _preprocess_conv2d_input(x, dim_ordering): + if dim_ordering == 'tf': + # TF uses the last dimension as channel dimension, + # instead of the 2nd one. + # TH input shape: (samples, input_depth, rows, cols) + # TF input shape: (samples, rows, cols, input_depth) + x = x.dimshuffle((0, 3, 1, 2)) + return x + + +def _preprocess_conv2d_kernel(kernel, dim_ordering): + if dim_ordering == 'tf': + # TF uses the last dimension as channel dimension, + # instead of the 2nd one. + # TH kernel shape: (depth, input_depth, rows, cols) + # TF kernel shape: (rows, cols, input_depth, depth) + kernel = kernel.dimshuffle((3, 2, 0, 1)) + return kernel + + +def _preprocess_border_mode(border_mode): + if border_mode == 'same': + th_border_mode = 'half' + elif border_mode == 'valid': + th_border_mode = 'valid' + else: + raise Exception('Border mode not supported: ' + str(border_mode)) + return th_border_mode + + +def _preprocess_image_shape(dim_ordering, image_shape): + # Theano might not accept long type + def int_or_none(value): + try: + return int(value) + except TypeError: + return None + if dim_ordering == 'tf': + if image_shape: + image_shape = (image_shape[0], image_shape[3], + image_shape[1], image_shape[2]) + if image_shape is not None: + image_shape = tuple(int_or_none(v) for v in image_shape) + return image_shape + + +def _preprocess_filter_shape(dim_ordering, filter_shape): + # Theano might not accept long type + def int_or_none(value): + try: + return int(value) + except TypeError: + return None + if dim_ordering == 'tf': + if filter_shape: + filter_shape = (filter_shape[3], filter_shape[2], + filter_shape[0], filter_shape[1]) + if filter_shape is not None: + filter_shape = tuple(int_or_none(v) for v in filter_shape) + return filter_shape + + +def _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel, strides, dim_ordering): + if border_mode == 'same': + if np_kernel.shape[2] % 2 == 0: + conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :] + if np_kernel.shape[3] % 2 == 0: + conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]] + if dim_ordering == 'tf': + conv_out = conv_out.dimshuffle((0, 2, 3, 1)) + return conv_out + + def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING, image_shape=None, filter_shape=None, filter_dilation=(1, 1)): @@ -953,42 +1026,12 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', if dim_ordering not in {'th', 'tf'}: raise Exception('Unknown dim_ordering ' + str(dim_ordering)) - if dim_ordering == 'tf': - # TF uses the last dimension as channel dimension, - # instead of the 2nd one. - # TH input shape: (samples, input_depth, rows, cols) - # TF input shape: (samples, rows, cols, input_depth) - # TH kernel shape: (depth, input_depth, rows, cols) - # TF kernel shape: (rows, cols, input_depth, depth) - x = x.dimshuffle((0, 3, 1, 2)) - kernel = kernel.dimshuffle((3, 2, 0, 1)) - if image_shape: - image_shape = (image_shape[0], image_shape[3], - image_shape[1], image_shape[2]) - if filter_shape: - filter_shape = (filter_shape[3], filter_shape[2], - filter_shape[0], filter_shape[1]) - - if border_mode == 'same': - th_border_mode = 'half' - np_kernel = kernel.eval() - elif border_mode == 'valid': - th_border_mode = 'valid' - else: - raise Exception('Border mode not supported: ' + str(border_mode)) - - # Theano might not accept long type - def int_or_none(value): - try: - return int(value) - except TypeError: - return None - - if image_shape is not None: - image_shape = tuple(int_or_none(v) for v in image_shape) - - if filter_shape is not None: - filter_shape = tuple(int_or_none(v) for v in filter_shape) + x = _preprocess_conv2d_input(x, dim_ordering) + kernel = _preprocess_conv2d_kernel(kernel, dim_ordering) + th_border_mode = _preprocess_border_mode(border_mode) + np_kernel = kernel.eval() + image_shape = _preprocess_image_shape(dim_ordering, image_shape) + filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape) # TODO: remove the if statement when theano with no filter dilation is deprecated. if filter_dilation == (1, 1): @@ -1005,14 +1048,8 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', filter_shape=filter_shape, filter_dilation=filter_dilation) - if border_mode == 'same': - if np_kernel.shape[2] % 2 == 0: - conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :] - if np_kernel.shape[3] % 2 == 0: - conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]] - - if dim_ordering == 'tf': - conv_out = conv_out.dimshuffle((0, 2, 3, 1)) + conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel, + strides, dim_ordering) return conv_out @@ -1020,7 +1057,38 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1), border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING, image_shape=None, filter_shape=None): - raise NotImplementedError + '''2D deconvolution (transposed convolution). + + # Arguments + kernel: kernel tensor. + output_shape: desired dimensions of output. + strides: strides tuple. + border_mode: string, "same" or "valid". + dim_ordering: "tf" or "th". + Whether to use Theano or TensorFlow dimension ordering + in inputs/kernels/ouputs. + ''' + flip_filters = False + if dim_ordering not in {'th', 'tf'}: + raise Exception('Unknown dim_ordering ' + str(dim_ordering)) + + x = _preprocess_conv2d_input(x, dim_ordering) + kernel = _preprocess_conv2d_kernel(kernel, dim_ordering) + kernel = kernel.dimshuffle((1, 0, 2, 3)) + th_border_mode = _preprocess_border_mode(border_mode) + np_kernel = kernel.eval() + filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape) + + op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=output_shape, + kshp=filter_shape, + subsample=strides, + border_mode=th_border_mode, + filter_flip=not flip_filters) + conv_out = op(kernel, x, output_shape[2:]) + + conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel, + strides, dim_ordering) + return conv_out def atrous_conv2d(x, kernel, rate=1, diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 183b668c4..66beb6561 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -4,7 +4,7 @@ from __future__ import absolute_import from .. import backend as K from .. import activations, initializations, regularizers, constraints from ..engine import Layer, InputSpec -from ..utils.np_utils import conv_output_length +from ..utils.np_utils import conv_output_length, conv_input_length # imports for backwards namespace compatibility from .pooling import AveragePooling1D, AveragePooling2D, AveragePooling3D @@ -379,6 +379,79 @@ class Convolution2D(Layer): return dict(list(base_config.items()) + list(config.items())) +class Deconvolution2D(Convolution2D): + '''Transposed convolution operator for filtering windows of two-dimensional inputs. + When using this layer as the first layer in a model, + provide the keyword argument `input_shape` + (tuple of integers, does not include the sample axis), + e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures. + ''' + def __init__(self, nb_filter, nb_row, nb_col, output_shape, + init='glorot_uniform', activation='linear', weights=None, + border_mode='valid', subsample=(1, 1), + dim_ordering=K.image_dim_ordering(), + W_regularizer=None, b_regularizer=None, activity_regularizer=None, + W_constraint=None, b_constraint=None, + bias=True, **kwargs): + + if border_mode not in {'valid', 'same'}: + raise Exception('Invalid border mode for AtrousConv2D:', border_mode) + + self.output_shape_ = output_shape + + super(Deconvolution2D, self).__init__(nb_filter, nb_row, nb_col, + init=init, activation=activation, + weights=weights, border_mode=border_mode, + subsample=subsample, dim_ordering=dim_ordering, + W_regularizer=W_regularizer, b_regularizer=b_regularizer, + activity_regularizer=activity_regularizer, + W_constraint=W_constraint, b_constraint=b_constraint, + bias=bias, **kwargs) + + def get_output_shape_for(self, input_shape): + if self.dim_ordering == 'th': + rows = input_shape[2] + cols = input_shape[3] + elif self.dim_ordering == 'tf': + rows = input_shape[1] + cols = input_shape[2] + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + rows = conv_input_length(rows, self.nb_row, + self.border_mode, self.subsample[0]) + cols = conv_input_length(cols, self.nb_col, + self.border_mode, self.subsample[1]) + + if self.dim_ordering == 'th': + return (input_shape[0], self.nb_filter, rows, cols) + elif self.dim_ordering == 'tf': + return (input_shape[0], rows, cols, self.nb_filter) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def call(self, x, mask=None): + output = K.deconv2d(x, self.W, self.output_shape_, + strides=self.subsample, + border_mode=self.border_mode, + dim_ordering=self.dim_ordering, + filter_shape=self.W_shape) + if self.bias: + if self.dim_ordering == 'th': + output += K.reshape(self.b, (1, self.nb_filter, 1, 1)) + elif self.dim_ordering == 'tf': + output += K.reshape(self.b, (1, 1, 1, self.nb_filter)) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + output = self.activation(output) + return output + + def get_config(self): + config = {'output_shape': self.output_shape} + base_config = super(Deconvolution2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class AtrousConvolution2D(Convolution2D): '''Atrous Convolution operator for filtering windows of two-dimensional inputs. A.k.a dilated convolution or convolution with holes. @@ -1251,5 +1324,6 @@ class ZeroPadding3D(Layer): Conv1D = Convolution1D Conv2D = Convolution2D Conv3D = Convolution3D +Deconv2D = Deconvolution2D AtrousConv2D = AtrousConvolution2D SeparableConv2D = SeparableConvolution2D diff --git a/keras/utils/np_utils.py b/keras/utils/np_utils.py index 4c8aaa089..687f1d2f8 100644 --- a/keras/utils/np_utils.py +++ b/keras/utils/np_utils.py @@ -120,3 +120,13 @@ def conv_output_length(input_length, filter_size, border_mode, stride, dilation= elif border_mode == 'valid': output_length = input_length - dilated_filter_size + 1 return (output_length + stride - 1) // stride + +def conv_input_length(output_length, filter_size, border_mode, stride): + if output_length is None: + return None + assert border_mode in {'same', 'valid'} + if border_mode == 'same': + pad = filter_size // 2 + elif border_mode == 'valid': + pad = 0 + return (output_length - 1) * stride - 2 * pad + filter_size diff --git a/keras/utils/test_utils.py b/keras/utils/test_utils.py index 2814321d6..abdf7708f 100644 --- a/keras/utils/test_utils.py +++ b/keras/utils/test_utils.py @@ -36,7 +36,7 @@ def get_test_data(nb_train=1000, nb_test=500, input_shape=(10,), def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, - input_data=None, expected_output=None, expected_output_dtype=None): + input_data=None, expected_output=None, expected_output_dtype=None, fixed_batch_size=False): '''Test routine for a layer with a single input tensor and single output tensor. ''' @@ -64,7 +64,10 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, layer = layer_cls(**kwargs) # test in functional API - x = Input(shape=input_shape[1:], dtype=input_dtype) + if fixed_batch_size: + x = Input(batch_shape=input_shape, dtype=input_dtype) + else: + x = Input(shape=input_shape[1:], dtype=input_dtype) y = layer(x) assert K.dtype(y) == expected_output_dtype diff --git a/tests/keras/layers/test_convolutional.py b/tests/keras/layers/test_convolutional.py index 77c98e3de..8588ee44f 100644 --- a/tests/keras/layers/test_convolutional.py +++ b/tests/keras/layers/test_convolutional.py @@ -3,6 +3,7 @@ import numpy as np from numpy.testing import assert_allclose from keras.utils.test_utils import layer_test, keras_test +from keras.utils.np_utils import conv_input_length from keras import backend as K from keras.layers import convolutional @@ -88,6 +89,45 @@ def test_convolution_2d(): input_shape=(nb_samples, stack_size, nb_row, nb_col)) +@keras_test +def test_deconvolution_2d(): + nb_samples = 2 + nb_filter = 2 + stack_size = 3 + nb_row = 10 + nb_col = 6 + + for border_mode in ['valid', 'same']: + for subsample in [(1, 1), (2, 2)]: + if border_mode == 'same' and subsample != (1, 1): + continue + + rows = conv_input_length(nb_row, 3, border_mode, subsample[0]) + cols = conv_input_length(nb_col, 3, border_mode, subsample[1]) + layer_test(convolutional.Deconvolution2D, + kwargs={'nb_filter': nb_filter, + 'nb_row': 3, + 'nb_col': 3, + 'output_shape': (nb_samples, nb_filter, rows, cols), + 'border_mode': border_mode, + 'subsample': subsample}, + input_shape=(nb_samples, stack_size, nb_row, nb_col), + fixed_batch_size=True) + + layer_test(convolutional.Deconvolution2D, + kwargs={'nb_filter': nb_filter, + 'nb_row': 3, + 'nb_col': 3, + 'output_shape': (nb_samples, nb_filter, rows, cols), + 'border_mode': border_mode, + 'W_regularizer': 'l2', + 'b_regularizer': 'l2', + 'activity_regularizer': 'activity_l2', + 'subsample': subsample}, + input_shape=(nb_samples, stack_size, nb_row, nb_col), + fixed_batch_size=True) + + @keras_test def test_atrous_conv_2d(): nb_samples = 2