Implemented transposed (de-) convolutions into Keras (#3251)

* theano backend now supports transposed convolutions

* working deconv

* new example file with deconv vae

* merged with #3273, fixed based on comments, pep8 tested

* test fix

* passes theano test

* start fixing deconv test

* fix deconv layer tests

* fix the right test

sorry, I "fixed" the wrong test last time

* clean up

* replace with_None with fixed_batch_size

* with_None --> fixed_batch_size

* comment edit

* fixed comments online
This commit is contained in:
yaringal 2016-07-25 18:33:03 +01:00 committed by François Chollet
parent 09d75a4347
commit c689b52dd1
7 changed files with 373 additions and 49 deletions

@ -0,0 +1,122 @@
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
# input image dimensions
img_rows, img_cols, img_chns = 28, 28, 1
# number of convolutional filters to use
nb_filters = 32
# convolution kernel size
nb_conv = 3
batch_size = 16
original_dim = (img_chns, img_rows, img_cols)
latent_dim = 2
intermediate_dim = 128
epsilon_std = 0.01
nb_epoch = 5
x = Input(batch_shape=(batch_size,) + original_dim)
c = Convolution2D(nb_filters, nb_conv, nb_conv, border_mode='same', activation='relu')(x)
f = Flatten()(c)
h = Dense(intermediate_dim, activation='relu')(f)
z_mean = Dense(latent_dim)(h)
z_log_std = Dense(latent_dim)(h)
def sampling(args):
z_mean, z_log_std = args
epsilon = K.random_normal(shape=(batch_size, latent_dim),
mean=0., std=epsilon_std)
return z_mean + K.exp(z_log_std) * epsilon
# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_std])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std])
# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_f = Dense(nb_filters*img_rows*img_cols, activation='relu')
decoder_c = Reshape((nb_filters, img_rows, img_cols))
decoder_mean = Deconvolution2D(img_chns, nb_conv, nb_conv,
(batch_size, img_chns, img_rows, img_cols), border_mode='same')
h_decoded = decoder_h(z)
f_decoded = decoder_f(h_decoded)
c_decoded = decoder_c(f_decoded)
x_decoded_mean = decoder_mean(c_decoded)
def vae_loss(x, x_decoded_mean):
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these!
x = K.flatten(x)
x_decoded_mean = K.flatten(x_decoded_mean)
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1)
return xent_loss + kl_loss
vae = Model(x, x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=vae_loss)
vae.summary()
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')[:, None, :, :] / 255.
x_test = x_test.astype('float32')[:, None, :, :] / 255.
vae.fit(x_train, x_train,
shuffle=True,
nb_epoch=nb_epoch,
batch_size=batch_size,
validation_data=(x_test, x_test))
# build a model to project inputs on the latent space
encoder = Model(x, z_mean)
# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()
# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_f_decoded = decoder_f(_h_decoded)
_c_decoded = decoder_c(_f_decoded)
_x_decoded_mean = decoder_mean(_c_decoded)
generator = Model(decoder_input, _x_decoded_mean)
# display a 2D manifold of the digits
n = 15 # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# we will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]]) * epsilon_std
x_decoded = generator.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

@ -1250,6 +1250,12 @@ def l2_normalize(x, axis):
# CONVOLUTIONS # CONVOLUTIONS
def _preprocess_deconv_output_shape(sh, dim_ordering):
if dim_ordering == "th":
sh = (sh[0], sh[2], sh[3], sh[1])
return sh
def _preprocess_conv2d_input(x, dim_ordering): def _preprocess_conv2d_input(x, dim_ordering):
if _FLOATX == 'float64': if _FLOATX == 'float64':
x = tf.cast(x, 'float32') x = tf.cast(x, 'float32')
@ -1375,11 +1381,12 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1),
raise Exception('Unknown dim_ordering ' + str(dim_ordering)) raise Exception('Unknown dim_ordering ' + str(dim_ordering))
x = _preprocess_conv2d_input(x, dim_ordering) x = _preprocess_conv2d_input(x, dim_ordering)
output_shape = _preprocess_deconv_output_shape(output_shape, dim_ordering)
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering) kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
kernel = tf.transpose(kernel, (0, 1, 3, 2)) # tranpose kernel chanels
padding = _preprocess_border_mode(border_mode) padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,) strides = (1,) + strides + (1,)
# TODO: pre-process output_shape if dim_ordering == th
x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides, x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides,
padding=padding) padding=padding)
return _postprocess_conv2d_output(x, dim_ordering) return _postprocess_conv2d_output(x, dim_ordering)

@ -937,6 +937,79 @@ def l2_normalize(x, axis):
# CONVOLUTIONS # CONVOLUTIONS
def _preprocess_conv2d_input(x, dim_ordering):
if dim_ordering == 'tf':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
x = x.dimshuffle((0, 3, 1, 2))
return x
def _preprocess_conv2d_kernel(kernel, dim_ordering):
if dim_ordering == 'tf':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
kernel = kernel.dimshuffle((3, 2, 0, 1))
return kernel
def _preprocess_border_mode(border_mode):
if border_mode == 'same':
th_border_mode = 'half'
elif border_mode == 'valid':
th_border_mode = 'valid'
else:
raise Exception('Border mode not supported: ' + str(border_mode))
return th_border_mode
def _preprocess_image_shape(dim_ordering, image_shape):
# Theano might not accept long type
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if dim_ordering == 'tf':
if image_shape:
image_shape = (image_shape[0], image_shape[3],
image_shape[1], image_shape[2])
if image_shape is not None:
image_shape = tuple(int_or_none(v) for v in image_shape)
return image_shape
def _preprocess_filter_shape(dim_ordering, filter_shape):
# Theano might not accept long type
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if dim_ordering == 'tf':
if filter_shape:
filter_shape = (filter_shape[3], filter_shape[2],
filter_shape[0], filter_shape[1])
if filter_shape is not None:
filter_shape = tuple(int_or_none(v) for v in filter_shape)
return filter_shape
def _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel, strides, dim_ordering):
if border_mode == 'same':
if np_kernel.shape[2] % 2 == 0:
conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :]
if np_kernel.shape[3] % 2 == 0:
conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]]
if dim_ordering == 'tf':
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
return conv_out
def conv2d(x, kernel, strides=(1, 1), border_mode='valid', def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING, image_shape=None, dim_ordering=_IMAGE_DIM_ORDERING, image_shape=None,
filter_shape=None, filter_dilation=(1, 1)): filter_shape=None, filter_dilation=(1, 1)):
@ -953,42 +1026,12 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
if dim_ordering not in {'th', 'tf'}: if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering)) raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if dim_ordering == 'tf': x = _preprocess_conv2d_input(x, dim_ordering)
# TF uses the last dimension as channel dimension, kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
# instead of the 2nd one. th_border_mode = _preprocess_border_mode(border_mode)
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = x.dimshuffle((0, 3, 1, 2))
kernel = kernel.dimshuffle((3, 2, 0, 1))
if image_shape:
image_shape = (image_shape[0], image_shape[3],
image_shape[1], image_shape[2])
if filter_shape:
filter_shape = (filter_shape[3], filter_shape[2],
filter_shape[0], filter_shape[1])
if border_mode == 'same':
th_border_mode = 'half'
np_kernel = kernel.eval() np_kernel = kernel.eval()
elif border_mode == 'valid': image_shape = _preprocess_image_shape(dim_ordering, image_shape)
th_border_mode = 'valid' filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape)
else:
raise Exception('Border mode not supported: ' + str(border_mode))
# Theano might not accept long type
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if image_shape is not None:
image_shape = tuple(int_or_none(v) for v in image_shape)
if filter_shape is not None:
filter_shape = tuple(int_or_none(v) for v in filter_shape)
# TODO: remove the if statement when theano with no filter dilation is deprecated. # TODO: remove the if statement when theano with no filter dilation is deprecated.
if filter_dilation == (1, 1): if filter_dilation == (1, 1):
@ -1005,14 +1048,8 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
filter_shape=filter_shape, filter_shape=filter_shape,
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
if border_mode == 'same': conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel,
if np_kernel.shape[2] % 2 == 0: strides, dim_ordering)
conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :]
if np_kernel.shape[3] % 2 == 0:
conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]]
if dim_ordering == 'tf':
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
return conv_out return conv_out
@ -1020,7 +1057,38 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1),
border_mode='valid', border_mode='valid',
dim_ordering=_IMAGE_DIM_ORDERING, dim_ordering=_IMAGE_DIM_ORDERING,
image_shape=None, filter_shape=None): image_shape=None, filter_shape=None):
raise NotImplementedError '''2D deconvolution (transposed convolution).
# Arguments
kernel: kernel tensor.
output_shape: desired dimensions of output.
strides: strides tuple.
border_mode: string, "same" or "valid".
dim_ordering: "tf" or "th".
Whether to use Theano or TensorFlow dimension ordering
in inputs/kernels/ouputs.
'''
flip_filters = False
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
x = _preprocess_conv2d_input(x, dim_ordering)
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
kernel = kernel.dimshuffle((1, 0, 2, 3))
th_border_mode = _preprocess_border_mode(border_mode)
np_kernel = kernel.eval()
filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape)
op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=output_shape,
kshp=filter_shape,
subsample=strides,
border_mode=th_border_mode,
filter_flip=not flip_filters)
conv_out = op(kernel, x, output_shape[2:])
conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel,
strides, dim_ordering)
return conv_out
def atrous_conv2d(x, kernel, rate=1, def atrous_conv2d(x, kernel, rate=1,

@ -4,7 +4,7 @@ from __future__ import absolute_import
from .. import backend as K from .. import backend as K
from .. import activations, initializations, regularizers, constraints from .. import activations, initializations, regularizers, constraints
from ..engine import Layer, InputSpec from ..engine import Layer, InputSpec
from ..utils.np_utils import conv_output_length from ..utils.np_utils import conv_output_length, conv_input_length
# imports for backwards namespace compatibility # imports for backwards namespace compatibility
from .pooling import AveragePooling1D, AveragePooling2D, AveragePooling3D from .pooling import AveragePooling1D, AveragePooling2D, AveragePooling3D
@ -379,6 +379,79 @@ class Convolution2D(Layer):
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
class Deconvolution2D(Convolution2D):
'''Transposed convolution operator for filtering windows of two-dimensional inputs.
When using this layer as the first layer in a model,
provide the keyword argument `input_shape`
(tuple of integers, does not include the sample axis),
e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures.
'''
def __init__(self, nb_filter, nb_row, nb_col, output_shape,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1),
dim_ordering=K.image_dim_ordering(),
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, **kwargs):
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for AtrousConv2D:', border_mode)
self.output_shape_ = output_shape
super(Deconvolution2D, self).__init__(nb_filter, nb_row, nb_col,
init=init, activation=activation,
weights=weights, border_mode=border_mode,
subsample=subsample, dim_ordering=dim_ordering,
W_regularizer=W_regularizer, b_regularizer=b_regularizer,
activity_regularizer=activity_regularizer,
W_constraint=W_constraint, b_constraint=b_constraint,
bias=bias, **kwargs)
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
rows = input_shape[2]
cols = input_shape[3]
elif self.dim_ordering == 'tf':
rows = input_shape[1]
cols = input_shape[2]
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
rows = conv_input_length(rows, self.nb_row,
self.border_mode, self.subsample[0])
cols = conv_input_length(cols, self.nb_col,
self.border_mode, self.subsample[1])
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, rows, cols)
elif self.dim_ordering == 'tf':
return (input_shape[0], rows, cols, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
output = K.deconv2d(x, self.W, self.output_shape_,
strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering,
filter_shape=self.W_shape)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'output_shape': self.output_shape}
base_config = super(Deconvolution2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class AtrousConvolution2D(Convolution2D): class AtrousConvolution2D(Convolution2D):
'''Atrous Convolution operator for filtering windows of two-dimensional inputs. '''Atrous Convolution operator for filtering windows of two-dimensional inputs.
A.k.a dilated convolution or convolution with holes. A.k.a dilated convolution or convolution with holes.
@ -1251,5 +1324,6 @@ class ZeroPadding3D(Layer):
Conv1D = Convolution1D Conv1D = Convolution1D
Conv2D = Convolution2D Conv2D = Convolution2D
Conv3D = Convolution3D Conv3D = Convolution3D
Deconv2D = Deconvolution2D
AtrousConv2D = AtrousConvolution2D AtrousConv2D = AtrousConvolution2D
SeparableConv2D = SeparableConvolution2D SeparableConv2D = SeparableConvolution2D

@ -120,3 +120,13 @@ def conv_output_length(input_length, filter_size, border_mode, stride, dilation=
elif border_mode == 'valid': elif border_mode == 'valid':
output_length = input_length - dilated_filter_size + 1 output_length = input_length - dilated_filter_size + 1
return (output_length + stride - 1) // stride return (output_length + stride - 1) // stride
def conv_input_length(output_length, filter_size, border_mode, stride):
if output_length is None:
return None
assert border_mode in {'same', 'valid'}
if border_mode == 'same':
pad = filter_size // 2
elif border_mode == 'valid':
pad = 0
return (output_length - 1) * stride - 2 * pad + filter_size

@ -36,7 +36,7 @@ def get_test_data(nb_train=1000, nb_test=500, input_shape=(10,),
def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
input_data=None, expected_output=None, expected_output_dtype=None): input_data=None, expected_output=None, expected_output_dtype=None, fixed_batch_size=False):
'''Test routine for a layer with a single input tensor '''Test routine for a layer with a single input tensor
and single output tensor. and single output tensor.
''' '''
@ -64,6 +64,9 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
layer = layer_cls(**kwargs) layer = layer_cls(**kwargs)
# test in functional API # test in functional API
if fixed_batch_size:
x = Input(batch_shape=input_shape, dtype=input_dtype)
else:
x = Input(shape=input_shape[1:], dtype=input_dtype) x = Input(shape=input_shape[1:], dtype=input_dtype)
y = layer(x) y = layer(x)
assert K.dtype(y) == expected_output_dtype assert K.dtype(y) == expected_output_dtype

@ -3,6 +3,7 @@ import numpy as np
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from keras.utils.test_utils import layer_test, keras_test from keras.utils.test_utils import layer_test, keras_test
from keras.utils.np_utils import conv_input_length
from keras import backend as K from keras import backend as K
from keras.layers import convolutional from keras.layers import convolutional
@ -88,6 +89,45 @@ def test_convolution_2d():
input_shape=(nb_samples, stack_size, nb_row, nb_col)) input_shape=(nb_samples, stack_size, nb_row, nb_col))
@keras_test
def test_deconvolution_2d():
nb_samples = 2
nb_filter = 2
stack_size = 3
nb_row = 10
nb_col = 6
for border_mode in ['valid', 'same']:
for subsample in [(1, 1), (2, 2)]:
if border_mode == 'same' and subsample != (1, 1):
continue
rows = conv_input_length(nb_row, 3, border_mode, subsample[0])
cols = conv_input_length(nb_col, 3, border_mode, subsample[1])
layer_test(convolutional.Deconvolution2D,
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'output_shape': (nb_samples, nb_filter, rows, cols),
'border_mode': border_mode,
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col),
fixed_batch_size=True)
layer_test(convolutional.Deconvolution2D,
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'output_shape': (nb_samples, nb_filter, rows, cols),
'border_mode': border_mode,
'W_regularizer': 'l2',
'b_regularizer': 'l2',
'activity_regularizer': 'activity_l2',
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col),
fixed_batch_size=True)
@keras_test @keras_test
def test_atrous_conv_2d(): def test_atrous_conv_2d():
nb_samples = 2 nb_samples = 2