Touch-ups in pooling layers

This commit is contained in:
Francois Chollet 2015-12-08 10:16:47 -08:00
parent 31534bd15e
commit 93af5e95fd

@ -243,7 +243,8 @@ class Convolution2D(Layer):
class Pooling1D(Layer):
input_dim = 3
def __init__(self, pool_length=2, stride=None, border_mode='valid', **kwargs):
def __init__(self, pool_length=2, stride=None,
border_mode='valid', **kwargs):
super(Pooling1D, self).__init__(**kwargs)
if stride is None:
stride = pool_length
@ -262,14 +263,15 @@ class Pooling1D(Layer):
self.border_mode, self.stride)
return (input_shape[0], length, input_shape[2])
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
def pooling_function(self, back_end, inputs, pool_size, strides,
border_mode, dim_ordering):
raise NotImplementedError
def get_output(self, train=False):
X = self.get_input(train)
X = K.expand_dims(X, -1) # add dummy last dimension
X = K.permute_dimensions(X, (0, 2, 1, 3))
output = self.pooling_function(back_end=K, inputs=X, pool_size=self.pool_size,
output = self.pooling_function(inputs=X, pool_size=self.pool_size,
strides=self.st,
border_mode=self.border_mode,
dim_ordering='th')
@ -286,22 +288,24 @@ class Pooling1D(Layer):
class MaxPooling1D(Pooling1D):
def __init__(self, **kwargs):
super(MaxPooling1D, self).__init__(**kwargs)
def __init__(self, *args, **kwargs):
super(MaxPooling1D, self).__init__(*args, **kwargs)
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
output = back_end.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
return output
class AveragePooling1D(Pooling1D):
def __init__(self, **kwargs):
super(AveragePooling1D, self).__init__(**kwargs)
def __init__(self, *args, **kwargs):
super(AveragePooling1D, self).__init__(*args, **kwargs)
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
output = back_end.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
return output
@ -345,12 +349,13 @@ class Pooling2D(Layer):
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
raise NotImplementedError
def get_output(self, train=False):
X = self.get_input(train)
output = self.pooling_function(back_end=K, inputs=X, pool_size=self.pool_size,
output = self.pooling_function(inputs=X, pool_size=self.pool_size,
strides=self.strides,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering)
@ -367,22 +372,24 @@ class Pooling2D(Layer):
class MaxPooling2D(Pooling2D):
def __init__(self, **kwargs):
super(MaxPooling2D, self).__init__(**kwargs)
def __init__(self, *args, **kwargs):
super(MaxPooling2D, self).__init__(*args, **kwargs)
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
output = back_end.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='max')
return output
class AveragePooling2D(Pooling2D):
def __init__(self, **kwargs):
super(AveragePooling2D, self).__init__(**kwargs)
def __init__(self, *args, **kwargs):
super(AveragePooling2D, self).__init__(*args, **kwargs)
def pooling_function(self, back_end, inputs, pool_size, strides, border_mode, dim_ordering):
output = back_end.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
def pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
output = K.pool2d(inputs, pool_size, strides,
border_mode, dim_ordering, pool_mode='avg')
return output