Upsampling layer fix to work when input shape is None (#3429)
This commit is contained in:
parent
55447cbb3d
commit
4805e5856b
@ -1045,7 +1045,8 @@ class UpSampling1D(Layer):
|
||||
super(UpSampling1D, self).__init__(**kwargs)
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
return (input_shape[0], self.length * input_shape[1], input_shape[2])
|
||||
length = self.length * input_shape[1] if input_shape[1] is not None else None
|
||||
return (input_shape[0], length, input_shape[2])
|
||||
|
||||
def call(self, x, mask=None):
|
||||
output = K.repeat_elements(x, self.length, axis=1)
|
||||
@ -1094,14 +1095,18 @@ class UpSampling2D(Layer):
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
if self.dim_ordering == 'th':
|
||||
width = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
||||
height = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
||||
return (input_shape[0],
|
||||
input_shape[1],
|
||||
self.size[0] * input_shape[2],
|
||||
self.size[1] * input_shape[3])
|
||||
width,
|
||||
height)
|
||||
elif self.dim_ordering == 'tf':
|
||||
width = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
||||
height = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
||||
return (input_shape[0],
|
||||
self.size[0] * input_shape[1],
|
||||
self.size[1] * input_shape[2],
|
||||
width,
|
||||
height,
|
||||
input_shape[3])
|
||||
else:
|
||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||
@ -1153,16 +1158,22 @@ class UpSampling3D(Layer):
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
if self.dim_ordering == 'th':
|
||||
dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
||||
dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
||||
dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
|
||||
return (input_shape[0],
|
||||
input_shape[1],
|
||||
self.size[0] * input_shape[2],
|
||||
self.size[1] * input_shape[3],
|
||||
self.size[2] * input_shape[4])
|
||||
dim1,
|
||||
dim2,
|
||||
dim3)
|
||||
elif self.dim_ordering == 'tf':
|
||||
dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
||||
dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
||||
dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
|
||||
return (input_shape[0],
|
||||
self.size[0] * input_shape[1],
|
||||
self.size[1] * input_shape[2],
|
||||
self.size[2] * input_shape[3],
|
||||
dim1,
|
||||
dim2,
|
||||
dim3,
|
||||
input_shape[4])
|
||||
else:
|
||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||
|
Loading…
Reference in New Issue
Block a user