Bug fix : squeeze (#3433)
This commit is contained in:
parent
69d5139b8c
commit
55447cbb3d
@ -504,11 +504,9 @@ def expand_dims(x, dim=-1):
|
||||
def squeeze(x, axis):
|
||||
'''Remove a 1-dimension from the tensor at index "axis".
|
||||
'''
|
||||
broadcastable = x.broadcastable[:axis] + x.broadcastable[axis+1:]
|
||||
x = T.patternbroadcast(x, [i == axis for i in range(x.type.ndim)])
|
||||
x = T.squeeze(x)
|
||||
x = T.patternbroadcast(x, broadcastable)
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
shape.pop(axis)
|
||||
return T.reshape(x, tuple(shape))
|
||||
|
||||
|
||||
def temporal_padding(x, padding=1):
|
||||
|
Loading…
Reference in New Issue
Block a user