Set default backend to TF

This commit is contained in:
Francois Chollet 2016-09-13 16:24:43 -07:00
parent d90e1db50b
commit 82318263a1
8 changed files with 55 additions and 52 deletions

@ -23,7 +23,7 @@ _keras_dir = os.path.join(_keras_base_dir, '.keras')
if not os.path.exists(_keras_dir):
os.makedirs(_keras_dir)
_BACKEND = 'theano'
_BACKEND = 'tensorflow'
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
_config = json.load(open(_config_path))

@ -6,7 +6,7 @@ from collections import defaultdict
_FLOATX = 'float32'
_EPSILON = 10e-8
_UID_PREFIXES = defaultdict(int)
_IMAGE_DIM_ORDERING = 'th'
_IMAGE_DIM_ORDERING = 'tf'
_LEGACY_WEIGHT_ORDERING = False

@ -844,7 +844,7 @@ def temporal_padding(x, padding=1):
return tf.pad(x, pattern)
def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
def spatial_2d_padding(x, padding=(1, 1), dim_ordering=_IMAGE_DIM_ORDERING):
'''Pads the 2nd and 3rd dimensions of a 4D tensor
with "padding[0]" and "padding[1]" (resp.) zeros left and right.
'''
@ -858,7 +858,7 @@ def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
return tf.pad(x, pattern)
def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering=_IMAGE_DIM_ORDERING):
'''Pads 5D tensor with zeros for the depth, height, width dimension with
"padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right

@ -573,7 +573,7 @@ def temporal_padding(x, padding=1):
return T.set_subtensor(output[:, padding:x.shape[1] + padding, :], x)
def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
def spatial_2d_padding(x, padding=(1, 1), dim_ordering=_IMAGE_DIM_ORDERING):
'''Pad the 2nd and 3rd dimensions of a 4D tensor
with "padding[0]" and "padding[1]" (resp.) zeros left and right.
'''
@ -604,7 +604,7 @@ def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
return T.set_subtensor(output[indices], x)
def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering=_IMAGE_DIM_ORDERING):
'''Pad the 2nd, 3rd and 4th dimensions of a 5D tensor
with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right.
'''
@ -1197,7 +1197,7 @@ def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
def conv3d(x, kernel, strides=(1, 1, 1),
border_mode='valid', dim_ordering='th',
border_mode='valid', dim_ordering=_IMAGE_DIM_ORDERING,
volume_shape=None, filter_shape=None):
'''
Run on cuDNN if available.
@ -1259,7 +1259,7 @@ def conv3d(x, kernel, strides=(1, 1, 1),
def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
dim_ordering=_IMAGE_DIM_ORDERING, pool_mode='max'):
if border_mode == 'same':
w_pad = pool_size[0] - 2 if pool_size[0] % 2 == 1 else pool_size[0] - 1
h_pad = pool_size[1] - 2 if pool_size[1] % 2 == 1 else pool_size[1] - 1
@ -1302,7 +1302,7 @@ def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
dim_ordering=_IMAGE_DIM_ORDERING, pool_mode='max'):
if border_mode == 'same':
# TODO: add implementation for border_mode="same"
raise Exception('border_mode="same" not supported with Theano.')

@ -16,7 +16,7 @@ def test_image_classification():
with convolutional hidden layer.
'''
np.random.seed(1337)
input_shape = (3, 16, 16)
input_shape = (16, 16, 3)
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=500,
nb_test=200,
input_shape=input_shape,

@ -487,8 +487,8 @@ class TestBackend(object):
kernel_th = KTH.variable(convert_kernel(kernel_val))
kernel_tf = KTF.variable(kernel_val)
zth = KTH.eval(KTH.conv2d(xth, kernel_th))
ztf = KTF.eval(KTF.conv2d(xtf, kernel_tf))
zth = KTH.eval(KTH.conv2d(xth, kernel_th, dim_ordering='th'))
ztf = KTF.eval(KTF.conv2d(xtf, kernel_tf, dim_ordering='th'))
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)
@ -531,8 +531,8 @@ class TestBackend(object):
kernel_th = KTH.variable(convert_kernel(kernel_val))
kernel_tf = KTF.variable(kernel_val)
zth = KTH.eval(KTH.conv3d(xth, kernel_th))
ztf = KTF.eval(KTF.conv3d(xtf, kernel_tf))
zth = KTH.eval(KTH.conv3d(xth, kernel_th, dim_ordering='th'))
ztf = KTF.eval(KTF.conv3d(xtf, kernel_tf, dim_ordering='th'))
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)
@ -558,23 +558,23 @@ class TestBackend(object):
assert_allclose(zth, ztf, atol=1e-05)
def test_pool2d(self):
check_single_tensor_operation('pool2d', (5, 3, 10, 12), pool_size=(2, 2),
check_single_tensor_operation('pool2d', (5, 10, 12, 3), pool_size=(2, 2),
strides=(1, 1), border_mode='valid')
check_single_tensor_operation('pool2d', (5, 3, 9, 11), pool_size=(2, 2),
check_single_tensor_operation('pool2d', (5, 9, 11, 3), pool_size=(2, 2),
strides=(1, 1), border_mode='valid')
check_single_tensor_operation('pool2d', (5, 3, 9, 11), pool_size=(2, 3),
check_single_tensor_operation('pool2d', (5, 9, 11, 3), pool_size=(2, 3),
strides=(1, 1), border_mode='valid')
def test_pool3d(self):
check_single_tensor_operation('pool3d', (5, 3, 10, 12, 5), pool_size=(2, 2, 2),
check_single_tensor_operation('pool3d', (5, 10, 12, 5, 3), pool_size=(2, 2, 2),
strides=(1, 1, 1), border_mode='valid')
check_single_tensor_operation('pool3d', (5, 3, 9, 11, 5), pool_size=(2, 2, 2),
check_single_tensor_operation('pool3d', (5, 9, 11, 5, 3), pool_size=(2, 2, 2),
strides=(1, 1, 1), border_mode='valid')
check_single_tensor_operation('pool3d', (5, 3, 9, 11, 5), pool_size=(2, 3, 2),
check_single_tensor_operation('pool3d', (5, 9, 11, 5, 3), pool_size=(2, 3, 2),
strides=(1, 1, 1), border_mode='valid')
def test_random_normal(self):

@ -75,7 +75,7 @@ def test_convolution_2d():
'nb_col': 3,
'border_mode': border_mode,
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
layer_test(convolutional.Convolution2D,
kwargs={'nb_filter': nb_filter,
@ -86,7 +86,7 @@ def test_convolution_2d():
'b_regularizer': 'l2',
'activity_regularizer': 'activity_l2',
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
@keras_test
@ -108,23 +108,23 @@ def test_deconvolution_2d():
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'output_shape': (nb_samples, nb_filter, rows, cols),
'output_shape': (nb_samples, rows, cols, nb_filter),
'border_mode': border_mode,
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col),
input_shape=(nb_samples, nb_row, nb_col, stack_size),
fixed_batch_size=True)
layer_test(convolutional.Deconvolution2D,
kwargs={'nb_filter': nb_filter,
'nb_row': 3,
'nb_col': 3,
'output_shape': (nb_samples, nb_filter, rows, cols),
'output_shape': (nb_samples, rows, cols, nb_filter),
'border_mode': border_mode,
'W_regularizer': 'l2',
'b_regularizer': 'l2',
'activity_regularizer': 'activity_l2',
'subsample': subsample},
input_shape=(nb_samples, stack_size, nb_row, nb_col),
input_shape=(nb_samples, nb_row, nb_col, stack_size),
fixed_batch_size=True)
@ -151,7 +151,7 @@ def test_atrous_conv_2d():
'border_mode': border_mode,
'subsample': subsample,
'atrous_rate': atrous_rate},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
layer_test(convolutional.AtrousConv2D,
kwargs={'nb_filter': nb_filter,
@ -163,7 +163,7 @@ def test_atrous_conv_2d():
'activity_regularizer': 'activity_l2',
'subsample': subsample,
'atrous_rate': atrous_rate},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
@pytest.mark.skipif(K._BACKEND != 'tensorflow', reason="Requires TF backend")
@ -188,7 +188,7 @@ def test_separable_conv_2d():
'border_mode': border_mode,
'subsample': subsample,
'depth_multiplier': multiplier},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
layer_test(convolutional.SeparableConv2D,
kwargs={'nb_filter': nb_filter,
@ -203,7 +203,7 @@ def test_separable_conv_2d():
'depthwise_constraint': 'unitnorm',
'subsample': subsample,
'depth_multiplier': multiplier},
input_shape=(nb_samples, stack_size, nb_row, nb_col))
input_shape=(nb_samples, nb_row, nb_col, stack_size))
@keras_test
@ -239,7 +239,7 @@ def test_maxpooling_2d():
kwargs={'strides': strides,
'border_mode': 'valid',
'pool_size': pool_size},
input_shape=(3, 4, 11, 12))
input_shape=(3, 11, 12, 4))
@keras_test
@ -253,7 +253,7 @@ def test_averagepooling_2d():
kwargs={'strides': strides,
'border_mode': border_mode,
'pool_size': pool_size},
input_shape=(3, 4, 11, 12))
input_shape=(3, 11, 12, 4))
@keras_test
@ -281,8 +281,9 @@ def test_convolution_3d():
'kernel_dim3': kernel_dim3,
'border_mode': border_mode,
'subsample': subsample},
input_shape=(nb_samples, stack_size,
input_len_dim1, input_len_dim2, input_len_dim3))
input_shape=(nb_samples,
input_len_dim1, input_len_dim2, input_len_dim3,
stack_size))
layer_test(convolutional.Convolution3D,
kwargs={'nb_filter': nb_filter,
@ -294,8 +295,9 @@ def test_convolution_3d():
'b_regularizer': 'l2',
'activity_regularizer': 'activity_l2',
'subsample': subsample},
input_shape=(nb_samples, stack_size,
input_len_dim1, input_len_dim2, input_len_dim3))
input_shape=(nb_samples,
input_len_dim1, input_len_dim2, input_len_dim3,
stack_size))
@keras_test
@ -329,7 +331,7 @@ def test_zero_padding_2d():
input_nb_row = 11
input_nb_col = 12
input = np.ones((nb_samples, stack_size, input_nb_row, input_nb_col))
input = np.ones((nb_samples, input_nb_row, input_nb_col, stack_size))
# basic test
layer_test(convolutional.ZeroPadding2D,
@ -342,9 +344,9 @@ def test_zero_padding_2d():
out = K.eval(layer.output)
for offset in [0, 1, -1, -2]:
assert_allclose(out[:, offset, :, :], 0.)
assert_allclose(out[:, :, offset, :], 0.)
assert_allclose(out[:, :, :, offset], 0.)
assert_allclose(out[:, :, 2:-2, 2:-2], 1.)
assert_allclose(out[:, 2:-2, 2:-2, :], 1.)
layer.get_config()
@ -355,8 +357,9 @@ def test_zero_padding_3d():
input_len_dim2 = 11
input_len_dim3 = 12
input = np.ones((nb_samples, stack_size, input_len_dim1,
input_len_dim2, input_len_dim3))
input = np.ones((nb_samples,
input_len_dim1, input_len_dim2, input_len_dim3,
stack_size))
# basic test
layer_test(convolutional.ZeroPadding3D,
@ -368,10 +371,10 @@ def test_zero_padding_3d():
layer.set_input(K.variable(input), shape=input.shape)
out = K.eval(layer.output)
for offset in [0, 1, -1, -2]:
assert_allclose(out[:, offset, :, :, :], 0.)
assert_allclose(out[:, :, offset, :, :], 0.)
assert_allclose(out[:, :, :, offset, :], 0.)
assert_allclose(out[:, :, :, :, offset], 0.)
assert_allclose(out[:, :, 2:-2, 2:-2, 2:-2], 1.)
assert_allclose(out[:, 2:-2, 2:-2, 2:-2, :], 1.)
layer.get_config()

@ -43,10 +43,10 @@ def test_TimeDistributed():
# test with Convolution2D
model = Sequential()
model.add(wrappers.TimeDistributed(convolutional.Convolution2D(5, 2, 2, border_mode='same'), input_shape=(2, 3, 4, 4)))
model.add(wrappers.TimeDistributed(convolutional.Convolution2D(5, 2, 2, border_mode='same'), input_shape=(2, 4, 4, 3)))
model.add(core.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(np.random.random((1, 2, 3, 4, 4)), np.random.random((1, 2, 5, 4, 4)))
model.train_on_batch(np.random.random((1, 2, 4, 4, 3)), np.random.random((1, 2, 4, 4, 5)))
model = model_from_json(model.to_json())
model.summary()