diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 019b66d14..55b0e3ec6 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -1045,7 +1045,8 @@ class UpSampling1D(Layer): super(UpSampling1D, self).__init__(**kwargs) def get_output_shape_for(self, input_shape): - return (input_shape[0], self.length * input_shape[1], input_shape[2]) + length = self.length * input_shape[1] if input_shape[1] is not None else None + return (input_shape[0], length, input_shape[2]) def call(self, x, mask=None): output = K.repeat_elements(x, self.length, axis=1) @@ -1094,14 +1095,18 @@ class UpSampling2D(Layer): def get_output_shape_for(self, input_shape): if self.dim_ordering == 'th': + width = self.size[0] * input_shape[2] if input_shape[2] is not None else None + height = self.size[1] * input_shape[3] if input_shape[3] is not None else None return (input_shape[0], input_shape[1], - self.size[0] * input_shape[2], - self.size[1] * input_shape[3]) + width, + height) elif self.dim_ordering == 'tf': + width = self.size[0] * input_shape[1] if input_shape[1] is not None else None + height = self.size[1] * input_shape[2] if input_shape[2] is not None else None return (input_shape[0], - self.size[0] * input_shape[1], - self.size[1] * input_shape[2], + width, + height, input_shape[3]) else: raise Exception('Invalid dim_ordering: ' + self.dim_ordering) @@ -1153,16 +1158,22 @@ class UpSampling3D(Layer): def get_output_shape_for(self, input_shape): if self.dim_ordering == 'th': + dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None + dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None + dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None return (input_shape[0], input_shape[1], - self.size[0] * input_shape[2], - self.size[1] * input_shape[3], - self.size[2] * input_shape[4]) + dim1, + dim2, + dim3) elif self.dim_ordering == 'tf': + dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None + dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None + dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None return (input_shape[0], - self.size[0] * input_shape[1], - self.size[1] * input_shape[2], - self.size[2] * input_shape[3], + dim1, + dim2, + dim3, input_shape[4]) else: raise Exception('Invalid dim_ordering: ' + self.dim_ordering)