Bidirectional Wrapper (#3495)
* Add Bidirectional Wrapper * Fix example * Update wrappers.py * Add reverse op * Update tensorflow_backend.py * Update wrappers.py * Update test_wrappers.py * bug fix * Update test_wrappers.py * Update test_wrappers.py * bug fix * Add test for reverse op * Enable reverse along multiple axes * Update theano_backend.py * Update theano_backend.py * Update test_wrappers.py * Speed up tests * Validate merge_mode arg, Add None mode * Update test_wrappers.py * Update test_wrappers.py * Add properties; reverse -> backward * Bug fix * Resolve naming conflict * Whitespace fix * Update imdb_bidirectional_lstm.py * Fix imports
This commit is contained in:
parent
52c1a7456f
commit
f25e894558
@ -9,8 +9,8 @@ import numpy as np
|
||||
np.random.seed(1337) # for reproducibility
|
||||
|
||||
from keras.preprocessing import sequence
|
||||
from keras.models import Model
|
||||
from keras.layers import Dense, Dropout, Embedding, LSTM, Input, merge
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Dense, Dropout, Embedding, LSTM, Input, Bidirectional
|
||||
from keras.datasets import imdb
|
||||
|
||||
|
||||
@ -31,24 +31,11 @@ print('X_test shape:', X_test.shape)
|
||||
y_train = np.array(y_train)
|
||||
y_test = np.array(y_test)
|
||||
|
||||
|
||||
# this is the placeholder tensor for the input sequences
|
||||
sequence = Input(shape=(maxlen,), dtype='int32')
|
||||
# this embedding layer will transform the sequences of integers
|
||||
# into vectors of size 128
|
||||
embedded = Embedding(max_features, 128, input_length=maxlen)(sequence)
|
||||
|
||||
# apply forwards LSTM
|
||||
forwards = LSTM(64)(embedded)
|
||||
# apply backwards LSTM
|
||||
backwards = LSTM(64, go_backwards=True)(embedded)
|
||||
|
||||
# concatenate the outputs of the 2 LSTMs
|
||||
merged = merge([forwards, backwards], mode='concat', concat_axis=-1)
|
||||
after_dp = Dropout(0.5)(merged)
|
||||
output = Dense(1, activation='sigmoid')(after_dp)
|
||||
|
||||
model = Model(input=sequence, output=output)
|
||||
model = Sequential()
|
||||
model.add(Embedding(max_features, 128, input_length=maxlen))
|
||||
model.add(Bidirectional(LSTM(64)))
|
||||
model.add(Dropout(0.5))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
# try using different optimizers and different optimizer configs
|
||||
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
|
||||
|
@ -859,6 +859,15 @@ def one_hot(indices, nb_classes):
|
||||
return tf.one_hot(indices, depth=nb_classes, axis=-1)
|
||||
|
||||
|
||||
def reverse(x, axes):
|
||||
'''Reverse a tensor along the the specified axes
|
||||
'''
|
||||
if type(axes) == int:
|
||||
axes = [axes]
|
||||
dims = [True if i in axes else False for i in range(len(x.get_shape()._dims))]
|
||||
return tf.reverse(x, dims)
|
||||
|
||||
|
||||
# VALUE MANIPULATION
|
||||
|
||||
|
||||
|
@ -606,6 +606,15 @@ def one_hot(indices, nb_classes):
|
||||
return oh
|
||||
|
||||
|
||||
def reverse(x, axes):
|
||||
'''Reverse a tensor along the the specified axes
|
||||
'''
|
||||
if type(axes) == int:
|
||||
axes = [axes]
|
||||
slices = [slice(None, None, -1) if i in axes else slice(None, None, None) for i in range(x.ndim)]
|
||||
return x[slices]
|
||||
|
||||
|
||||
# VALUE MANIPULATION
|
||||
|
||||
|
||||
|
@ -284,10 +284,14 @@ class Layer(object):
|
||||
|
||||
# these properties will be set upon call of self.build(),
|
||||
# which itself will be called upon self.add_inbound_node if necessary.
|
||||
self.trainable_weights = []
|
||||
self.non_trainable_weights = []
|
||||
self.regularizers = []
|
||||
self.constraints = {} # dict {tensor: constraint instance}
|
||||
if not hasattr(self, 'trainable_weights'):
|
||||
self.trainable_weights = []
|
||||
if not hasattr(self, 'non_trainable_weights'):
|
||||
self.non_trainable_weights = []
|
||||
if not hasattr(self, 'regularizers'):
|
||||
self.regularizers = []
|
||||
if not hasattr(self, 'constraints'):
|
||||
self.constraints = {} # dict {tensor: constraint instance}
|
||||
self.built = False
|
||||
|
||||
# these properties should be set by the user via keyword arguments.
|
||||
|
@ -133,3 +133,127 @@ class TimeDistributed(Wrapper):
|
||||
output_shape = self.get_output_shape_for(input_shape)
|
||||
y = K.reshape(y, (-1, input_length) + output_shape[2:])
|
||||
return y
|
||||
|
||||
|
||||
class Bidirectional(Wrapper):
|
||||
''' Bidirectional wrapper for RNNs
|
||||
|
||||
# Arguments:
|
||||
layer: `Recurrent` instance.
|
||||
merge_mode: Mode by which outputs of the forward and backward RNNs will be combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the outputs will not be combined, they will be returned as a list.
|
||||
|
||||
# Examples:
|
||||
```python
|
||||
model = Sequential()
|
||||
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
|
||||
model.add(Bidirectional(LSTM(10)))
|
||||
model.add(Dense(5))
|
||||
model.add(Activation('softmax'))
|
||||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||||
```
|
||||
'''
|
||||
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
|
||||
assert merge_mode in ['sum', 'mul', 'ave', 'concat', None], "Invalid merge mode. Merge mode should be one of {'sum', 'mul', 'ave', 'concat', None}"
|
||||
self.forward_layer = layer
|
||||
config = layer.get_config()
|
||||
config['go_backwards'] = not config['go_backwards']
|
||||
self.backward_layer = layer.__class__.from_config(config)
|
||||
self.forward_layer.name = 'forward_' + self.forward_layer.name
|
||||
self.backward_layer.name = 'backward_' + self.backward_layer.name
|
||||
self.merge_mode = merge_mode
|
||||
if weights:
|
||||
nw = len(weights)
|
||||
self.forward_layer.initial_weights = weights[:nw//2]
|
||||
self.backward_layer.initial_weights = weights[nw//2:]
|
||||
self.stateful = layer.stateful
|
||||
self.return_sequences = layer.return_sequences
|
||||
self.supports_masking = True
|
||||
super(Bidirectional, self).__init__(layer, **kwargs)
|
||||
|
||||
def get_weights(self):
|
||||
return self.forward_layer.get_weights() + self.backward_layer.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
nw = len(weights)
|
||||
self.forward_layer.set_weights(weights[:nw//2])
|
||||
self.backward_layer.set_weights(weights[nw//2:])
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
if self.merge_mode in ['sum', 'ave', 'mul']:
|
||||
return self.forward_layer.get_output_shape_for(input_shape)
|
||||
elif self.merge_mode == 'concat':
|
||||
shape = list(self.forward_layer.get_output_shape_for(input_shape))
|
||||
shape[-1] *= 2
|
||||
return tuple(shape)
|
||||
elif self.merge_mode is None:
|
||||
return [self.forward_layer.get_output_shape_for(input_shape)] * 2
|
||||
|
||||
def call(self, X, mask=None):
|
||||
Y = self.forward_layer.call(X, mask)
|
||||
Y_rev = self.backward_layer.call(X, mask)
|
||||
if self.return_sequences:
|
||||
Y_rev = K.reverse(Y_rev, 1)
|
||||
if self.merge_mode == 'concat':
|
||||
return K.concatenate([Y, Y_rev])
|
||||
elif self.merge_mode == 'sum':
|
||||
return Y + Y_rev
|
||||
elif self.merge_mode == 'ave':
|
||||
return (Y + Y_rev) / 2
|
||||
elif self.merge_mode == 'mul':
|
||||
return Y * Y_rev
|
||||
elif self.merge_mode is None:
|
||||
return [Y, Y_rev]
|
||||
|
||||
def reset_states(self):
|
||||
self.forward_layer.reset_states()
|
||||
self.backward_layer.reset_states()
|
||||
|
||||
def build(self, input_shape):
|
||||
self.forward_layer.build(input_shape)
|
||||
self.backward_layer.build(input_shape)
|
||||
|
||||
def compute_mask(self, input, mask):
|
||||
if self.return_sequences:
|
||||
if not self.merge_mode:
|
||||
return [mask, mask]
|
||||
else:
|
||||
return mask
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
if hasattr(self.forward_layer, 'trainable_weights'):
|
||||
return self.forward_layer.trainable_weights + self.backward_layer.trainable_weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
if hasattr(self.forward_layer, 'non_trainable_weights'):
|
||||
return self.forward_layer.non_trainable_weights + self.backward_layer.non_trainable_weights
|
||||
return []
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
if hasattr(self.forward_layer, 'updates'):
|
||||
return self.forward_layer.updates + self.backward_layer.updates
|
||||
return []
|
||||
|
||||
@property
|
||||
def regularizers(self):
|
||||
if hasattr(self.forward_layer, 'regularizers'):
|
||||
return self.forward_layer.regularizers + self.backward_layer.regularizers
|
||||
return []
|
||||
|
||||
@property
|
||||
def constraints(self):
|
||||
_constraints = {}
|
||||
if hasattr(self.forward_layer, 'constraints'):
|
||||
_constraints.update(self.forward_layer.constraints)
|
||||
_constraints.update(self.backward_layer.constraints)
|
||||
return _constraints
|
||||
|
||||
def get_config(self):
|
||||
config = {"merge_mode": self.merge_mode}
|
||||
base_config = super(Bidirectional, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -70,6 +70,8 @@ class TestBackend(object):
|
||||
check_two_tensor_operation('batch_dot', (4, 2, 3), (4, 5, 3),
|
||||
axes=(2, 2))
|
||||
check_single_tensor_operation('transpose', (4, 2))
|
||||
check_single_tensor_operation('reverse', (4, 3, 2), axes=1)
|
||||
check_single_tensor_operation('reverse', (4, 3, 2), axes=(1, 2))
|
||||
|
||||
def test_shape_operations(self):
|
||||
# concatenate
|
||||
@ -633,5 +635,6 @@ class TestBackend(object):
|
||||
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
|
||||
assert np.all(koh == oh)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
@ -3,7 +3,7 @@ import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
from keras.utils.test_utils import keras_test
|
||||
from keras.layers import wrappers, Input
|
||||
from keras.layers import core, convolutional
|
||||
from keras.layers import core, convolutional, recurrent
|
||||
from keras.models import Sequential, Model, model_from_json
|
||||
|
||||
|
||||
@ -76,5 +76,36 @@ def test_TimeDistributed():
|
||||
outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), nb_epoch=1, batch_size=10)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_Bidirectional():
|
||||
for rnn in [recurrent.SimpleRNN, recurrent.LSTM]:
|
||||
for mode in ['sum', 'concat']:
|
||||
x = np.random.random((5, 3, 2))
|
||||
output_dim = 6 if mode == 'concat' else 3
|
||||
y = np.random.random((5, output_dim))
|
||||
|
||||
# test with Sequential model
|
||||
model = Sequential()
|
||||
model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode, input_shape=(3, 2)))
|
||||
model.add(core.Activation('sigmoid'))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, nb_epoch=1, batch_size=5)
|
||||
|
||||
# test stacked bidirectional layers
|
||||
model = Sequential()
|
||||
model.add(wrappers.Bidirectional(rnn(2, return_sequences=True), merge_mode=mode, input_shape=(3, 2)))
|
||||
model.add(wrappers.Bidirectional(rnn(3), merge_mode=mode))
|
||||
model.add(core.Activation('sigmoid'))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, nb_epoch=1, batch_size=5)
|
||||
|
||||
# test with functional API
|
||||
input = Input((3, 2))
|
||||
output = wrappers.Bidirectional(rnn(3), merge_mode=mode)(input)
|
||||
model = Model(input, output)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, nb_epoch=1, batch_size=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
Loading…
Reference in New Issue
Block a user