Style fixes in tests

This commit is contained in:
Francois Chollet 2016-08-10 16:44:29 -07:00
parent 4805e5856b
commit ec6eda77ad

@ -92,14 +92,15 @@ class TestBackend(object):
check_single_tensor_operation('expand_dims', (4, 3), dim=-1) check_single_tensor_operation('expand_dims', (4, 3), dim=-1)
check_single_tensor_operation('expand_dims', (4, 3, 2), dim=1) check_single_tensor_operation('expand_dims', (4, 3, 2), dim=1)
check_single_tensor_operation('squeeze', (4, 3, 1), axis=2) check_single_tensor_operation('squeeze', (4, 3, 1), axis=2)
check_composed_tensor_operations('reshape', {'shape':(4,3,1,1)}, check_single_tensor_operation('squeeze', (4, 1, 1), axis=1)
'squeeze', {'axis':2}, check_composed_tensor_operations('reshape', {'shape': (4, 3, 1, 1)},
'squeeze', {'axis': 2},
(4, 3, 1, 1)) (4, 3, 1, 1))
def test_repeat_elements(self): def test_repeat_elements(self):
reps = 3 reps = 3
for ndims in [1, 2, 3]: for ndims in [1, 2, 3]:
shape = np.arange(2, 2+ndims) shape = np.arange(2, 2 + ndims)
arr = np.arange(np.prod(shape)).reshape(shape) arr = np.arange(np.prod(shape)).reshape(shape)
arr_th = KTH.variable(arr) arr_th = KTH.variable(arr)
arr_tf = KTF.variable(arr) arr_tf = KTF.variable(arr)