Naming, batch_flatten

This commit is contained in:
Francois Chollet 2016-01-08 10:02:28 -08:00
parent 13379da81b
commit 037e592f2b
7 changed files with 85 additions and 31 deletions

@ -287,6 +287,10 @@ def tile(x, n):
def flatten(x):
return tf.reshape(x, [-1])
def batch_flatten(x):
'''Turn a n-D tensor into a 2D tensor where
the first dimension is conserved.
'''
@ -345,12 +349,16 @@ def set_value(x, value):
class Function(object):
def __init__(self, inputs, outputs, updates=[]):
assert type(inputs) in {list, tuple}
assert type(outputs) in {list, tuple}
assert type(updates) in {list, tuple}
self.inputs = list(inputs)
self.outputs = list(outputs)
with tf.control_dependencies(self.outputs):
self.updates = [tf.assign(p, new_p) for (p, new_p) in updates]
def __call__(self, inputs):
assert type(inputs) in {list, tuple}
names = [v.name for v in self.inputs]
feed_dict = dict(zip(names, inputs))
session = _get_session()
@ -442,7 +450,7 @@ def rnn(step_function, inputs, initial_states,
new_states = successive_states[-1]
outputs = tf.transpose(outputs, (1, 0, 2))
return last_output, outputs, states
return last_output, outputs, new_states
def switch(condition, then_expression, else_expression):

@ -287,6 +287,10 @@ def tile(x, n):
def flatten(x):
return T.flatten(x)
def batch_flatten(x):
'''Turn a n-D tensor into a 2D tensor where
the first dimension is conserved.
'''
@ -378,10 +382,14 @@ def set_value(x, value):
class Function(object):
def __init__(self, inputs, outputs, updates=[], **kwargs):
assert type(inputs) in {list, tuple}
assert type(outputs) in {list, tuple}
assert type(updates) in {list, tuple}
self.function = theano.function(inputs, outputs, updates=updates,
allow_input_downcast=True, **kwargs)
def __call__(self, inputs):
assert type(inputs) in {list, tuple}
return self.function(*inputs)

@ -9,52 +9,54 @@ def get_fans(shape):
return fan_in, fan_out
def uniform(shape, scale=0.05):
return K.variable(np.random.uniform(low=-scale, high=scale, size=shape))
def uniform(shape, scale=0.05, name=None):
return K.variable(np.random.uniform(low=-scale, high=scale, size=shape),
name=name)
def normal(shape, scale=0.05):
return K.variable(np.random.normal(loc=0.0, scale=scale, size=shape))
def normal(shape, scale=0.05, name=None):
return K.variable(np.random.normal(loc=0.0, scale=scale, size=shape),
name=name)
def lecun_uniform(shape):
def lecun_uniform(shape, name=None):
''' Reference: LeCun 98, Efficient Backprop
http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
'''
fan_in, fan_out = get_fans(shape)
scale = np.sqrt(3. / fan_in)
return uniform(shape, scale)
return uniform(shape, scale, name=name)
def glorot_normal(shape):
def glorot_normal(shape, name=None):
''' Reference: Glorot & Bengio, AISTATS 2010
'''
fan_in, fan_out = get_fans(shape)
s = np.sqrt(2. / (fan_in + fan_out))
return normal(shape, s)
return normal(shape, s, name=name)
def glorot_uniform(shape):
def glorot_uniform(shape, name=None):
fan_in, fan_out = get_fans(shape)
s = np.sqrt(6. / (fan_in + fan_out))
return uniform(shape, s)
return uniform(shape, s, name=name)
def he_normal(shape):
def he_normal(shape, name=None):
''' Reference: He et al., http://arxiv.org/abs/1502.01852
'''
fan_in, fan_out = get_fans(shape)
s = np.sqrt(2. / fan_in)
return normal(shape, s)
return normal(shape, s, name=name)
def he_uniform(shape):
def he_uniform(shape, name=None):
fan_in, fan_out = get_fans(shape)
s = np.sqrt(6. / fan_in)
return uniform(shape, s)
return uniform(shape, s, name=name)
def orthogonal(shape, scale=1.1):
def orthogonal(shape, scale=1.1, name=None):
''' From Lasagne. Reference: Saxe et al., http://arxiv.org/abs/1312.6120
'''
flat_shape = (shape[0], np.prod(shape[1:]))
@ -63,23 +65,23 @@ def orthogonal(shape, scale=1.1):
# pick the one with the correct shape
q = u if u.shape == flat_shape else v
q = q.reshape(shape)
return K.variable(scale * q[:shape[0], :shape[1]])
return K.variable(scale * q[:shape[0], :shape[1]], name=name)
def identity(shape, scale=1):
def identity(shape, scale=1, name=None):
if len(shape) != 2 or shape[0] != shape[1]:
raise Exception('Identity matrix initialization can only be used '
'for 2D square matrices.')
else:
return K.variable(scale * np.identity(shape[0]))
return K.variable(scale * np.identity(shape[0]), name=name)
def zero(shape):
return K.zeros(shape)
def zero(shape, name=None):
return K.zeros(shape, name=name)
def one(shape):
return K.ones(shape)
def one(shape, name=None):
return K.ones(shape, name=name)
from .utils.generic_utils import get_from_module

@ -36,20 +36,34 @@ class Layer(object):
allowed_kwargs = {'input_shape',
'trainable',
'batch_input_shape',
'cache_enabled'}
'cache_enabled',
'name'}
for kwarg in kwargs:
assert kwarg in allowed_kwargs, 'Keyword argument not understood: ' + kwarg
if 'input_shape' in kwargs:
self.set_input_shape((None,) + tuple(kwargs['input_shape']))
if 'batch_input_shape' in kwargs:
self.set_input_shape(tuple(kwargs['batch_input_shape']))
self.trainable = True
if 'trainable' in kwargs:
self._trainable = kwargs['trainable']
self.trainable = kwargs['trainable']
self.name = self.__class__.__name__.lower()
if 'name' in kwargs:
self.name = kwargs['name']
if not hasattr(self, 'params'):
self.params = []
self._cache_enabled = True
self.cache_enabled = True
if 'cache_enabled' in kwargs:
self._cache_enabled = kwargs['cache_enabled']
self.cache_enabled = kwargs['cache_enabled']
@property
def name(self):
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def cache_enabled(self):
@ -234,6 +248,7 @@ class Layer(object):
if hasattr(self, '_trainable'):
config['trainable'] = self._trainable
config['cache_enabled'] = self.cache_enabled
config['custom_name'] = self.name
return config
def get_params(self):
@ -819,7 +834,7 @@ class Flatten(Layer):
def get_output(self, train=False):
X = self.get_input(train)
return K.flatten(X)
return K.batch_flatten(X)
class RepeatVector(Layer):

@ -71,10 +71,11 @@ def container_from_config(original_layer_dict, custom_objects={}):
kwargs[kwarg] = layer_dict[kwarg]
return AutoEncoder(**kwargs)
else:
else: # this is a non-topological layer (e.g. Dense, etc.)
layer_dict.pop('name')
for k, v in layer_dict.items():
# a dictionary argument may be a regularizer or constraint
if isinstance(v, dict):
vname = v.pop('name')
if vname in [x for x, y in inspect.getmembers(constraints, predicate=inspect.isclass)]:
@ -85,6 +86,9 @@ def container_from_config(original_layer_dict, custom_objects={}):
# not a regularizer of constraint, don't touch it
v['name'] = vname
# the "name" keyword argument of layers is saved as "custom_name"
if 'custom_name' in layer_dict:
layer_dict['name'] = layer_dict.pop('custom_name')
base_layer = get_layer(name, layer_dict)
return base_layer

@ -130,6 +130,21 @@ def test_maxout_dense():
_runner(layer)
def test_naming():
layer = core.Dense(2, input_dim=2)
assert layer.name == 'dense'
model = Sequential()
model.add(core.Dense(2, input_dim=2, name='my_dense'))
model.add(core.Dense(2, name='my_dense'))
assert model.layers[0].name == 'my_dense'
assert model.layers[1].name == 'my_dense'
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(np.random.random((2, 2)), np.random.random((2, 2)))
@pytest.mark.skipif(K._BACKEND == 'tensorflow',
reason='currently not working with TensorFlow')
def test_sequences():

@ -85,6 +85,8 @@ def test_batchnorm_config():
epsilon=0.1, momentum=0.9)
conf = norm.get_config()
del conf['cache_enabled']
del conf['trainable']
del conf['custom_name']
conf_target = {"input_shape": (10, 10),
"name": normalization.BatchNormalization.__name__,
"epsilon": 0.1, "mode": 1, "momentum": 0.9}