Bidirectional Wrapper (#3495)

* Add Bidirectional Wrapper

* Fix example

* Update wrappers.py

* Add reverse op

* Update tensorflow_backend.py

* Update wrappers.py

* Update test_wrappers.py

* bug fix

* Update test_wrappers.py

* Update test_wrappers.py

* bug fix

* Add test for reverse op

* Enable reverse along multiple axes

* Update theano_backend.py

* Update theano_backend.py

* Update test_wrappers.py

* Speed up tests

* Validate merge_mode arg, Add None mode

* Update test_wrappers.py

* Update test_wrappers.py

* Add properties; reverse -> backward

* Bug fix

* Resolve naming conflict

* Whitespace fix

* Update imdb_bidirectional_lstm.py

* Fix imports
This commit is contained in:
Fariz Rahman 2016-08-18 04:57:06 +05:30 committed by François Chollet
parent 52c1a7456f
commit f25e894558
7 changed files with 192 additions and 25 deletions

@ -9,8 +9,8 @@ import numpy as np
np.random.seed(1337) # for reproducibility
from keras.preprocessing import sequence
from keras.models import Model
from keras.layers import Dense, Dropout, Embedding, LSTM, Input, merge
from keras.models import Sequential
from keras.layers import Dense, Dropout, Embedding, LSTM, Input, Bidirectional
from keras.datasets import imdb
@ -31,24 +31,11 @@ print('X_test shape:', X_test.shape)
y_train = np.array(y_train)
y_test = np.array(y_test)
# this is the placeholder tensor for the input sequences
sequence = Input(shape=(maxlen,), dtype='int32')
# this embedding layer will transform the sequences of integers
# into vectors of size 128
embedded = Embedding(max_features, 128, input_length=maxlen)(sequence)
# apply forwards LSTM
forwards = LSTM(64)(embedded)
# apply backwards LSTM
backwards = LSTM(64, go_backwards=True)(embedded)
# concatenate the outputs of the 2 LSTMs
merged = merge([forwards, backwards], mode='concat', concat_axis=-1)
after_dp = Dropout(0.5)(merged)
output = Dense(1, activation='sigmoid')(after_dp)
model = Model(input=sequence, output=output)
model = Sequential()
model.add(Embedding(max_features, 128, input_length=maxlen))
model.add(Bidirectional(LSTM(64)))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
# try using different optimizers and different optimizer configs
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

@ -859,6 +859,15 @@ def one_hot(indices, nb_classes):
return tf.one_hot(indices, depth=nb_classes, axis=-1)
def reverse(x, axes):
'''Reverse a tensor along the the specified axes
'''
if type(axes) == int:
axes = [axes]
dims = [True if i in axes else False for i in range(len(x.get_shape()._dims))]
return tf.reverse(x, dims)
# VALUE MANIPULATION

@ -606,6 +606,15 @@ def one_hot(indices, nb_classes):
return oh
def reverse(x, axes):
'''Reverse a tensor along the the specified axes
'''
if type(axes) == int:
axes = [axes]
slices = [slice(None, None, -1) if i in axes else slice(None, None, None) for i in range(x.ndim)]
return x[slices]
# VALUE MANIPULATION

@ -284,10 +284,14 @@ class Layer(object):
# these properties will be set upon call of self.build(),
# which itself will be called upon self.add_inbound_node if necessary.
self.trainable_weights = []
self.non_trainable_weights = []
self.regularizers = []
self.constraints = {} # dict {tensor: constraint instance}
if not hasattr(self, 'trainable_weights'):
self.trainable_weights = []
if not hasattr(self, 'non_trainable_weights'):
self.non_trainable_weights = []
if not hasattr(self, 'regularizers'):
self.regularizers = []
if not hasattr(self, 'constraints'):
self.constraints = {} # dict {tensor: constraint instance}
self.built = False
# these properties should be set by the user via keyword arguments.

@ -133,3 +133,127 @@ class TimeDistributed(Wrapper):
output_shape = self.get_output_shape_for(input_shape)
y = K.reshape(y, (-1, input_length) + output_shape[2:])
return y
class Bidirectional(Wrapper):
''' Bidirectional wrapper for RNNs
# Arguments:
layer: `Recurrent` instance.
merge_mode: Mode by which outputs of the forward and backward RNNs will be combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the outputs will not be combined, they will be returned as a list.
# Examples:
```python
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
```
'''
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
assert merge_mode in ['sum', 'mul', 'ave', 'concat', None], "Invalid merge mode. Merge mode should be one of {'sum', 'mul', 'ave', 'concat', None}"
self.forward_layer = layer
config = layer.get_config()
config['go_backwards'] = not config['go_backwards']
self.backward_layer = layer.__class__.from_config(config)
self.forward_layer.name = 'forward_' + self.forward_layer.name
self.backward_layer.name = 'backward_' + self.backward_layer.name
self.merge_mode = merge_mode
if weights:
nw = len(weights)
self.forward_layer.initial_weights = weights[:nw//2]
self.backward_layer.initial_weights = weights[nw//2:]
self.stateful = layer.stateful
self.return_sequences = layer.return_sequences
self.supports_masking = True
super(Bidirectional, self).__init__(layer, **kwargs)
def get_weights(self):
return self.forward_layer.get_weights() + self.backward_layer.get_weights()
def set_weights(self, weights):
nw = len(weights)
self.forward_layer.set_weights(weights[:nw//2])
self.backward_layer.set_weights(weights[nw//2:])
def get_output_shape_for(self, input_shape):
if self.merge_mode in ['sum', 'ave', 'mul']:
return self.forward_layer.get_output_shape_for(input_shape)
elif self.merge_mode == 'concat':
shape = list(self.forward_layer.get_output_shape_for(input_shape))
shape[-1] *= 2
return tuple(shape)
elif self.merge_mode is None:
return [self.forward_layer.get_output_shape_for(input_shape)] * 2
def call(self, X, mask=None):
Y = self.forward_layer.call(X, mask)
Y_rev = self.backward_layer.call(X, mask)
if self.return_sequences:
Y_rev = K.reverse(Y_rev, 1)
if self.merge_mode == 'concat':
return K.concatenate([Y, Y_rev])
elif self.merge_mode == 'sum':
return Y + Y_rev
elif self.merge_mode == 'ave':
return (Y + Y_rev) / 2
elif self.merge_mode == 'mul':
return Y * Y_rev
elif self.merge_mode is None:
return [Y, Y_rev]
def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
def build(self, input_shape):
self.forward_layer.build(input_shape)
self.backward_layer.build(input_shape)
def compute_mask(self, input, mask):
if self.return_sequences:
if not self.merge_mode:
return [mask, mask]
else:
return mask
else:
return None
@property
def trainable_weights(self):
if hasattr(self.forward_layer, 'trainable_weights'):
return self.forward_layer.trainable_weights + self.backward_layer.trainable_weights
return []
@property
def non_trainable_weights(self):
if hasattr(self.forward_layer, 'non_trainable_weights'):
return self.forward_layer.non_trainable_weights + self.backward_layer.non_trainable_weights
return []
@property
def updates(self):
if hasattr(self.forward_layer, 'updates'):
return self.forward_layer.updates + self.backward_layer.updates
return []
@property
def regularizers(self):
if hasattr(self.forward_layer, 'regularizers'):
return self.forward_layer.regularizers + self.backward_layer.regularizers
return []
@property
def constraints(self):
_constraints = {}
if hasattr(self.forward_layer, 'constraints'):
_constraints.update(self.forward_layer.constraints)
_constraints.update(self.backward_layer.constraints)
return _constraints
def get_config(self):
config = {"merge_mode": self.merge_mode}
base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@ -70,6 +70,8 @@ class TestBackend(object):
check_two_tensor_operation('batch_dot', (4, 2, 3), (4, 5, 3),
axes=(2, 2))
check_single_tensor_operation('transpose', (4, 2))
check_single_tensor_operation('reverse', (4, 3, 2), axes=1)
check_single_tensor_operation('reverse', (4, 3, 2), axes=(1, 2))
def test_shape_operations(self):
# concatenate
@ -633,5 +635,6 @@ class TestBackend(object):
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
assert np.all(koh == oh)
if __name__ == '__main__':
pytest.main([__file__])

@ -3,7 +3,7 @@ import numpy as np
from numpy.testing import assert_allclose
from keras.utils.test_utils import keras_test
from keras.layers import wrappers, Input
from keras.layers import core, convolutional
from keras.layers import core, convolutional, recurrent
from keras.models import Sequential, Model, model_from_json
@ -76,5 +76,36 @@ def test_TimeDistributed():
outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), nb_epoch=1, batch_size=10)
@keras_test
def test_Bidirectional():
for rnn in [recurrent.SimpleRNN, recurrent.LSTM]:
for mode in ['sum', 'concat']:
x = np.random.random((5, 3, 2))
output_dim = 6 if mode == 'concat' else 3
y = np.random.random((5, output_dim))
# test with Sequential model
model = Sequential()
model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode, input_shape=(3, 2)))
model.add(core.Activation('sigmoid'))
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, nb_epoch=1, batch_size=5)
# test stacked bidirectional layers
model = Sequential()
model.add(wrappers.Bidirectional(rnn(2, return_sequences=True), merge_mode=mode, input_shape=(3, 2)))
model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode))
model.add(core.Activation('sigmoid'))
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, nb_epoch=1, batch_size=5)
# test with functional API
input = Input((3, 2))
output = wrappers.Bidirectional(rnn(3), merge_mode=mode)(input)
model = Model(input, output)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, nb_epoch=1, batch_size=5)
if __name__ == '__main__':
pytest.main([__file__])