Bug fix : squeeze (#3433)

This commit is contained in:
Fariz Rahman 2016-08-10 02:29:47 +05:30 committed by François Chollet
parent 69d5139b8c
commit 55447cbb3d

@ -504,11 +504,9 @@ def expand_dims(x, dim=-1):
def squeeze(x, axis):
'''Remove a 1-dimension from the tensor at index "axis".
'''
broadcastable = x.broadcastable[:axis] + x.broadcastable[axis+1:]
x = T.patternbroadcast(x, [i == axis for i in range(x.type.ndim)])
x = T.squeeze(x)
x = T.patternbroadcast(x, broadcastable)
return x
shape = list(x.shape)
shape.pop(axis)
return T.reshape(x, tuple(shape))
def temporal_padding(x, padding=1):