diff --git a/examples/imdb_bidirectional_lstm.py b/examples/imdb_bidirectional_lstm.py index 53c12b613..dcae48e20 100644 --- a/examples/imdb_bidirectional_lstm.py +++ b/examples/imdb_bidirectional_lstm.py @@ -9,8 +9,8 @@ import numpy as np np.random.seed(1337) # for reproducibility from keras.preprocessing import sequence -from keras.models import Model -from keras.layers import Dense, Dropout, Embedding, LSTM, Input, merge +from keras.models import Sequential +from keras.layers import Dense, Dropout, Embedding, LSTM, Input, Bidirectional from keras.datasets import imdb @@ -31,24 +31,11 @@ print('X_test shape:', X_test.shape) y_train = np.array(y_train) y_test = np.array(y_test) - -# this is the placeholder tensor for the input sequences -sequence = Input(shape=(maxlen,), dtype='int32') -# this embedding layer will transform the sequences of integers -# into vectors of size 128 -embedded = Embedding(max_features, 128, input_length=maxlen)(sequence) - -# apply forwards LSTM -forwards = LSTM(64)(embedded) -# apply backwards LSTM -backwards = LSTM(64, go_backwards=True)(embedded) - -# concatenate the outputs of the 2 LSTMs -merged = merge([forwards, backwards], mode='concat', concat_axis=-1) -after_dp = Dropout(0.5)(merged) -output = Dense(1, activation='sigmoid')(after_dp) - -model = Model(input=sequence, output=output) +model = Sequential() +model.add(Embedding(max_features, 128, input_length=maxlen)) +model.add(Bidirectional(LSTM(64))) +model.add(Dropout(0.5)) +model.add(Dense(1, activation='sigmoid')) # try using different optimizers and different optimizer configs model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index bb11cd0bb..b9f575fb2 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -859,6 +859,15 @@ def one_hot(indices, nb_classes): return tf.one_hot(indices, depth=nb_classes, axis=-1) +def reverse(x, axes): + '''Reverse a tensor along the the specified axes + ''' + if type(axes) == int: + axes = [axes] + dims = [True if i in axes else False for i in range(len(x.get_shape()._dims))] + return tf.reverse(x, dims) + + # VALUE MANIPULATION diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index c8c116b57..08dbbd30c 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -606,6 +606,15 @@ def one_hot(indices, nb_classes): return oh +def reverse(x, axes): + '''Reverse a tensor along the the specified axes + ''' + if type(axes) == int: + axes = [axes] + slices = [slice(None, None, -1) if i in axes else slice(None, None, None) for i in range(x.ndim)] + return x[slices] + + # VALUE MANIPULATION diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 88075a8ea..5f52ac911 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -284,10 +284,14 @@ class Layer(object): # these properties will be set upon call of self.build(), # which itself will be called upon self.add_inbound_node if necessary. - self.trainable_weights = [] - self.non_trainable_weights = [] - self.regularizers = [] - self.constraints = {} # dict {tensor: constraint instance} + if not hasattr(self, 'trainable_weights'): + self.trainable_weights = [] + if not hasattr(self, 'non_trainable_weights'): + self.non_trainable_weights = [] + if not hasattr(self, 'regularizers'): + self.regularizers = [] + if not hasattr(self, 'constraints'): + self.constraints = {} # dict {tensor: constraint instance} self.built = False # these properties should be set by the user via keyword arguments. diff --git a/keras/layers/wrappers.py b/keras/layers/wrappers.py index 2966e6ff2..5dd895b96 100644 --- a/keras/layers/wrappers.py +++ b/keras/layers/wrappers.py @@ -133,3 +133,127 @@ class TimeDistributed(Wrapper): output_shape = self.get_output_shape_for(input_shape) y = K.reshape(y, (-1, input_length) + output_shape[2:]) return y + + +class Bidirectional(Wrapper): + ''' Bidirectional wrapper for RNNs + + # Arguments: + layer: `Recurrent` instance. + merge_mode: Mode by which outputs of the forward and backward RNNs will be combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the outputs will not be combined, they will be returned as a list. + + # Examples: + ```python + model = Sequential() + model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10))) + model.add(Bidirectional(LSTM(10))) + model.add(Dense(5)) + model.add(Activation('softmax')) + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + ``` + ''' + def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): + assert merge_mode in ['sum', 'mul', 'ave', 'concat', None], "Invalid merge mode. Merge mode should be one of {'sum', 'mul', 'ave', 'concat', None}" + self.forward_layer = layer + config = layer.get_config() + config['go_backwards'] = not config['go_backwards'] + self.backward_layer = layer.__class__.from_config(config) + self.forward_layer.name = 'forward_' + self.forward_layer.name + self.backward_layer.name = 'backward_' + self.backward_layer.name + self.merge_mode = merge_mode + if weights: + nw = len(weights) + self.forward_layer.initial_weights = weights[:nw//2] + self.backward_layer.initial_weights = weights[nw//2:] + self.stateful = layer.stateful + self.return_sequences = layer.return_sequences + self.supports_masking = True + super(Bidirectional, self).__init__(layer, **kwargs) + + def get_weights(self): + return self.forward_layer.get_weights() + self.backward_layer.get_weights() + + def set_weights(self, weights): + nw = len(weights) + self.forward_layer.set_weights(weights[:nw//2]) + self.backward_layer.set_weights(weights[nw//2:]) + + def get_output_shape_for(self, input_shape): + if self.merge_mode in ['sum', 'ave', 'mul']: + return self.forward_layer.get_output_shape_for(input_shape) + elif self.merge_mode == 'concat': + shape = list(self.forward_layer.get_output_shape_for(input_shape)) + shape[-1] *= 2 + return tuple(shape) + elif self.merge_mode is None: + return [self.forward_layer.get_output_shape_for(input_shape)] * 2 + + def call(self, X, mask=None): + Y = self.forward_layer.call(X, mask) + Y_rev = self.backward_layer.call(X, mask) + if self.return_sequences: + Y_rev = K.reverse(Y_rev, 1) + if self.merge_mode == 'concat': + return K.concatenate([Y, Y_rev]) + elif self.merge_mode == 'sum': + return Y + Y_rev + elif self.merge_mode == 'ave': + return (Y + Y_rev) / 2 + elif self.merge_mode == 'mul': + return Y * Y_rev + elif self.merge_mode is None: + return [Y, Y_rev] + + def reset_states(self): + self.forward_layer.reset_states() + self.backward_layer.reset_states() + + def build(self, input_shape): + self.forward_layer.build(input_shape) + self.backward_layer.build(input_shape) + + def compute_mask(self, input, mask): + if self.return_sequences: + if not self.merge_mode: + return [mask, mask] + else: + return mask + else: + return None + + @property + def trainable_weights(self): + if hasattr(self.forward_layer, 'trainable_weights'): + return self.forward_layer.trainable_weights + self.backward_layer.trainable_weights + return [] + + @property + def non_trainable_weights(self): + if hasattr(self.forward_layer, 'non_trainable_weights'): + return self.forward_layer.non_trainable_weights + self.backward_layer.non_trainable_weights + return [] + + @property + def updates(self): + if hasattr(self.forward_layer, 'updates'): + return self.forward_layer.updates + self.backward_layer.updates + return [] + + @property + def regularizers(self): + if hasattr(self.forward_layer, 'regularizers'): + return self.forward_layer.regularizers + self.backward_layer.regularizers + return [] + + @property + def constraints(self): + _constraints = {} + if hasattr(self.forward_layer, 'constraints'): + _constraints.update(self.forward_layer.constraints) + _constraints.update(self.backward_layer.constraints) + return _constraints + + def get_config(self): + config = {"merge_mode": self.merge_mode} + base_config = super(Bidirectional, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index a947ad021..3b05d92f6 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -70,6 +70,8 @@ class TestBackend(object): check_two_tensor_operation('batch_dot', (4, 2, 3), (4, 5, 3), axes=(2, 2)) check_single_tensor_operation('transpose', (4, 2)) + check_single_tensor_operation('reverse', (4, 3, 2), axes=1) + check_single_tensor_operation('reverse', (4, 3, 2), axes=(1, 2)) def test_shape_operations(self): # concatenate @@ -633,5 +635,6 @@ class TestBackend(object): koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes)) assert np.all(koh == oh) + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/keras/layers/test_wrappers.py b/tests/keras/layers/test_wrappers.py index 7092d0087..1fa46a4c4 100644 --- a/tests/keras/layers/test_wrappers.py +++ b/tests/keras/layers/test_wrappers.py @@ -3,7 +3,7 @@ import numpy as np from numpy.testing import assert_allclose from keras.utils.test_utils import keras_test from keras.layers import wrappers, Input -from keras.layers import core, convolutional +from keras.layers import core, convolutional, recurrent from keras.models import Sequential, Model, model_from_json @@ -76,5 +76,36 @@ def test_TimeDistributed(): outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), nb_epoch=1, batch_size=10) +@keras_test +def test_Bidirectional(): + for rnn in [recurrent.SimpleRNN, recurrent.LSTM]: + for mode in ['sum', 'concat']: + x = np.random.random((5, 3, 2)) + output_dim = 6 if mode == 'concat' else 3 + y = np.random.random((5, output_dim)) + + # test with Sequential model + model = Sequential() + model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode, input_shape=(3, 2))) + model.add(core.Activation('sigmoid')) + model.compile(loss='mse', optimizer='sgd') + model.fit(x, y, nb_epoch=1, batch_size=5) + + # test stacked bidirectional layers + model = Sequential() + model.add(wrappers.Bidirectional(rnn(2, return_sequences=True), merge_mode=mode, input_shape=(3, 2))) + model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode)) + model.add(core.Activation('sigmoid')) + model.compile(loss='mse', optimizer='sgd') + model.fit(x, y, nb_epoch=1, batch_size=5) + + # test with functional API + input = Input((3, 2)) + output = wrappers.Bidirectional(rnn(3), merge_mode=mode)(input) + model = Model(input, output) + model.compile(loss='mse', optimizer='sgd') + model.fit(x, y, nb_epoch=1, batch_size=5) + + if __name__ == '__main__': pytest.main([__file__])