diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 7afe26261..eade21f5d 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -504,11 +504,9 @@ def expand_dims(x, dim=-1): def squeeze(x, axis): '''Remove a 1-dimension from the tensor at index "axis". ''' - broadcastable = x.broadcastable[:axis] + x.broadcastable[axis+1:] - x = T.patternbroadcast(x, [i == axis for i in range(x.type.ndim)]) - x = T.squeeze(x) - x = T.patternbroadcast(x, broadcastable) - return x + shape = list(x.shape) + shape.pop(axis) + return T.reshape(x, tuple(shape)) def temporal_padding(x, padding=1):