From 79406f111bfe2041b9d288a0b152daf6447235ab Mon Sep 17 00:00:00 2001 From: Julien Phalip Date: Fri, 16 Dec 2016 16:59:01 -0800 Subject: [PATCH] Make sure that changes to the global floatx are effectively taken into account by the backend. (#4739) --- keras/backend/tensorflow_backend.py | 61 ++++++++++++++++++---------- keras/backend/theano_backend.py | 40 ++++++++++++------ tests/keras/backend/test_backends.py | 50 ++++++++++++++++++++++- 3 files changed, 117 insertions(+), 34 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index dc0ca8797..319858aee 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -12,7 +12,7 @@ import numpy as np import os import copy import warnings -from .common import _FLOATX, _EPSILON, image_dim_ordering, reset_uids +from .common import floatx, _EPSILON, image_dim_ordering, reset_uids py_all = all # INTERNAL UTILS @@ -207,7 +207,7 @@ def to_dense(tensor): return tensor -def variable(value, dtype=_FLOATX, name=None): +def variable(value, dtype=None, name=None): '''Instantiates a variable and returns it. # Arguments @@ -232,6 +232,8 @@ def variable(value, dtype=_FLOATX, name=None): [ 3., 4.]]) ``` ''' + if dtype is None: + dtype = floatx() if hasattr(value, 'tocoo'): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), @@ -271,7 +273,7 @@ def _initialize_variables(): sess.run(tf.initialize_variables(uninitialized_variables)) -def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None): +def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None): '''Instantiates a placeholder tensor and returns it. # Arguments @@ -296,6 +298,8 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None): ``` ''' + if dtype is None: + dtype = floatx() if not shape: if ndim: shape = tuple([None for _ in range(ndim)]) @@ -448,7 +452,7 @@ def eval(x): return to_dense(x).eval(session=get_session()) -def zeros(shape, dtype=_FLOATX, name=None): +def zeros(shape, dtype=None, name=None): '''Instantiates an all-zeros variable and returns it. # Arguments @@ -469,13 +473,15 @@ def zeros(shape, dtype=_FLOATX, name=None): [ 0., 0., 0., 0.]], dtype=float32) ``` ''' + if dtype is None: + dtype = floatx() shape = tuple(map(int, shape)) tf_dtype = _convert_string_dtype(dtype) return variable(tf.constant_initializer(0., dtype=tf_dtype)(shape), dtype, name) -def ones(shape, dtype=_FLOATX, name=None): +def ones(shape, dtype=None, name=None): '''Instantiates an all-ones tensor variable and returns it. # Arguments @@ -498,13 +504,15 @@ def ones(shape, dtype=_FLOATX, name=None): [ 1., 1., 1., 1.]], dtype=float32) ``` ''' + if dtype is None: + dtype = floatx() shape = tuple(map(int, shape)) tf_dtype = _convert_string_dtype(dtype) return variable(tf.constant_initializer(1., dtype=tf_dtype)(shape), dtype, name) -def eye(size, dtype=_FLOATX, name=None): +def eye(size, dtype=None, name=None): '''Instantiate an identity matrix and returns it. # Arguments @@ -577,7 +585,7 @@ def ones_like(x, name=None): return tf.ones_like(x, name=name) -def random_uniform_variable(shape, low, high, dtype=_FLOATX, +def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): '''Instantiates an Keras variable filled with samples drawn from a uniform distribution and returns it. @@ -609,6 +617,8 @@ def random_uniform_variable(shape, low, high, dtype=_FLOATX, [ 0.66137183, 0.00869417, 0.89220798]], dtype=float32) ``` ''' + if dtype is None: + dtype = floatx() shape = tuple(map(int, shape)) tf_dtype = _convert_string_dtype(dtype) if seed is None: @@ -619,7 +629,7 @@ def random_uniform_variable(shape, low, high, dtype=_FLOATX, return variable(value, dtype=dtype, name=name) -def random_normal_variable(shape, mean, scale, dtype=_FLOATX, +def random_normal_variable(shape, mean, scale, dtype=None, name=None, seed=None): '''Instantiates an Keras variable filled with samples drawn from a normal distribution and returns it. @@ -651,6 +661,8 @@ def random_normal_variable(shape, mean, scale, dtype=_FLOATX, [ 0.92629528, 0.28055015, 1.70484698]], dtype=float32) ``` ''' + if dtype is None: + dtype = floatx() shape = tuple(map(int, shape)) tf_dtype = _convert_string_dtype(dtype) if seed is None: @@ -963,7 +975,7 @@ def var(x, axis=None, keepdims=False): ''' axis = _normalize_axis(axis, ndim(x)) if x.dtype.base_dtype == tf.bool: - x = tf.cast(x, _FLOATX) + x = tf.cast(x, floatx()) m = tf.reduce_mean(x, reduction_indices=axis, keep_dims=True) devs_squared = tf.square(x - m) return tf.reduce_mean(devs_squared, @@ -982,7 +994,7 @@ def mean(x, axis=None, keepdims=False): ''' axis = _normalize_axis(axis, ndim(x)) if x.dtype.base_dtype == tf.bool: - x = tf.cast(x, _FLOATX) + x = tf.cast(x, floatx()) return tf.reduce_mean(x, reduction_indices=axis, keep_dims=keepdims) @@ -2073,7 +2085,7 @@ def _preprocess_deconv_output_shape(shape, dim_ordering): def _preprocess_conv2d_input(x, dim_ordering): - if _FLOATX == 'float64': + if dtype(x) == 'float64': x = tf.cast(x, 'float32') if dim_ordering == 'th': # TF uses the last dimension as channel dimension, @@ -2085,7 +2097,7 @@ def _preprocess_conv2d_input(x, dim_ordering): def _preprocess_conv3d_input(x, dim_ordering): - if _FLOATX == 'float64': + if dtype(x) == 'float64': x = tf.cast(x, 'float32') if dim_ordering == 'th': # TF uses the last dimension as channel dimension, @@ -2097,7 +2109,7 @@ def _preprocess_conv3d_input(x, dim_ordering): def _preprocess_conv2d_kernel(kernel, dim_ordering): - if _FLOATX == 'float64': + if dtype(kernel) == 'float64': kernel = tf.cast(kernel, 'float32') if dim_ordering == 'th': # TF uses the last dimension as channel dimension, @@ -2109,7 +2121,7 @@ def _preprocess_conv2d_kernel(kernel, dim_ordering): def _preprocess_conv3d_kernel(kernel, dim_ordering): - if _FLOATX == 'float64': + if dtype(kernel) == 'float64': kernel = tf.cast(kernel, 'float32') if dim_ordering == 'th': # TF uses the last dimension as channel dimension, @@ -2134,7 +2146,7 @@ def _postprocess_conv2d_output(x, dim_ordering): if dim_ordering == 'th': x = tf.transpose(x, (0, 3, 1, 2)) - if _FLOATX == 'float64': + if floatx() == 'float64': x = tf.cast(x, 'float64') return x @@ -2143,7 +2155,7 @@ def _postprocess_conv3d_output(x, dim_ordering): if dim_ordering == 'th': x = tf.transpose(x, (0, 4, 1, 2, 3)) - if _FLOATX == 'float64': + if floatx() == 'float64': x = tf.cast(x, 'float64') return x @@ -2158,13 +2170,14 @@ def conv1d(x, kernel, stride=1, border_mode='valid', border_mode: string, "same" or "valid". ''' # pre-process dtype - if _FLOATX == 'float64': + x_dtype = dtype(x) + if x_dtype == 'float64': x = tf.cast(x, 'float32') kernel = tf.cast(kernel, 'float32') padding = _preprocess_border_mode(border_mode) x = tf.nn.conv1d(x, kernel, stride, padding=padding) # post-process dtype - if _FLOATX == 'float64': + if x_dtype == 'float64': x = tf.cast(x, 'float64') return x @@ -2367,21 +2380,27 @@ def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid', # RANDOMNESS -def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None): +def random_normal(shape, mean=0.0, std=1.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(10e6) return tf.random_normal(shape, mean=mean, stddev=std, dtype=dtype, seed=seed) -def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None): +def random_uniform(shape, low=0.0, high=1.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(10e6) return tf.random_uniform(shape, minval=low, maxval=high, dtype=dtype, seed=seed) -def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None): +def random_binomial(shape, p=0.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(10e6) return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p, diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 6773c2097..5985bc0b7 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -14,7 +14,7 @@ except ImportError: from theano.sandbox.softsign import softsign as T_softsign import inspect import numpy as np -from .common import _FLOATX, _EPSILON, image_dim_ordering +from .common import _FLOATX, floatx, _EPSILON, image_dim_ordering py_all = all @@ -56,9 +56,11 @@ def to_dense(tensor): return tensor -def variable(value, dtype=_FLOATX, name=None): +def variable(value, dtype=None, name=None): '''Instantiates a variable. ''' + if dtype is None: + dtype = floatx() if hasattr(value, 'tocoo'): _assert_sparse_module() return th_sparse_module.as_sparse_variable(value) @@ -67,9 +69,11 @@ def variable(value, dtype=_FLOATX, name=None): return theano.shared(value=value, name=name, strict=False) -def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None): +def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None): '''Instantiate an input data placeholder variable. ''' + if dtype is None: + dtype = floatx() if shape is None and ndim is None: raise ValueError('Specify either a shape or ndim value.') if shape is not None: @@ -111,21 +115,27 @@ def eval(x): return to_dense(x).eval() -def zeros(shape, dtype=_FLOATX, name=None): +def zeros(shape, dtype=None, name=None): '''Instantiates an all-zeros variable. ''' + if dtype is None: + dtype = floatx() return variable(np.zeros(shape), dtype, name) -def ones(shape, dtype=_FLOATX, name=None): +def ones(shape, dtype=None, name=None): '''Instantiates an all-ones variable. ''' + if dtype is None: + dtype = floatx() return variable(np.ones(shape), dtype, name) -def eye(size, dtype=_FLOATX, name=None): +def eye(size, dtype=None, name=None): '''Instantiates an identity matrix. ''' + if dtype is None: + dtype = floatx() return variable(np.eye(size), dtype, name) @@ -137,12 +147,12 @@ def zeros_like(x, name=None): return T.zeros_like(x) -def random_uniform_variable(shape, low, high, dtype=_FLOATX, name=None): +def random_uniform_variable(shape, low, high, dtype=None, name=None): return variable(np.random.uniform(low=low, high=high, size=shape), dtype=dtype, name=name) -def random_normal_variable(shape, mean, scale, dtype=_FLOATX, name=None): +def random_normal_variable(shape, mean, scale, dtype=None, name=None): return variable(np.random.normal(loc=0.0, scale=scale, size=shape), dtype=dtype, name=name) @@ -284,7 +294,7 @@ def mean(x, axis=None, keepdims=False): dtype = None # bool is available since theano v0.9dev if 'int' in x.dtype or x.dtype == 'bool': - dtype = _FLOATX + dtype = floatx() return T.mean(x, axis=axis, keepdims=keepdims, dtype=dtype) @@ -1799,21 +1809,27 @@ def _old_theano_pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid', # RANDOMNESS -def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None): +def random_normal(shape, mean=0.0, std=1.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(1, 10e6) rng = RandomStreams(seed=seed) return rng.normal(size=shape, avg=mean, std=std, dtype=dtype) -def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None): +def random_uniform(shape, low=0.0, high=1.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(1, 10e6) rng = RandomStreams(seed=seed) return rng.uniform(shape, low=low, high=high, dtype=dtype) -def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None): +def random_binomial(shape, p=0.0, dtype=None, seed=None): + if dtype is None: + dtype = floatx() if seed is None: seed = np.random.randint(1, 10e6) rng = RandomStreams(seed=seed) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index 651abf6c8..b0dccd671 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -3,11 +3,19 @@ from numpy.testing import assert_allclose import numpy as np import scipy.sparse as sparse -from keras.backend import theano_backend as KTH +from keras import backend as K +from keras.backend import theano_backend as KTH, floatx, set_floatx, variable from keras.backend import tensorflow_backend as KTF from keras.utils.np_utils import convert_kernel +def check_dtype(var, dtype): + if K._BACKEND == 'theano': + assert var.dtype == dtype + else: + assert var.dtype.name == '%s_ref' % dtype + + def check_single_tensor_operation(function_name, input_shape, **kwargs): val = np.random.random(input_shape) - 0.5 xth = KTH.variable(val) @@ -930,6 +938,46 @@ class TestBackend(object): t = backend.arange(10, dtype=dtype) assert backend.dtype(t) == dtype + def test_setfloatx_incorrect_values(self): + # Keep track of the old value + old_floatx = floatx() + # Try some incorrect values + initial = floatx() + for value in ['', 'beerfloat', 123]: + with pytest.raises(Exception): + set_floatx(value) + assert floatx() == initial + # Restore old value + set_floatx(old_floatx) + + def test_setfloatx_correct_values(self): + # Keep track of the old value + old_floatx = floatx() + # Check correct values + for value in ['float16', 'float32', 'float64']: + set_floatx(value) + assert floatx() == value + # Restore old value + set_floatx(old_floatx) + + def test_set_floatx(self): + """ + Make sure that changes to the global floatx are effectively + taken into account by the backend. + """ + # Keep track of the old value + old_floatx = floatx() + + set_floatx('float16') + var = variable([10]) + check_dtype(var, 'float16') + + set_floatx('float64') + var = variable([10]) + check_dtype(var, 'float64') + + # Restore old value + set_floatx(old_floatx) if __name__ == '__main__': pytest.main([__file__])