From 037e592f2ba7c18b71bc9b39f84de11af0252863 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 8 Jan 2016 10:02:28 -0800 Subject: [PATCH] Naming, batch_flatten --- keras/backend/tensorflow_backend.py | 10 +++++- keras/backend/theano_backend.py | 8 +++++ keras/initializations.py | 46 ++++++++++++------------ keras/layers/core.py | 29 +++++++++++---- keras/utils/layer_utils.py | 6 +++- tests/keras/layers/test_core.py | 15 ++++++++ tests/keras/layers/test_normalization.py | 2 ++ 7 files changed, 85 insertions(+), 31 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 5bb472c4c..154ffcf13 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -287,6 +287,10 @@ def tile(x, n): def flatten(x): + return tf.reshape(x, [-1]) + + +def batch_flatten(x): '''Turn a n-D tensor into a 2D tensor where the first dimension is conserved. ''' @@ -345,12 +349,16 @@ def set_value(x, value): class Function(object): def __init__(self, inputs, outputs, updates=[]): + assert type(inputs) in {list, tuple} + assert type(outputs) in {list, tuple} + assert type(updates) in {list, tuple} self.inputs = list(inputs) self.outputs = list(outputs) with tf.control_dependencies(self.outputs): self.updates = [tf.assign(p, new_p) for (p, new_p) in updates] def __call__(self, inputs): + assert type(inputs) in {list, tuple} names = [v.name for v in self.inputs] feed_dict = dict(zip(names, inputs)) session = _get_session() @@ -442,7 +450,7 @@ def rnn(step_function, inputs, initial_states, new_states = successive_states[-1] outputs = tf.transpose(outputs, (1, 0, 2)) - return last_output, outputs, states + return last_output, outputs, new_states def switch(condition, then_expression, else_expression): diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index f43e276c1..fa049b8f9 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -287,6 +287,10 @@ def tile(x, n): def flatten(x): + return T.flatten(x) + + +def batch_flatten(x): '''Turn a n-D tensor into a 2D tensor where the first dimension is conserved. ''' @@ -378,10 +382,14 @@ def set_value(x, value): class Function(object): def __init__(self, inputs, outputs, updates=[], **kwargs): + assert type(inputs) in {list, tuple} + assert type(outputs) in {list, tuple} + assert type(updates) in {list, tuple} self.function = theano.function(inputs, outputs, updates=updates, allow_input_downcast=True, **kwargs) def __call__(self, inputs): + assert type(inputs) in {list, tuple} return self.function(*inputs) diff --git a/keras/initializations.py b/keras/initializations.py index d0afff97c..a1451e6d6 100644 --- a/keras/initializations.py +++ b/keras/initializations.py @@ -9,52 +9,54 @@ def get_fans(shape): return fan_in, fan_out -def uniform(shape, scale=0.05): - return K.variable(np.random.uniform(low=-scale, high=scale, size=shape)) +def uniform(shape, scale=0.05, name=None): + return K.variable(np.random.uniform(low=-scale, high=scale, size=shape), + name=name) -def normal(shape, scale=0.05): - return K.variable(np.random.normal(loc=0.0, scale=scale, size=shape)) +def normal(shape, scale=0.05, name=None): + return K.variable(np.random.normal(loc=0.0, scale=scale, size=shape), + name=name) -def lecun_uniform(shape): +def lecun_uniform(shape, name=None): ''' Reference: LeCun 98, Efficient Backprop http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf ''' fan_in, fan_out = get_fans(shape) scale = np.sqrt(3. / fan_in) - return uniform(shape, scale) + return uniform(shape, scale, name=name) -def glorot_normal(shape): +def glorot_normal(shape, name=None): ''' Reference: Glorot & Bengio, AISTATS 2010 ''' fan_in, fan_out = get_fans(shape) s = np.sqrt(2. / (fan_in + fan_out)) - return normal(shape, s) + return normal(shape, s, name=name) -def glorot_uniform(shape): +def glorot_uniform(shape, name=None): fan_in, fan_out = get_fans(shape) s = np.sqrt(6. / (fan_in + fan_out)) - return uniform(shape, s) + return uniform(shape, s, name=name) -def he_normal(shape): +def he_normal(shape, name=None): ''' Reference: He et al., http://arxiv.org/abs/1502.01852 ''' fan_in, fan_out = get_fans(shape) s = np.sqrt(2. / fan_in) - return normal(shape, s) + return normal(shape, s, name=name) -def he_uniform(shape): +def he_uniform(shape, name=None): fan_in, fan_out = get_fans(shape) s = np.sqrt(6. / fan_in) - return uniform(shape, s) + return uniform(shape, s, name=name) -def orthogonal(shape, scale=1.1): +def orthogonal(shape, scale=1.1, name=None): ''' From Lasagne. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 ''' flat_shape = (shape[0], np.prod(shape[1:])) @@ -63,23 +65,23 @@ def orthogonal(shape, scale=1.1): # pick the one with the correct shape q = u if u.shape == flat_shape else v q = q.reshape(shape) - return K.variable(scale * q[:shape[0], :shape[1]]) + return K.variable(scale * q[:shape[0], :shape[1]], name=name) -def identity(shape, scale=1): +def identity(shape, scale=1, name=None): if len(shape) != 2 or shape[0] != shape[1]: raise Exception('Identity matrix initialization can only be used ' 'for 2D square matrices.') else: - return K.variable(scale * np.identity(shape[0])) + return K.variable(scale * np.identity(shape[0]), name=name) -def zero(shape): - return K.zeros(shape) +def zero(shape, name=None): + return K.zeros(shape, name=name) -def one(shape): - return K.ones(shape) +def one(shape, name=None): + return K.ones(shape, name=name) from .utils.generic_utils import get_from_module diff --git a/keras/layers/core.py b/keras/layers/core.py index 534f12194..c8a50d3e7 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -36,20 +36,34 @@ class Layer(object): allowed_kwargs = {'input_shape', 'trainable', 'batch_input_shape', - 'cache_enabled'} + 'cache_enabled', + 'name'} for kwarg in kwargs: assert kwarg in allowed_kwargs, 'Keyword argument not understood: ' + kwarg + if 'input_shape' in kwargs: self.set_input_shape((None,) + tuple(kwargs['input_shape'])) if 'batch_input_shape' in kwargs: self.set_input_shape(tuple(kwargs['batch_input_shape'])) + self.trainable = True if 'trainable' in kwargs: - self._trainable = kwargs['trainable'] + self.trainable = kwargs['trainable'] + self.name = self.__class__.__name__.lower() + if 'name' in kwargs: + self.name = kwargs['name'] if not hasattr(self, 'params'): self.params = [] - self._cache_enabled = True + self.cache_enabled = True if 'cache_enabled' in kwargs: - self._cache_enabled = kwargs['cache_enabled'] + self.cache_enabled = kwargs['cache_enabled'] + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name @property def cache_enabled(self): @@ -233,7 +247,8 @@ class Layer(object): config['input_shape'] = self._input_shape[1:] if hasattr(self, '_trainable'): config['trainable'] = self._trainable - config['cache_enabled'] = self.cache_enabled + config['cache_enabled'] = self.cache_enabled + config['custom_name'] = self.name return config def get_params(self): @@ -688,7 +703,7 @@ class Reshape(Layer): def _fix_unknown_dimension(self, input_shape, output_shape): '''Find and replace a single missing dimension in an output shape given and input shape. - + A near direct port of the internal numpy function _fix_unknown_dimension in numpy/core/src/multiarray/shape.c @@ -819,7 +834,7 @@ class Flatten(Layer): def get_output(self, train=False): X = self.get_input(train) - return K.flatten(X) + return K.batch_flatten(X) class RepeatVector(Layer): diff --git a/keras/utils/layer_utils.py b/keras/utils/layer_utils.py index 069b3fa5f..a92dd75fa 100644 --- a/keras/utils/layer_utils.py +++ b/keras/utils/layer_utils.py @@ -71,10 +71,11 @@ def container_from_config(original_layer_dict, custom_objects={}): kwargs[kwarg] = layer_dict[kwarg] return AutoEncoder(**kwargs) - else: + else: # this is a non-topological layer (e.g. Dense, etc.) layer_dict.pop('name') for k, v in layer_dict.items(): + # a dictionary argument may be a regularizer or constraint if isinstance(v, dict): vname = v.pop('name') if vname in [x for x, y in inspect.getmembers(constraints, predicate=inspect.isclass)]: @@ -85,6 +86,9 @@ def container_from_config(original_layer_dict, custom_objects={}): # not a regularizer of constraint, don't touch it v['name'] = vname + # the "name" keyword argument of layers is saved as "custom_name" + if 'custom_name' in layer_dict: + layer_dict['name'] = layer_dict.pop('custom_name') base_layer = get_layer(name, layer_dict) return base_layer diff --git a/tests/keras/layers/test_core.py b/tests/keras/layers/test_core.py index 965d5b7be..621626a01 100644 --- a/tests/keras/layers/test_core.py +++ b/tests/keras/layers/test_core.py @@ -130,6 +130,21 @@ def test_maxout_dense(): _runner(layer) +def test_naming(): + layer = core.Dense(2, input_dim=2) + assert layer.name == 'dense' + + model = Sequential() + model.add(core.Dense(2, input_dim=2, name='my_dense')) + model.add(core.Dense(2, name='my_dense')) + + assert model.layers[0].name == 'my_dense' + assert model.layers[1].name == 'my_dense' + + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.random.random((2, 2)), np.random.random((2, 2))) + + @pytest.mark.skipif(K._BACKEND == 'tensorflow', reason='currently not working with TensorFlow') def test_sequences(): diff --git a/tests/keras/layers/test_normalization.py b/tests/keras/layers/test_normalization.py index f38e70068..a62eaa258 100644 --- a/tests/keras/layers/test_normalization.py +++ b/tests/keras/layers/test_normalization.py @@ -85,6 +85,8 @@ def test_batchnorm_config(): epsilon=0.1, momentum=0.9) conf = norm.get_config() del conf['cache_enabled'] + del conf['trainable'] + del conf['custom_name'] conf_target = {"input_shape": (10, 10), "name": normalization.BatchNormalization.__name__, "epsilon": 0.1, "mode": 1, "momentum": 0.9}