Upsampling layer fix to work when input shape is None (#3429)

This commit is contained in:
Yann Henon 2016-08-09 17:47:23 -04:00 committed by François Chollet
parent 55447cbb3d
commit 4805e5856b

@ -1045,7 +1045,8 @@ class UpSampling1D(Layer):
super(UpSampling1D, self).__init__(**kwargs) super(UpSampling1D, self).__init__(**kwargs)
def get_output_shape_for(self, input_shape): def get_output_shape_for(self, input_shape):
return (input_shape[0], self.length * input_shape[1], input_shape[2]) length = self.length * input_shape[1] if input_shape[1] is not None else None
return (input_shape[0], length, input_shape[2])
def call(self, x, mask=None): def call(self, x, mask=None):
output = K.repeat_elements(x, self.length, axis=1) output = K.repeat_elements(x, self.length, axis=1)
@ -1094,14 +1095,18 @@ class UpSampling2D(Layer):
def get_output_shape_for(self, input_shape): def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th': if self.dim_ordering == 'th':
width = self.size[0] * input_shape[2] if input_shape[2] is not None else None
height = self.size[1] * input_shape[3] if input_shape[3] is not None else None
return (input_shape[0], return (input_shape[0],
input_shape[1], input_shape[1],
self.size[0] * input_shape[2], width,
self.size[1] * input_shape[3]) height)
elif self.dim_ordering == 'tf': elif self.dim_ordering == 'tf':
width = self.size[0] * input_shape[1] if input_shape[1] is not None else None
height = self.size[1] * input_shape[2] if input_shape[2] is not None else None
return (input_shape[0], return (input_shape[0],
self.size[0] * input_shape[1], width,
self.size[1] * input_shape[2], height,
input_shape[3]) input_shape[3])
else: else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
@ -1153,16 +1158,22 @@ class UpSampling3D(Layer):
def get_output_shape_for(self, input_shape): def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th': if self.dim_ordering == 'th':
dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
return (input_shape[0], return (input_shape[0],
input_shape[1], input_shape[1],
self.size[0] * input_shape[2], dim1,
self.size[1] * input_shape[3], dim2,
self.size[2] * input_shape[4]) dim3)
elif self.dim_ordering == 'tf': elif self.dim_ordering == 'tf':
dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
return (input_shape[0], return (input_shape[0],
self.size[0] * input_shape[1], dim1,
self.size[1] * input_shape[2], dim2,
self.size[2] * input_shape[3], dim3,
input_shape[4]) input_shape[4])
else: else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) raise Exception('Invalid dim_ordering: ' + self.dim_ordering)