diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 0cb383a07..925c8f154 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1164,8 +1164,8 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1), padding = _preprocess_border_mode(border_mode) strides = (1,) + strides + (1,) - tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel, - strides, padding) + x = tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel, + strides, padding) return _postprocess_conv2d_output(x, dim_ordering) diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index fa84ff0c2..e598e279a 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -59,7 +59,8 @@ class Convolution1D(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. - bias: whether to include a bias (i.e. make the layer affine rather than linear). + bias: whether to include a bias + (i.e. make the layer affine rather than linear). input_dim: Number of channels/dimensions in the input. Either this argument or the keyword argument `input_shape`must be provided when using this layer as the first layer in a model. @@ -235,7 +236,8 @@ class Convolution2D(Layer): It defaults to the `image_dim_ordering` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "th". - bias: whether to include a bias (i.e. make the layer affine rather than linear). + bias: whether to include a bias + (i.e. make the layer affine rather than linear). # Input shape 4D tensor with shape: @@ -450,7 +452,8 @@ class AtrousConv2D(Convolution2D): ''' def __init__(self, nb_filter, nb_row, nb_col, init='glorot_uniform', activation='linear', weights=None, - border_mode='valid', subsample=(1, 1), atrous_rate=(1, 1), dim_ordering=K.image_dim_ordering(), + border_mode='valid', subsample=(1, 1), + atrous_rate=(1, 1), dim_ordering=K.image_dim_ordering(), W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True, **kwargs): @@ -513,6 +516,236 @@ class AtrousConv2D(Convolution2D): return dict(list(base_config.items()) + list(config.items())) +class SeparableConv2D(Layer): + '''Separable convolution operator for 2D inputs. + + Separable convolutions consist in first performing + a depthwise spatial convolution + (which acts on each input channel separately) + followed by a pointwise convolution which mixes together the resulting + output channels. The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + + Intuitively, separable convolutions can be understood as + a way to factorize a convolution kernel into two smaller kernels, + or as an extreme version of an Inception block. + + When using this layer as the first layer in a model, + provide the keyword argument `input_shape` + (tuple of integers, does not include the sample axis), + e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures. + + # Arguments + nb_filter: Number of convolution filters to use. + nb_row: Number of rows in the convolution kernel. + nb_col: Number of columns in the convolution kernel. + init: name of initialization function for the weights of the layer + (see [initializations](../initializations.md)), or alternatively, + Theano function to use for weights initialization. + This parameter is only relevant if you don't pass + a `weights` argument. + activation: name of activation function to use + (see [activations](../activations.md)), + or alternatively, elementwise Theano function. + If you don't specify anything, no activation is applied + (ie. "linear" activation: a(x) = x). + weights: list of numpy arrays to set as initial weights. + border_mode: 'valid' or 'same'. + subsample: tuple of length 2. Factor by which to subsample output. + Also called strides elsewhere. + depth_multiplier: how many output channel to use per input channel + for the depthwise convolution step. + atrous_rate: tuple of length 2. Factor for kernel dilation. + Also called filter_dilation elsewhere. + depthwise_regularizer: instance of [WeightRegularizer](../regularizers.md) + (eg. L1 or L2 regularization), applied to the depthwise weights matrix. + pointwise_regularizer: instance of [WeightRegularizer](../regularizers.md) + (eg. L1 or L2 regularization), applied to the pointwise weights matrix. + b_regularizer: instance of [WeightRegularizer](../regularizers.md), + applied to the bias. + activity_regularizer: instance of [ActivityRegularizer](../regularizers.md), + applied to the network output. + depthwise_constraint: instance of the [constraints](../constraints.md) module + (eg. maxnorm, nonneg), applied to the depthwise weights matrix. + pointwise_constraint: instance of the [constraints](../constraints.md) module + (eg. maxnorm, nonneg), applied to the pointwise weights matrix. + b_constraint: instance of the [constraints](../constraints.md) module, + applied to the bias. + dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension + (the depth) is at index 1, in 'tf' mode is it at index 3. + It defaults to the `image_dim_ordering` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "th". + bias: whether to include a bias + (i.e. make the layer affine rather than linear). + + # Input shape + 4D tensor with shape: + `(samples, channels, rows, cols)` if dim_ordering='th' + or 4D tensor with shape: + `(samples, rows, cols, channels)` if dim_ordering='tf'. + + # Output shape + 4D tensor with shape: + `(samples, nb_filter, new_rows, new_cols)` if dim_ordering='th' + or 4D tensor with shape: + `(samples, new_rows, new_cols, nb_filter)` if dim_ordering='tf'. + `rows` and `cols` values might have changed due to padding. + ''' + def __init__(self, nb_filter, nb_row, nb_col, + init='glorot_uniform', activation='linear', weights=None, + border_mode='valid', subsample=(1, 1), + depth_multiplier=1, dim_ordering=K.image_dim_ordering(), + depthwise_regularizer=None, pointwise_regularizer=None, + b_regularizer=None, activity_regularizer=None, + depthwise_constraint=None, pointwise_constraint=None, + b_constraint=None, + bias=True, **kwargs): + + if K._BACKEND != 'tensorflow': + raise Exception('SeparableConv2D is only available ' + 'with TensorFlow for the time being.') + + if border_mode not in {'valid', 'same'}: + raise Exception('Invalid border mode for AtrousConv2D:', border_mode) + + if border_mode not in {'valid', 'same'}: + raise Exception('Invalid border mode for Convolution2D:', border_mode) + self.nb_filter = nb_filter + self.nb_row = nb_row + self.nb_col = nb_col + self.init = initializations.get(init, dim_ordering=dim_ordering) + self.activation = activations.get(activation) + assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}' + self.border_mode = border_mode + self.subsample = tuple(subsample) + self.depth_multiplier = depth_multiplier + assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' + self.dim_ordering = dim_ordering + + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.pointwise_regularizer = regularizers.get(pointwise_regularizer) + self.b_regularizer = regularizers.get(b_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.pointwise_constraint = constraints.get(pointwise_constraint) + self.b_constraint = constraints.get(b_constraint) + + self.bias = bias + self.input_spec = [InputSpec(ndim=4)] + self.initial_weights = weights + super(SeparableConv2D, self).__init__(**kwargs) + + def build(self, input_shape): + if self.dim_ordering == 'th': + stack_size = input_shape[1] + depthwise_shape = (self.depth_multiplier, stack_size, self.nb_row, self.nb_col) + pointwise_shape = (self.nb_filter, self.depth_multiplier * stack_size, 1, 1) + elif self.dim_ordering == 'tf': + stack_size = input_shape[3] + depthwise_shape = (self.nb_row, self.nb_col, stack_size, self.depth_multiplier) + pointwise_shape = (1, 1, self.depth_multiplier * stack_size, self.nb_filter) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + self.depthwise_kernel = self.init(depthwise_shape, + name='{}_depthwise_kernel'.format(self.name)) + self.pointwise_kernel = self.init(pointwise_shape, + name='{}_pointwise_kernel'.format(self.name)) + if self.bias: + self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name)) + self.trainable_weights = [self.depthwise_kernel, + self.pointwise_kernel, + self.b] + else: + self.trainable_weights = [self.depthwise_kernel, + self.pointwise_kernel] + self.regularizers = [] + if self.depthwise_regularizer: + self.depthwise_regularizer.set_param(self.depthwise_kernel) + self.regularizers.append(self.depthwise_regularizer) + if self.pointwise_regularizer: + self.pointwise_regularizer.set_param(self.pointwise_kernel) + self.regularizers.append(self.pointwise_regularizer) + if self.bias and self.b_regularizer: + self.b_regularizer.set_param(self.b) + self.regularizers.append(self.b_regularizer) + if self.activity_regularizer: + self.activity_regularizer.set_layer(self) + self.regularizers.append(self.activity_regularizer) + + self.constraints = {} + if self.depthwise_constraint: + self.constraints[self.depthwise_kernel] = self.depthwise_constraint + if self.pointwise_constraint: + self.constraints[self.pointwise_kernel] = self.pointwise_constraint + if self.bias and self.b_constraint: + self.constraints[self.b] = self.b_constraint + + if self.initial_weights is not None: + self.set_weights(self.initial_weights) + del self.initial_weights + + def get_output_shape_for(self, input_shape): + if self.dim_ordering == 'th': + rows = input_shape[2] + cols = input_shape[3] + elif self.dim_ordering == 'tf': + rows = input_shape[1] + cols = input_shape[2] + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + rows = conv_output_length(rows, self.nb_row, + self.border_mode, self.subsample[0]) + cols = conv_output_length(cols, self.nb_col, + self.border_mode, self.subsample[1]) + + if self.dim_ordering == 'th': + return (input_shape[0], self.nb_filter, rows, cols) + elif self.dim_ordering == 'tf': + return (input_shape[0], rows, cols, self.nb_filter) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + def call(self, x, mask=None): + output = K.separable_conv2d(x, self.depthwise_kernel, + self.pointwise_kernel, + strides=self.subsample, + border_mode=self.border_mode, + dim_ordering=self.dim_ordering) + if self.bias: + if self.dim_ordering == 'th': + output += K.reshape(self.b, (1, self.nb_filter, 1, 1)) + elif self.dim_ordering == 'tf': + output += K.reshape(self.b, (1, 1, 1, self.nb_filter)) + else: + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + output = self.activation(output) + return output + + def get_config(self): + config = {'nb_filter': self.nb_filter, + 'nb_row': self.nb_row, + 'nb_col': self.nb_col, + 'init': self.init.__name__, + 'activation': self.activation.__name__, + 'border_mode': self.border_mode, + 'subsample': self.subsample, + 'depth_multiplier': self.depth_multiplier, + 'dim_ordering': self.dim_ordering, + 'depthwise_regularizer': self.depthwise_regularizer.get_config() if self.depthwise_regularizer else None, + 'pointwise_regularizer': self.depthwise_regularizer.get_config() if self.depthwise_regularizer else None, + 'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None, + 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, + 'depthwise_constraint': self.depthwise_constraint.get_config() if self.depthwise_constraint else None, + 'pointwise_constraint': self.pointwise_constraint.get_config() if self.pointwise_constraint else None, + 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'bias': self.bias} + base_config = super(SeparableConv2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class Convolution3D(Layer): '''Convolution operator for filtering windows of three-dimensional inputs. When using this layer as the first layer in a model, diff --git a/tests/keras/layers/test_convolutional.py b/tests/keras/layers/test_convolutional.py index fa37e023d..82237737f 100644 --- a/tests/keras/layers/test_convolutional.py +++ b/tests/keras/layers/test_convolutional.py @@ -53,7 +53,7 @@ def test_averagepooling_1d(): def test_convolution_2d(): - nb_samples = 8 + nb_samples = 2 nb_filter = 3 stack_size = 4 nb_row = 10 @@ -85,7 +85,7 @@ def test_convolution_2d(): def test_atrous_conv_2d(): - nb_samples = 8 + nb_samples = 2 nb_filter = 3 stack_size = 4 nb_row = 10 @@ -121,6 +121,45 @@ def test_atrous_conv_2d(): input_shape=(nb_samples, stack_size, nb_row, nb_col)) +@pytest.mark.skipif(K._BACKEND != 'tensorflow', reason="Requires TF backend") +def test_separable_conv_2d(): + nb_samples = 2 + nb_filter = 8 + stack_size = 4 + nb_row = 10 + nb_col = 6 + + for border_mode in ['valid', 'same']: + for subsample in [(1, 1), (2, 2)]: + for multiplier in [1, 2]: + if border_mode == 'same' and subsample != (1, 1): + continue + + layer_test(convolutional.SeparableConv2D, + kwargs={'nb_filter': nb_filter, + 'nb_row': 3, + 'nb_col': 3, + 'border_mode': border_mode, + 'subsample': subsample, + 'depth_multiplier': multiplier}, + input_shape=(nb_samples, stack_size, nb_row, nb_col)) + + layer_test(convolutional.SeparableConv2D, + kwargs={'nb_filter': nb_filter, + 'nb_row': 3, + 'nb_col': 3, + 'border_mode': border_mode, + 'depthwise_regularizer': 'l2', + 'pointwise_regularizer': 'l2', + 'b_regularizer': 'l2', + 'activity_regularizer': 'activity_l2', + 'pointwise_constraint': 'unitnorm', + 'depthwise_constraint': 'unitnorm', + 'subsample': subsample, + 'depth_multiplier': multiplier}, + input_shape=(nb_samples, stack_size, nb_row, nb_col)) + + def test_maxpooling_2d(): pool_size = (3, 3) @@ -209,7 +248,7 @@ def test_averagepooling_3d(): def test_zero_padding_2d(): - nb_samples = 9 + nb_samples = 2 stack_size = 7 input_nb_row = 11 input_nb_col = 12 @@ -235,7 +274,7 @@ def test_zero_padding_2d(): @pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") def test_zero_padding_3d(): - nb_samples = 9 + nb_samples = 2 stack_size = 7 input_len_dim1 = 10 input_len_dim2 = 11 @@ -268,7 +307,7 @@ def test_upsampling_1d(): def test_upsampling_2d(): - nb_samples = 9 + nb_samples = 2 stack_size = 7 input_nb_row = 11 input_nb_col = 12 @@ -309,7 +348,7 @@ def test_upsampling_2d(): @pytest.mark.skipif(K._BACKEND != 'theano', reason="Requires Theano backend") def test_upsampling_3d(): - nb_samples = 9 + nb_samples = 2 stack_size = 7 input_len_dim1 = 10 input_len_dim2 = 11