Upsampling layer fix to work when input shape is None (#3429)
This commit is contained in:
parent
55447cbb3d
commit
4805e5856b
@ -1045,7 +1045,8 @@ class UpSampling1D(Layer):
|
|||||||
super(UpSampling1D, self).__init__(**kwargs)
|
super(UpSampling1D, self).__init__(**kwargs)
|
||||||
|
|
||||||
def get_output_shape_for(self, input_shape):
|
def get_output_shape_for(self, input_shape):
|
||||||
return (input_shape[0], self.length * input_shape[1], input_shape[2])
|
length = self.length * input_shape[1] if input_shape[1] is not None else None
|
||||||
|
return (input_shape[0], length, input_shape[2])
|
||||||
|
|
||||||
def call(self, x, mask=None):
|
def call(self, x, mask=None):
|
||||||
output = K.repeat_elements(x, self.length, axis=1)
|
output = K.repeat_elements(x, self.length, axis=1)
|
||||||
@ -1094,14 +1095,18 @@ class UpSampling2D(Layer):
|
|||||||
|
|
||||||
def get_output_shape_for(self, input_shape):
|
def get_output_shape_for(self, input_shape):
|
||||||
if self.dim_ordering == 'th':
|
if self.dim_ordering == 'th':
|
||||||
|
width = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
||||||
|
height = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
||||||
return (input_shape[0],
|
return (input_shape[0],
|
||||||
input_shape[1],
|
input_shape[1],
|
||||||
self.size[0] * input_shape[2],
|
width,
|
||||||
self.size[1] * input_shape[3])
|
height)
|
||||||
elif self.dim_ordering == 'tf':
|
elif self.dim_ordering == 'tf':
|
||||||
|
width = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
||||||
|
height = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
||||||
return (input_shape[0],
|
return (input_shape[0],
|
||||||
self.size[0] * input_shape[1],
|
width,
|
||||||
self.size[1] * input_shape[2],
|
height,
|
||||||
input_shape[3])
|
input_shape[3])
|
||||||
else:
|
else:
|
||||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||||
@ -1153,16 +1158,22 @@ class UpSampling3D(Layer):
|
|||||||
|
|
||||||
def get_output_shape_for(self, input_shape):
|
def get_output_shape_for(self, input_shape):
|
||||||
if self.dim_ordering == 'th':
|
if self.dim_ordering == 'th':
|
||||||
|
dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
||||||
|
dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
||||||
|
dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
|
||||||
return (input_shape[0],
|
return (input_shape[0],
|
||||||
input_shape[1],
|
input_shape[1],
|
||||||
self.size[0] * input_shape[2],
|
dim1,
|
||||||
self.size[1] * input_shape[3],
|
dim2,
|
||||||
self.size[2] * input_shape[4])
|
dim3)
|
||||||
elif self.dim_ordering == 'tf':
|
elif self.dim_ordering == 'tf':
|
||||||
|
dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
||||||
|
dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
||||||
|
dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
|
||||||
return (input_shape[0],
|
return (input_shape[0],
|
||||||
self.size[0] * input_shape[1],
|
dim1,
|
||||||
self.size[1] * input_shape[2],
|
dim2,
|
||||||
self.size[2] * input_shape[3],
|
dim3,
|
||||||
input_shape[4])
|
input_shape[4])
|
||||||
else:
|
else:
|
||||||
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
|
||||||
|
Loading…
Reference in New Issue
Block a user