diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index 0afb68975..d236d51ff 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -92,14 +92,15 @@ class TestBackend(object): check_single_tensor_operation('expand_dims', (4, 3), dim=-1) check_single_tensor_operation('expand_dims', (4, 3, 2), dim=1) check_single_tensor_operation('squeeze', (4, 3, 1), axis=2) - check_composed_tensor_operations('reshape', {'shape':(4,3,1,1)}, - 'squeeze', {'axis':2}, + check_single_tensor_operation('squeeze', (4, 1, 1), axis=1) + check_composed_tensor_operations('reshape', {'shape': (4, 3, 1, 1)}, + 'squeeze', {'axis': 2}, (4, 3, 1, 1)) def test_repeat_elements(self): reps = 3 for ndims in [1, 2, 3]: - shape = np.arange(2, 2+ndims) + shape = np.arange(2, 2 + ndims) arr = np.arange(np.prod(shape)).reshape(shape) arr_th = KTH.variable(arr) arr_tf = KTF.variable(arr)