Merge branch 'maxpumperla-pooling'

This commit is contained in:
Francois Chollet 2015-12-08 10:21:39 -08:00
commit d400fc4512
8 changed files with 217 additions and 45 deletions

@ -15,6 +15,7 @@ pages:
- Index: documentation.md
- Examples: examples.md
- FAQ: faq.md
- Backends: backend.md
- Optimizers: optimizers.md
- Objectives: objectives.md
- Models: models.md

@ -100,6 +100,26 @@ Max pooling operation for temporal data.
---
## AveragePooling1D
```python
keras.layers.convolutional.AveragePooling1D(pool_length=2, stride=None, border_mode='valid')
```
Average pooling operation for temporal data.
- __Input shape__: 3D tensor with shape: `(samples, steps, features)`.
- __Output shape__: 3D tensor with shape: `(samples, downsampled_steps, features)`.
- __Arguments__:
- __pool_length__: factor by which to downscale. 2 will halve the input.
- __stride__: integer or None. Stride value.
- __border_mode__: 'valid' or 'same'. **Note:** 'same' will only work with TensorFlow for the time being.
---
## MaxPooling2D
```python
@ -121,6 +141,28 @@ or 4D tensor with shape: `(samples, pooled_rows, pooled_cols, channels)` if dim_
- __border_mode__: 'valid' or 'same'. **Note:** 'same' will only work with TensorFlow for the time being.
- __dim_ordering__: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3.
---
## AveragePooling2D
```python
keras.layers.convolutional.AveragePooling2D(pool_size=(2, 2), border_mode='valid', dim_ordering='th')
```
Average pooling operation for spatial data.
- __Input shape__: 4D tensor with shape: `(samples, channels, rows, cols)` if dim_ordering='th'
or 4D tensor with shape: `(samples, rows, cols, channels)` if dim_ordering='tf'.
- __Output shape__: 4D tensor with shape: `(nb_samples, channels, pooled_rows, pooled_cols)` if dim_ordering='th'
or 4D tensor with shape: `(samples, pooled_rows, pooled_cols, channels)` if dim_ordering='tf'.
- __Arguments__:
- __pool_size__: tuple of 2 integers, factors by which to downscale (vertical, horizontal). (2, 2) will halve the image in each dimension.
- __strides__: tuple of 2 integers, or None. Strides values.
- __border_mode__: 'valid' or 'same'. **Note:** 'same' will only work with TensorFlow for the time being.
- __dim_ordering__: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3.
---
@ -199,4 +241,4 @@ or 4D tensor with shape: `(samples, padded_rows, padded_cols, channels)` if dim_
- __padding__: tuple of 2 integers, the size of the padding for rows and columns respectively.
- __dim_ordering__: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3.
---
---

@ -545,8 +545,8 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
return x
def maxpool2d(x, pool_size, strides=(1, 1),
border_mode='valid', dim_ordering='th'):
def pool2d(x, pool_size, strides=(1, 1),
border_mode='valid', dim_ordering='th', pool_mode='max'):
'''
pool_size: tuple of 2 integers.
strides: tuple of 2 integers.
@ -567,18 +567,23 @@ def maxpool2d(x, pool_size, strides=(1, 1),
# tf max_pool only supports float32
x = tf.cast(x, 'float32')
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = tf.transpose(x, (0, 2, 3, 1))
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
x = tf.transpose(x, (0, 3, 1, 2))
elif dim_ordering == 'tf':
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
if dim_ordering in {'tf', 'th'}:
if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = tf.transpose(x, (0, 2, 3, 1))
if pool_mode == 'max':
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
elif pool_mode == 'avg':
x = tf.nn.avg_pool(x, pool_size, strides, padding=padding)
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if dim_ordering == 'th':
x = tf.transpose(x, (0, 3, 1, 2))
else:
raise Exception('Unknown dim_ordering: ' + str(dim_ordering))

@ -579,8 +579,8 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
return conv_out
def maxpool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th'):
def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
# TODO: add implementation for border_mode="same"
raise Exception('border_mode="same" not supported with Theano.')
@ -596,19 +596,26 @@ def maxpool2d(x, pool_size, strides=(1, 1), border_mode='valid',
if dim_ordering == 'tf':
x = x.dimshuffle((0, 3, 1, 2))
pool_out = downsample.max_pool_2d(x,
ds=pool_size,
st=strides,
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')
if pool_mode == 'max':
pool_out = downsample.max_pool_2d(x, ds=pool_size, st=strides,
ignore_border=ignore_border,
padding=padding,
mode='max')
elif pool_mode == 'avg':
pool_out = downsample.max_pool_2d(x, ds=pool_size, st=strides,
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if dim_ordering == 'tf':
pool_out = pool_out.dimshuffle((0, 2, 3, 1))
return pool_out
# RANDOMNESS
def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
@ -622,8 +629,6 @@ def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
rng = RandomStreams(seed=seed)
return rng.uniform(shape, low=low, high=high, dtype=dtype)
'''
more TODO:

@ -240,12 +240,12 @@ class Convolution2D(Layer):
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling1D(Layer):
input_ndim = 3
class Pooling1D(Layer):
input_dim = 3
def __init__(self, pool_length=2, stride=None,
border_mode='valid', **kwargs):
super(MaxPooling1D, self).__init__(**kwargs)
super(Pooling1D, self).__init__(**kwargs)
if stride is None:
stride = pool_length
self.pool_length = pool_length
@ -263,13 +263,18 @@ class MaxPooling1D(Layer):
self.border_mode, self.stride)
return (input_shape[0], length, input_shape[2])
def pooling_function(self, back_end, inputs, pool_size, strides,
border_mode, dim_ordering):
raise NotImplementedError
def get_output(self, train=False):
X = self.get_input(train)
X = K.expand_dims(X, -1) # add dummy last dimension
X = K.permute_dimensions(X, (0, 2, 1, 3))
output = K.maxpool2d(X, pool_size=self.pool_size, strides=self.st,
border_mode=self.border_mode,
dim_ordering='th')
output = self.pooling_function(inputs=X, pool_size=self.pool_size,
strides=self.st,
border_mode=self.border_mode,
dim_ordering='th')
output = K.permute_dimensions(output, (0, 2, 1, 3))
return K.squeeze(output, 3) # remove dummy last dimension
@ -278,16 +283,38 @@ class MaxPooling1D(Layer):
"stride": self.stride,
"pool_length": self.pool_length,
"border_mode": self.border_mode}
base_config = super(MaxPooling1D, self).get_config()
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling2D(Layer):
class MaxPooling1D(Pooling1D):
def __init__(self, *args, **kwargs):
super(MaxPooling1D, self).__init__(*args, **kwargs)
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
return output
class AveragePooling1D(Pooling1D):
def __init__(self, *args, **kwargs):
super(AveragePooling1D, self).__init__(*args, **kwargs)
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
return output
class Pooling2D(Layer):
input_ndim = 4
def __init__(self, pool_size=(2, 2), strides=None, border_mode='valid',
dim_ordering='th', **kwargs):
super(MaxPooling2D, self).__init__(**kwargs)
super(Pooling2D, self).__init__(**kwargs)
self.input = K.placeholder(ndim=4)
self.pool_size = tuple(pool_size)
if strides is None:
@ -322,12 +349,16 @@ class MaxPooling2D(Layer):
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
raise NotImplementedError
def get_output(self, train=False):
X = self.get_input(train)
output = K.maxpool2d(X, pool_size=self.pool_size,
strides=self.strides,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering)
output = self.pooling_function(inputs=X, pool_size=self.pool_size,
strides=self.strides,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering)
return output
def get_config(self):
@ -336,10 +367,32 @@ class MaxPooling2D(Layer):
"border_mode": self.border_mode,
"strides": self.strides,
"dim_ordering": self.dim_ordering}
base_config = super(MaxPooling2D, self).get_config()
base_config = super(Pooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling2D(Pooling2D):
def __init__(self, *args, **kwargs):
super(MaxPooling2D, self).__init__(*args, **kwargs)
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
return output
class AveragePooling2D(Pooling2D):
def __init__(self, *args, **kwargs):
super(AveragePooling2D, self).__init__(*args, **kwargs)
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
return output
class UpSampling1D(Layer):
input_ndim = 3

@ -291,17 +291,17 @@ class TestBackend(unittest.TestCase):
# check_two_tensor_operation('conv2d', (5, 3, 10, 12), (4, 3, 3, 3),
# strides=(2, 2), border_mode='valid')
# def test_maxpool2d(self):
# '''maxpool2d works "properly" with Theano and TF but outputs different
# def test_pool2d(self):
# '''pool2d works "properly" with Theano and TF but outputs different
# values in each case. Cause unclear (input shape format?)
# '''
# check_single_tensor_operation('maxpool2d', (5, 3, 10, 12), pool_size=(2, 2),
# check_single_tensor_operation('pool2d', (5, 3, 10, 12), pool_size=(2, 2),
# strides=(1, 1), border_mode='valid')
# check_single_tensor_operation('maxpool2d', (5, 3, 9, 11), pool_size=(2, 2),
# check_single_tensor_operation('pool2d', (5, 3, 9, 11), pool_size=(2, 2),
# strides=(1, 1), border_mode='valid')
# check_single_tensor_operation('maxpool2d', (5, 3, 9, 11), pool_size=(2, 3),
# check_single_tensor_operation('pool2d', (5, 3, 9, 11), pool_size=(2, 3),
# strides=(1, 1), border_mode='valid')
def test_random_normal(self):

@ -58,6 +58,20 @@ class TestConvolutions(unittest.TestCase):
K.eval(layer.get_output(train))
layer.get_config()
def test_averagepooling_1d(self):
nb_samples = 9
nb_steps = 7
input_dim = 10
input = np.ones((nb_samples, nb_steps, input_dim))
for stride in [1, 2]:
layer = convolutional.AveragePooling1D(stride=stride,
border_mode='valid')
layer.input = K.variable(input)
for train in [True, False]:
K.eval(layer.get_output(train))
layer.get_config()
def test_convolution_2d(self):
nb_samples = 8
nb_filter = 9
@ -113,6 +127,23 @@ class TestConvolutions(unittest.TestCase):
K.eval(layer.get_output(train))
layer.get_config()
def test_averagepooling_2d(self):
nb_samples = 9
stack_size = 7
input_nb_row = 11
input_nb_col = 12
pool_size = (3, 3)
input = np.ones((nb_samples, stack_size, input_nb_row, input_nb_col))
for strides in [(1, 1), (2, 2)]:
layer = convolutional.AveragePooling2D(strides=strides,
border_mode='valid',
pool_size=pool_size)
layer.input = K.variable(input)
for train in [True, False]:
K.eval(layer.get_output(train))
layer.get_config()
def test_zero_padding_2d(self):
nb_samples = 9
stack_size = 7

@ -56,6 +56,7 @@ def test_TimeDistributedDense():
input_data = np.random.random((2, 2, 3))
check_layer_output_shape(layer, input_data)
#################
# Convolutional #
#################
@ -111,6 +112,18 @@ def test_MaxPooling1D():
check_layer_output_shape(layer, input_data)
def test_AveragePooling1D():
for ignore_border in [True, False]:
for pool_length in [1, 2]:
for stride in [1]:
for input_data_shape in [(2, 3, 4), (2, 4, 4)]:
layer = AveragePooling1D(pool_length=pool_length,
stride=stride,
border_mode='valid')
input_data = np.random.random(input_data_shape)
check_layer_output_shape(layer, input_data)
def test_MaxPooling2D():
for ignore_border in [True, False]:
for strides in [(1, 1), (2, 2)]:
@ -132,6 +145,27 @@ def test_MaxPooling2D():
check_layer_output_shape(layer, input_data)
def test_AveragePooling2D():
for ignore_border in [True, False]:
for strides in [(1, 1), (2, 2)]:
for pool_size in [(2, 2), (3, 3), (4, 4)]:
for input_data_shape in [(2, 1, 4, 4), (2, 1, 5, 5), (2, 1, 6, 6)]:
layer = AveragePooling2D(pool_size=pool_size,
strides=strides,
border_mode='valid',
dim_ordering='th')
input_data = np.random.random(input_data_shape)
check_layer_output_shape(layer, input_data)
for input_data_shape in [(2, 4, 4, 1)]:
layer = AveragePooling2D(pool_size=pool_size,
strides=strides,
border_mode='valid',
dim_ordering='tf')
input_data = np.random.random(input_data_shape)
check_layer_output_shape(layer, input_data)
def test_UpSampling1D():
layer = UpSampling1D(length=2)
input_data = np.random.random((2, 2, 3))
@ -163,6 +197,7 @@ def test_ZeroPadding2D():
input_data = np.random.random((2, 2, 3, 1))
check_layer_output_shape(layer, input_data)
# #############
# # Recurrent #
# #############