Implemented transposed (de-) convolutions into Keras (#3251)
* theano backend now supports transposed convolutions * working deconv * new example file with deconv vae * merged with #3273, fixed based on comments, pep8 tested * test fix * passes theano test * start fixing deconv test * fix deconv layer tests * fix the right test sorry, I "fixed" the wrong test last time * clean up * replace with_None with fixed_batch_size * with_None --> fixed_batch_size * comment edit * fixed comments online
This commit is contained in:
parent
09d75a4347
commit
c689b52dd1
122
examples/variational_autoencoder_deconv.py
Normal file
122
examples/variational_autoencoder_deconv.py
Normal file
@ -0,0 +1,122 @@
|
||||
'''This script demonstrates how to build a variational autoencoder with Keras and deconvolution layers.
|
||||
|
||||
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
|
||||
'''
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
|
||||
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D
|
||||
from keras.models import Model
|
||||
from keras import backend as K
|
||||
from keras import objectives
|
||||
from keras.datasets import mnist
|
||||
|
||||
# input image dimensions
|
||||
img_rows, img_cols, img_chns = 28, 28, 1
|
||||
# number of convolutional filters to use
|
||||
nb_filters = 32
|
||||
# convolution kernel size
|
||||
nb_conv = 3
|
||||
|
||||
batch_size = 16
|
||||
original_dim = (img_chns, img_rows, img_cols)
|
||||
latent_dim = 2
|
||||
intermediate_dim = 128
|
||||
epsilon_std = 0.01
|
||||
nb_epoch = 5
|
||||
|
||||
|
||||
x = Input(batch_shape=(batch_size,) + original_dim)
|
||||
c = Convolution2D(nb_filters, nb_conv, nb_conv, border_mode='same', activation='relu')(x)
|
||||
f = Flatten()(c)
|
||||
h = Dense(intermediate_dim, activation='relu')(f)
|
||||
|
||||
z_mean = Dense(latent_dim)(h)
|
||||
z_log_std = Dense(latent_dim)(h)
|
||||
|
||||
def sampling(args):
|
||||
z_mean, z_log_std = args
|
||||
epsilon = K.random_normal(shape=(batch_size, latent_dim),
|
||||
mean=0., std=epsilon_std)
|
||||
return z_mean + K.exp(z_log_std) * epsilon
|
||||
|
||||
# note that "output_shape" isn't necessary with the TensorFlow backend
|
||||
# so you could write `Lambda(sampling)([z_mean, z_log_std])`
|
||||
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std])
|
||||
|
||||
# we instantiate these layers separately so as to reuse them later
|
||||
decoder_h = Dense(intermediate_dim, activation='relu')
|
||||
decoder_f = Dense(nb_filters*img_rows*img_cols, activation='relu')
|
||||
decoder_c = Reshape((nb_filters, img_rows, img_cols))
|
||||
decoder_mean = Deconvolution2D(img_chns, nb_conv, nb_conv,
|
||||
(batch_size, img_chns, img_rows, img_cols), border_mode='same')
|
||||
|
||||
h_decoded = decoder_h(z)
|
||||
f_decoded = decoder_f(h_decoded)
|
||||
c_decoded = decoder_c(f_decoded)
|
||||
x_decoded_mean = decoder_mean(c_decoded)
|
||||
|
||||
|
||||
def vae_loss(x, x_decoded_mean):
|
||||
# NOTE: binary_crossentropy expects a batch_size by dim for x and x_decoded_mean, so we MUST flatten these!
|
||||
x = K.flatten(x)
|
||||
x_decoded_mean = K.flatten(x_decoded_mean)
|
||||
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
|
||||
kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1)
|
||||
return xent_loss + kl_loss
|
||||
|
||||
vae = Model(x, x_decoded_mean)
|
||||
vae.compile(optimizer='rmsprop', loss=vae_loss)
|
||||
vae.summary()
|
||||
|
||||
# train the VAE on MNIST digits
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
|
||||
x_train = x_train.astype('float32')[:, None, :, :] / 255.
|
||||
x_test = x_test.astype('float32')[:, None, :, :] / 255.
|
||||
|
||||
vae.fit(x_train, x_train,
|
||||
shuffle=True,
|
||||
nb_epoch=nb_epoch,
|
||||
batch_size=batch_size,
|
||||
validation_data=(x_test, x_test))
|
||||
|
||||
|
||||
# build a model to project inputs on the latent space
|
||||
encoder = Model(x, z_mean)
|
||||
|
||||
# display a 2D plot of the digit classes in the latent space
|
||||
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
|
||||
plt.figure(figsize=(6, 6))
|
||||
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
# build a digit generator that can sample from the learned distribution
|
||||
decoder_input = Input(shape=(latent_dim,))
|
||||
_h_decoded = decoder_h(decoder_input)
|
||||
_f_decoded = decoder_f(_h_decoded)
|
||||
_c_decoded = decoder_c(_f_decoded)
|
||||
_x_decoded_mean = decoder_mean(_c_decoded)
|
||||
generator = Model(decoder_input, _x_decoded_mean)
|
||||
|
||||
# display a 2D manifold of the digits
|
||||
n = 15 # figure with 15x15 digits
|
||||
digit_size = 28
|
||||
figure = np.zeros((digit_size * n, digit_size * n))
|
||||
# we will sample n points within [-15, 15] standard deviations
|
||||
grid_x = np.linspace(-15, 15, n)
|
||||
grid_y = np.linspace(-15, 15, n)
|
||||
|
||||
for i, yi in enumerate(grid_x):
|
||||
for j, xi in enumerate(grid_y):
|
||||
z_sample = np.array([[xi, yi]]) * epsilon_std
|
||||
x_decoded = generator.predict(z_sample)
|
||||
digit = x_decoded[0].reshape(digit_size, digit_size)
|
||||
figure[i * digit_size: (i + 1) * digit_size,
|
||||
j * digit_size: (j + 1) * digit_size] = digit
|
||||
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(figure)
|
||||
plt.show()
|
@ -1250,6 +1250,12 @@ def l2_normalize(x, axis):
|
||||
|
||||
# CONVOLUTIONS
|
||||
|
||||
def _preprocess_deconv_output_shape(sh, dim_ordering):
|
||||
if dim_ordering == "th":
|
||||
sh = (sh[0], sh[2], sh[3], sh[1])
|
||||
return sh
|
||||
|
||||
|
||||
def _preprocess_conv2d_input(x, dim_ordering):
|
||||
if _FLOATX == 'float64':
|
||||
x = tf.cast(x, 'float32')
|
||||
@ -1375,11 +1381,12 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1),
|
||||
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
|
||||
|
||||
x = _preprocess_conv2d_input(x, dim_ordering)
|
||||
output_shape = _preprocess_deconv_output_shape(output_shape, dim_ordering)
|
||||
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
|
||||
kernel = tf.transpose(kernel, (0, 1, 3, 2)) # tranpose kernel chanels
|
||||
padding = _preprocess_border_mode(border_mode)
|
||||
strides = (1,) + strides + (1,)
|
||||
|
||||
# TODO: pre-process output_shape if dim_ordering == th
|
||||
x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides,
|
||||
padding=padding)
|
||||
return _postprocess_conv2d_output(x, dim_ordering)
|
||||
|
@ -937,6 +937,79 @@ def l2_normalize(x, axis):
|
||||
|
||||
# CONVOLUTIONS
|
||||
|
||||
def _preprocess_conv2d_input(x, dim_ordering):
|
||||
if dim_ordering == 'tf':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH input shape: (samples, input_depth, rows, cols)
|
||||
# TF input shape: (samples, rows, cols, input_depth)
|
||||
x = x.dimshuffle((0, 3, 1, 2))
|
||||
return x
|
||||
|
||||
|
||||
def _preprocess_conv2d_kernel(kernel, dim_ordering):
|
||||
if dim_ordering == 'tf':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH kernel shape: (depth, input_depth, rows, cols)
|
||||
# TF kernel shape: (rows, cols, input_depth, depth)
|
||||
kernel = kernel.dimshuffle((3, 2, 0, 1))
|
||||
return kernel
|
||||
|
||||
|
||||
def _preprocess_border_mode(border_mode):
|
||||
if border_mode == 'same':
|
||||
th_border_mode = 'half'
|
||||
elif border_mode == 'valid':
|
||||
th_border_mode = 'valid'
|
||||
else:
|
||||
raise Exception('Border mode not supported: ' + str(border_mode))
|
||||
return th_border_mode
|
||||
|
||||
|
||||
def _preprocess_image_shape(dim_ordering, image_shape):
|
||||
# Theano might not accept long type
|
||||
def int_or_none(value):
|
||||
try:
|
||||
return int(value)
|
||||
except TypeError:
|
||||
return None
|
||||
if dim_ordering == 'tf':
|
||||
if image_shape:
|
||||
image_shape = (image_shape[0], image_shape[3],
|
||||
image_shape[1], image_shape[2])
|
||||
if image_shape is not None:
|
||||
image_shape = tuple(int_or_none(v) for v in image_shape)
|
||||
return image_shape
|
||||
|
||||
|
||||
def _preprocess_filter_shape(dim_ordering, filter_shape):
|
||||
# Theano might not accept long type
|
||||
def int_or_none(value):
|
||||
try:
|
||||
return int(value)
|
||||
except TypeError:
|
||||
return None
|
||||
if dim_ordering == 'tf':
|
||||
if filter_shape:
|
||||
filter_shape = (filter_shape[3], filter_shape[2],
|
||||
filter_shape[0], filter_shape[1])
|
||||
if filter_shape is not None:
|
||||
filter_shape = tuple(int_or_none(v) for v in filter_shape)
|
||||
return filter_shape
|
||||
|
||||
|
||||
def _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel, strides, dim_ordering):
|
||||
if border_mode == 'same':
|
||||
if np_kernel.shape[2] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :]
|
||||
if np_kernel.shape[3] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]]
|
||||
if dim_ordering == 'tf':
|
||||
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
|
||||
return conv_out
|
||||
|
||||
|
||||
def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
|
||||
dim_ordering=_IMAGE_DIM_ORDERING, image_shape=None,
|
||||
filter_shape=None, filter_dilation=(1, 1)):
|
||||
@ -953,42 +1026,12 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
|
||||
|
||||
if dim_ordering == 'tf':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH input shape: (samples, input_depth, rows, cols)
|
||||
# TF input shape: (samples, rows, cols, input_depth)
|
||||
# TH kernel shape: (depth, input_depth, rows, cols)
|
||||
# TF kernel shape: (rows, cols, input_depth, depth)
|
||||
x = x.dimshuffle((0, 3, 1, 2))
|
||||
kernel = kernel.dimshuffle((3, 2, 0, 1))
|
||||
if image_shape:
|
||||
image_shape = (image_shape[0], image_shape[3],
|
||||
image_shape[1], image_shape[2])
|
||||
if filter_shape:
|
||||
filter_shape = (filter_shape[3], filter_shape[2],
|
||||
filter_shape[0], filter_shape[1])
|
||||
|
||||
if border_mode == 'same':
|
||||
th_border_mode = 'half'
|
||||
np_kernel = kernel.eval()
|
||||
elif border_mode == 'valid':
|
||||
th_border_mode = 'valid'
|
||||
else:
|
||||
raise Exception('Border mode not supported: ' + str(border_mode))
|
||||
|
||||
# Theano might not accept long type
|
||||
def int_or_none(value):
|
||||
try:
|
||||
return int(value)
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
if image_shape is not None:
|
||||
image_shape = tuple(int_or_none(v) for v in image_shape)
|
||||
|
||||
if filter_shape is not None:
|
||||
filter_shape = tuple(int_or_none(v) for v in filter_shape)
|
||||
x = _preprocess_conv2d_input(x, dim_ordering)
|
||||
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
|
||||
th_border_mode = _preprocess_border_mode(border_mode)
|
||||
np_kernel = kernel.eval()
|
||||
image_shape = _preprocess_image_shape(dim_ordering, image_shape)
|
||||
filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape)
|
||||
|
||||
# TODO: remove the if statement when theano with no filter dilation is deprecated.
|
||||
if filter_dilation == (1, 1):
|
||||
@ -1005,14 +1048,8 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
|
||||
filter_shape=filter_shape,
|
||||
filter_dilation=filter_dilation)
|
||||
|
||||
if border_mode == 'same':
|
||||
if np_kernel.shape[2] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :]
|
||||
if np_kernel.shape[3] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]]
|
||||
|
||||
if dim_ordering == 'tf':
|
||||
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
|
||||
conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel,
|
||||
strides, dim_ordering)
|
||||
return conv_out
|
||||
|
||||
|
||||
@ -1020,7 +1057,38 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1),
|
||||
border_mode='valid',
|
||||
dim_ordering=_IMAGE_DIM_ORDERING,
|
||||
image_shape=None, filter_shape=None):
|
||||
raise NotImplementedError
|
||||
'''2D deconvolution (transposed convolution).
|
||||
|
||||
# Arguments
|
||||
kernel: kernel tensor.
|
||||
output_shape: desired dimensions of output.
|
||||
strides: strides tuple.
|
||||
border_mode: string, "same" or "valid".
|
||||
dim_ordering: "tf" or "th".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
in inputs/kernels/ouputs.
|
||||
'''
|
||||
flip_filters = False
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
|
||||
|
||||
x = _preprocess_conv2d_input(x, dim_ordering)
|
||||
kernel = _preprocess_conv2d_kernel(kernel, dim_ordering)
|
||||
kernel = kernel.dimshuffle((1, 0, 2, 3))
|
||||
th_border_mode = _preprocess_border_mode(border_mode)
|
||||
np_kernel = kernel.eval()
|
||||
filter_shape = _preprocess_filter_shape(dim_ordering, filter_shape)
|
||||
|
||||
op = T.nnet.abstract_conv.AbstractConv2d_gradInputs(imshp=output_shape,
|
||||
kshp=filter_shape,
|
||||
subsample=strides,
|
||||
border_mode=th_border_mode,
|
||||
filter_flip=not flip_filters)
|
||||
conv_out = op(kernel, x, output_shape[2:])
|
||||
|
||||
conv_out = _postprocess_conv2d_output(conv_out, x, border_mode, np_kernel,
|
||||
strides, dim_ordering)
|
||||
return conv_out
|
||||
|
||||
|
||||
def atrous_conv2d(x, kernel, rate=1,
|
||||
|
@ -4,7 +4,7 @@ from __future__ import absolute_import
|
||||
from .. import backend as K
|
||||
from .. import activations, initializations, regularizers, constraints
|
||||
from ..engine import Layer, InputSpec
|
||||
from ..utils.np_utils import conv_output_length
|
||||
from ..utils.np_utils import conv_output_length, conv_input_length
|
||||
|
||||
# imports for backwards namespace compatibility
|
||||
from .pooling import AveragePooling1D, AveragePooling2D, AveragePooling3D
|
||||
@ -379,6 +379,79 @@ class Convolution2D(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Deconvolution2D(Convolution2D):
|
||||
'''Transposed convolution operator for filtering windows of two-dimensional inputs.
|
||||
When using this layer as the first layer in a model,
|
||||
provide the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the sample axis),
|
||||
e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures.
|
||||
'''
|
||||
def __init__(self, nb_filter, nb_row, nb_col, output_shape,
|
||||
init='glorot_uniform', activation='linear', weights=None,
|
||||
border_mode='valid', subsample=(1, 1),
|
||||
dim_ordering=K.image_dim_ordering(),
|
||||
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
|
||||
W_constraint=None, b_constraint=None,
|
||||
bias=True, **kwargs):
|
||||
|
||||
if border_mode not in {'valid', 'same'}:
|
||||
raise Exception('Invalid border mode for AtrousConv2D:', border_mode)
|
||||
|
||||
self.output_shape_ = output_shape
|
||||
|
||||
super(Deconvolution2D, self).__init__(nb_filter, nb_row, nb_col,
|
||||
init=init, activation=activation,
|
||||
weights=weights, border_mode=border_mode,
|
||||
subsample=subsample, dim_ordering=dim_ordering,
|
||||
W_regularizer=W_regularizer, b_regularizer=b_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
W_constraint=W_constraint, b_constraint=b_constraint,
|
||||
bias=bias, **kwargs)
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
if self.dim_ordering == 'th':
|
||||
rows = input_shape[2]
|
||||
cols = input_shape[3]
|
||||
elif self.dim_ordering == 'tf':
|
||||
rows = input_shape[1]
|
||||
cols = input_shape[2]
|
||||
else:
|
||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||
|
||||
rows = conv_input_length(rows, self.nb_row,
|
||||
self.border_mode, self.subsample[0])
|
||||
cols = conv_input_length(cols, self.nb_col,
|
||||
self.border_mode, self.subsample[1])
|
||||
|
||||
if self.dim_ordering == 'th':
|
||||
return (input_shape[0], self.nb_filter, rows, cols)
|
||||
elif self.dim_ordering == 'tf':
|
||||
return (input_shape[0], rows, cols, self.nb_filter)
|
||||
else:
|
||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||
|
||||
def call(self, x, mask=None):
|
||||
output = K.deconv2d(x, self.W, self.output_shape_,
|
||||
strides=self.subsample,
|
||||
border_mode=self.border_mode,
|
||||
dim_ordering=self.dim_ordering,
|
||||
filter_shape=self.W_shape)
|
||||
if self.bias:
|
||||
if self.dim_ordering == 'th':
|
||||
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
|
||||
elif self.dim_ordering == 'tf':
|
||||
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
|
||||
else:
|
||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||
output = self.activation(output)
|
||||
return output
|
||||
|
||||
def get_config(self):
|
||||
config = {'output_shape': self.output_shape}
|
||||
base_config = super(Deconvolution2D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class AtrousConvolution2D(Convolution2D):
|
||||
'''Atrous Convolution operator for filtering windows of two-dimensional inputs.
|
||||
A.k.a dilated convolution or convolution with holes.
|
||||
@ -1251,5 +1324,6 @@ class ZeroPadding3D(Layer):
|
||||
Conv1D = Convolution1D
|
||||
Conv2D = Convolution2D
|
||||
Conv3D = Convolution3D
|
||||
Deconv2D = Deconvolution2D
|
||||
AtrousConv2D = AtrousConvolution2D
|
||||
SeparableConv2D = SeparableConvolution2D
|
||||
|
@ -120,3 +120,13 @@ def conv_output_length(input_length, filter_size, border_mode, stride, dilation=
|
||||
elif border_mode == 'valid':
|
||||
output_length = input_length - dilated_filter_size + 1
|
||||
return (output_length + stride - 1) // stride
|
||||
|
||||
def conv_input_length(output_length, filter_size, border_mode, stride):
|
||||
if output_length is None:
|
||||
return None
|
||||
assert border_mode in {'same', 'valid'}
|
||||
if border_mode == 'same':
|
||||
pad = filter_size // 2
|
||||
elif border_mode == 'valid':
|
||||
pad = 0
|
||||
return (output_length - 1) * stride - 2 * pad + filter_size
|
||||
|
@ -36,7 +36,7 @@ def get_test_data(nb_train=1000, nb_test=500, input_shape=(10,),
|
||||
|
||||
|
||||
def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
|
||||
input_data=None, expected_output=None, expected_output_dtype=None):
|
||||
input_data=None, expected_output=None, expected_output_dtype=None, fixed_batch_size=False):
|
||||
'''Test routine for a layer with a single input tensor
|
||||
and single output tensor.
|
||||
'''
|
||||
@ -64,7 +64,10 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
|
||||
layer = layer_cls(**kwargs)
|
||||
|
||||
# test in functional API
|
||||
x = Input(shape=input_shape[1:], dtype=input_dtype)
|
||||
if fixed_batch_size:
|
||||
x = Input(batch_shape=input_shape, dtype=input_dtype)
|
||||
else:
|
||||
x = Input(shape=input_shape[1:], dtype=input_dtype)
|
||||
y = layer(x)
|
||||
assert K.dtype(y) == expected_output_dtype
|
||||
|
||||
|
@ -3,6 +3,7 @@ import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from keras.utils.test_utils import layer_test, keras_test
|
||||
from keras.utils.np_utils import conv_input_length
|
||||
from keras import backend as K
|
||||
from keras.layers import convolutional
|
||||
|
||||
@ -88,6 +89,45 @@ def test_convolution_2d():
|
||||
input_shape=(nb_samples, stack_size, nb_row, nb_col))
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_deconvolution_2d():
|
||||
nb_samples = 2
|
||||
nb_filter = 2
|
||||
stack_size = 3
|
||||
nb_row = 10
|
||||
nb_col = 6
|
||||
|
||||
for border_mode in ['valid', 'same']:
|
||||
for subsample in [(1, 1), (2, 2)]:
|
||||
if border_mode == 'same' and subsample != (1, 1):
|
||||
continue
|
||||
|
||||
rows = conv_input_length(nb_row, 3, border_mode, subsample[0])
|
||||
cols = conv_input_length(nb_col, 3, border_mode, subsample[1])
|
||||
layer_test(convolutional.Deconvolution2D,
|
||||
kwargs={'nb_filter': nb_filter,
|
||||
'nb_row': 3,
|
||||
'nb_col': 3,
|
||||
'output_shape': (nb_samples, nb_filter, rows, cols),
|
||||
'border_mode': border_mode,
|
||||
'subsample': subsample},
|
||||
input_shape=(nb_samples, stack_size, nb_row, nb_col),
|
||||
fixed_batch_size=True)
|
||||
|
||||
layer_test(convolutional.Deconvolution2D,
|
||||
kwargs={'nb_filter': nb_filter,
|
||||
'nb_row': 3,
|
||||
'nb_col': 3,
|
||||
'output_shape': (nb_samples, nb_filter, rows, cols),
|
||||
'border_mode': border_mode,
|
||||
'W_regularizer': 'l2',
|
||||
'b_regularizer': 'l2',
|
||||
'activity_regularizer': 'activity_l2',
|
||||
'subsample': subsample},
|
||||
input_shape=(nb_samples, stack_size, nb_row, nb_col),
|
||||
fixed_batch_size=True)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_atrous_conv_2d():
|
||||
nb_samples = 2
|
||||
|
Loading…
Reference in New Issue
Block a user