Add SeparableConv2D layer (TF only)

This commit is contained in:
Francois Chollet 2016-07-14 11:22:27 -07:00
parent b35b943364
commit 47c09d9557
3 changed files with 283 additions and 11 deletions

@ -1164,8 +1164,8 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)
tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel,
strides, padding)
x = tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel,
strides, padding)
return _postprocess_conv2d_output(x, dim_ordering)

@ -59,7 +59,8 @@ class Convolution1D(Layer):
(eg. maxnorm, nonneg), applied to the main weights matrix.
b_constraint: instance of the [constraints](../constraints.md) module,
applied to the bias.
bias: whether to include a bias (i.e. make the layer affine rather than linear).
bias: whether to include a bias
(i.e. make the layer affine rather than linear).
input_dim: Number of channels/dimensions in the input.
Either this argument or the keyword argument `input_shape`must be
provided when using this layer as the first layer in a model.
@ -235,7 +236,8 @@ class Convolution2D(Layer):
It defaults to the `image_dim_ordering` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "th".
bias: whether to include a bias (i.e. make the layer affine rather than linear).
bias: whether to include a bias
(i.e. make the layer affine rather than linear).
# Input shape
4D tensor with shape:
@ -450,7 +452,8 @@ class AtrousConv2D(Convolution2D):
'''
def __init__(self, nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1), atrous_rate=(1, 1), dim_ordering=K.image_dim_ordering(),
border_mode='valid', subsample=(1, 1),
atrous_rate=(1, 1), dim_ordering=K.image_dim_ordering(),
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, **kwargs):
@ -513,6 +516,236 @@ class AtrousConv2D(Convolution2D):
return dict(list(base_config.items()) + list(config.items()))
class SeparableConv2D(Layer):
'''Separable convolution operator for 2D inputs.
Separable convolutions consist in first performing
a depthwise spatial convolution
(which acts on each input channel separately)
followed by a pointwise convolution which mixes together the resulting
output channels. The `depth_multiplier` argument controls how many
output channels are generated per input channel in the depthwise step.
Intuitively, separable convolutions can be understood as
a way to factorize a convolution kernel into two smaller kernels,
or as an extreme version of an Inception block.
When using this layer as the first layer in a model,
provide the keyword argument `input_shape`
(tuple of integers, does not include the sample axis),
e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures.
# Arguments
nb_filter: Number of convolution filters to use.
nb_row: Number of rows in the convolution kernel.
nb_col: Number of columns in the convolution kernel.
init: name of initialization function for the weights of the layer
(see [initializations](../initializations.md)), or alternatively,
Theano function to use for weights initialization.
This parameter is only relevant if you don't pass
a `weights` argument.
activation: name of activation function to use
(see [activations](../activations.md)),
or alternatively, elementwise Theano function.
If you don't specify anything, no activation is applied
(ie. "linear" activation: a(x) = x).
weights: list of numpy arrays to set as initial weights.
border_mode: 'valid' or 'same'.
subsample: tuple of length 2. Factor by which to subsample output.
Also called strides elsewhere.
depth_multiplier: how many output channel to use per input channel
for the depthwise convolution step.
atrous_rate: tuple of length 2. Factor for kernel dilation.
Also called filter_dilation elsewhere.
depthwise_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the depthwise weights matrix.
pointwise_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the pointwise weights matrix.
b_regularizer: instance of [WeightRegularizer](../regularizers.md),
applied to the bias.
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
applied to the network output.
depthwise_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the depthwise weights matrix.
pointwise_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the pointwise weights matrix.
b_constraint: instance of the [constraints](../constraints.md) module,
applied to the bias.
dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
(the depth) is at index 1, in 'tf' mode is it at index 3.
It defaults to the `image_dim_ordering` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "th".
bias: whether to include a bias
(i.e. make the layer affine rather than linear).
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if dim_ordering='tf'.
# Output shape
4D tensor with shape:
`(samples, nb_filter, new_rows, new_cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, new_rows, new_cols, nb_filter)` if dim_ordering='tf'.
`rows` and `cols` values might have changed due to padding.
'''
def __init__(self, nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1),
depth_multiplier=1, dim_ordering=K.image_dim_ordering(),
depthwise_regularizer=None, pointwise_regularizer=None,
b_regularizer=None, activity_regularizer=None,
depthwise_constraint=None, pointwise_constraint=None,
b_constraint=None,
bias=True, **kwargs):
if K._BACKEND != 'tensorflow':
raise Exception('SeparableConv2D is only available '
'with TensorFlow for the time being.')
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for AtrousConv2D:', border_mode)
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for Convolution2D:', border_mode)
self.nb_filter = nb_filter
self.nb_row = nb_row
self.nb_col = nb_col
self.init = initializations.get(init, dim_ordering=dim_ordering)
self.activation = activations.get(activation)
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
self.border_mode = border_mode
self.subsample = tuple(subsample)
self.depth_multiplier = depth_multiplier
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
self.dim_ordering = dim_ordering
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
self.pointwise_regularizer = regularizers.get(pointwise_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.pointwise_constraint = constraints.get(pointwise_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.input_spec = [InputSpec(ndim=4)]
self.initial_weights = weights
super(SeparableConv2D, self).__init__(**kwargs)
def build(self, input_shape):
if self.dim_ordering == 'th':
stack_size = input_shape[1]
depthwise_shape = (self.depth_multiplier, stack_size, self.nb_row, self.nb_col)
pointwise_shape = (self.nb_filter, self.depth_multiplier * stack_size, 1, 1)
elif self.dim_ordering == 'tf':
stack_size = input_shape[3]
depthwise_shape = (self.nb_row, self.nb_col, stack_size, self.depth_multiplier)
pointwise_shape = (1, 1, self.depth_multiplier * stack_size, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
self.depthwise_kernel = self.init(depthwise_shape,
name='{}_depthwise_kernel'.format(self.name))
self.pointwise_kernel = self.init(pointwise_shape,
name='{}_pointwise_kernel'.format(self.name))
if self.bias:
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
self.trainable_weights = [self.depthwise_kernel,
self.pointwise_kernel,
self.b]
else:
self.trainable_weights = [self.depthwise_kernel,
self.pointwise_kernel]
self.regularizers = []
if self.depthwise_regularizer:
self.depthwise_regularizer.set_param(self.depthwise_kernel)
self.regularizers.append(self.depthwise_regularizer)
if self.pointwise_regularizer:
self.pointwise_regularizer.set_param(self.pointwise_kernel)
self.regularizers.append(self.pointwise_regularizer)
if self.bias and self.b_regularizer:
self.b_regularizer.set_param(self.b)
self.regularizers.append(self.b_regularizer)
if self.activity_regularizer:
self.activity_regularizer.set_layer(self)
self.regularizers.append(self.activity_regularizer)
self.constraints = {}
if self.depthwise_constraint:
self.constraints[self.depthwise_kernel] = self.depthwise_constraint
if self.pointwise_constraint:
self.constraints[self.pointwise_kernel] = self.pointwise_constraint
if self.bias and self.b_constraint:
self.constraints[self.b] = self.b_constraint
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
rows = input_shape[2]
cols = input_shape[3]
elif self.dim_ordering == 'tf':
rows = input_shape[1]
cols = input_shape[2]
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
rows = conv_output_length(rows, self.nb_row,
self.border_mode, self.subsample[0])
cols = conv_output_length(cols, self.nb_col,
self.border_mode, self.subsample[1])
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, rows, cols)
elif self.dim_ordering == 'tf':
return (input_shape[0], rows, cols, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
output = K.separable_conv2d(x, self.depthwise_kernel,
self.pointwise_kernel,
strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'nb_filter': self.nb_filter,
'nb_row': self.nb_row,
'nb_col': self.nb_col,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'depth_multiplier': self.depth_multiplier,
'dim_ordering': self.dim_ordering,
'depthwise_regularizer': self.depthwise_regularizer.get_config() if self.depthwise_regularizer else None,
'pointwise_regularizer': self.depthwise_regularizer.get_config() if self.depthwise_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'depthwise_constraint': self.depthwise_constraint.get_config() if self.depthwise_constraint else None,
'pointwise_constraint': self.pointwise_constraint.get_config() if self.pointwise_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None,
'bias': self.bias}
base_config = super(SeparableConv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Convolution3D(Layer):
'''Convolution operator for filtering windows of three-dimensional inputs.
When using this layer as the first layer in a model,

@ -53,7 +53,7 @@ def test_averagepooling_1d():
def test_convolution_2d():
nb_samples = 8
nb_samples = 2
nb_filter = 3
stack_size = 4
nb_row = 10
@ -85,7 +85,7 @@ def test_convolution_2d():
def test_atrous_conv_2d():
nb_samples = 8
nb_samples = 2
nb_filter = 3
stack_size = 4
nb_row = 10
@ -121,6 +121,45 @@ def test_atrous_conv_2d():
input_shape=(nb_samples, stack_size, nb_row, nb_col))
@pytest.mark.skipif(K._BACKEND != 'tensorflow', reason="Requires TF backend")
def test_separable_conv_2d():
nb_samples = 2
nb_filter = 8
stack_size = 4
nb_row = 10
nb_col = 6
for border_mode in ['valid', 'same']:
for subsample in [(1, 1), (2, 2)]:
for multiplier in [1, 2]:
if border_mode == 'same' and subsample != (1, 1):
continue
layer_test(convolutional.SeparableConv2D,
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'border_mode': border_mode,
'subsample': subsample,
'depth_multiplier': multiplier},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
layer_test(convolutional.SeparableConv2D,
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'border_mode': border_mode,
'depthwise_regularizer': 'l2',
'pointwise_regularizer': 'l2',
'b_regularizer': 'l2',
'activity_regularizer': 'activity_l2',
'pointwise_constraint': 'unitnorm',
'depthwise_constraint': 'unitnorm',
'subsample': subsample,
'depth_multiplier': multiplier},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
def test_maxpooling_2d():
pool_size = (3, 3)
@ -209,7 +248,7 @@ def test_averagepooling_3d():
def test_zero_padding_2d():
nb_samples = 9
nb_samples = 2
stack_size = 7
input_nb_row = 11
input_nb_col = 12
@ -235,7 +274,7 @@ def test_zero_padding_2d():
@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend")
def test_zero_padding_3d():
nb_samples = 9
nb_samples = 2
stack_size = 7
input_len_dim1 = 10
input_len_dim2 = 11
@ -268,7 +307,7 @@ def test_upsampling_1d():
def test_upsampling_2d():
nb_samples = 9
nb_samples = 2
stack_size = 7
input_nb_row = 11
input_nb_col = 12
@ -309,7 +348,7 @@ def test_upsampling_2d():
@pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend")
def test_upsampling_3d():
nb_samples = 9
nb_samples = 2
stack_size = 7
input_len_dim1 = 10
input_len_dim2 = 11