Merge branch 'add-1' of https://github.com/dbonadiman/keras into dbonadiman-add-1
This commit is contained in:
commit
1760ed6cd7
64
examples/imdb_bidirectional_lstm.py
Normal file
64
examples/imdb_bidirectional_lstm.py
Normal file
@ -0,0 +1,64 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
np.random.seed(1337) # for reproducibility
|
||||
|
||||
from keras.preprocessing import sequence
|
||||
from keras.utils.np_utils import accuracy
|
||||
from keras.models import Graph
|
||||
from keras.layers.core import Dense, Dropout
|
||||
from keras.layers.embeddings import Embedding
|
||||
from keras.layers.recurrent import LSTM
|
||||
from keras.datasets import imdb
|
||||
|
||||
'''
|
||||
Train a Bidirectional LSTM on the IMDB sentiment classification task.
|
||||
The dataset is actually too small for bidirectional LSTM to be of any advantage
|
||||
compared to simpler, much faster methods such as TF-IDF+LogReg.
|
||||
Bidirectional LSTM may be not suited for a simple text classification task.
|
||||
Notes:
|
||||
- RNNs are tricky. And in particular Bidirectional RNNs you may experiment
|
||||
with different Recurrent layers and different parameters to find the best configuration
|
||||
for your task.
|
||||
|
||||
GPU command:
|
||||
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_bidirectional_lstm.py
|
||||
|
||||
Output after 4 epochs on CPU: ~0.8146
|
||||
'''
|
||||
|
||||
max_features = 20000
|
||||
maxlen = 100 # cut texts after this number of words (among top max_features most common words)
|
||||
batch_size = 32
|
||||
|
||||
print("Loading data...")
|
||||
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features, test_split=0.2)
|
||||
print(len(X_train), 'train sequences')
|
||||
print(len(X_test), 'test sequences')
|
||||
|
||||
print("Pad sequences (samples x time)")
|
||||
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
|
||||
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
|
||||
print('X_train shape:', X_train.shape)
|
||||
print('X_test shape:', X_test.shape)
|
||||
y_train = np.array(y_train)
|
||||
y_test = np.array(y_test)
|
||||
|
||||
print('Build model...')
|
||||
model = Graph()
|
||||
model.add_input(name='input', input_shape=(1,), dtype=int)
|
||||
model.add_node(Embedding(max_features, 128, input_length=maxlen), name='embedding', input='input')
|
||||
model.add_node(LSTM(64), name='forward', input='embedding' ) # You can change these two layers with GRU
|
||||
model.add_node(LSTM(64, go_backwards=True), name='backward', input='embedding' )
|
||||
model.add_node(Dropout(0.5), name='dropout', inputs=['forward', 'backward'])
|
||||
model.add_node(Dense(1, activation='sigmoid'), name='sigmoid', input='dropout')
|
||||
model.add_output(name='output', input='sigmoid')
|
||||
|
||||
# try using different optimizers and different optimizer configs
|
||||
model.compile('adam', {'output':'binary_crossentropy'})
|
||||
|
||||
print("Train...")
|
||||
model.fit({'input':X_train, 'output':y_train}, batch_size=batch_size, nb_epoch=4)
|
||||
acc = accuracy(y_test, np.round(np.array(model.predict({'input':X_test}, batch_size=batch_size)['output'])))
|
||||
print('Test accuracy:', acc)
|
||||
|
@ -55,7 +55,8 @@ class SimpleRNN(Recurrent):
|
||||
'''
|
||||
def __init__(self, output_dim,
|
||||
init='glorot_uniform', inner_init='orthogonal', activation='sigmoid', weights=None,
|
||||
truncate_gradient=-1, return_sequences=False, input_dim=None, input_length=None, **kwargs):
|
||||
truncate_gradient=-1, return_sequences=False, input_dim=None,
|
||||
input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -63,6 +64,7 @@ class SimpleRNN(Recurrent):
|
||||
self.activation = activations.get(activation)
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -107,7 +109,8 @@ class SimpleRNN(Recurrent):
|
||||
# initialization of the output. Input to _step with default tap=-1.
|
||||
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
|
||||
non_sequences=self.U, # static inputs to _step
|
||||
truncate_gradient=self.truncate_gradient)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
@ -122,7 +125,8 @@ class SimpleRNN(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(SimpleRNN, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -141,7 +145,7 @@ class SimpleDeepRNN(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal',
|
||||
activation='sigmoid', inner_activation='hard_sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -151,6 +155,7 @@ class SimpleDeepRNN(Recurrent):
|
||||
self.depth = depth
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -202,8 +207,8 @@ class SimpleDeepRNN(Recurrent):
|
||||
taps=[(-i-1) for i in range(self.depth)]
|
||||
)],
|
||||
non_sequences=self.Us,
|
||||
truncate_gradient=self.truncate_gradient
|
||||
)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
@ -220,7 +225,8 @@ class SimpleDeepRNN(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(SimpleDeepRNN, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -251,7 +257,7 @@ class GRU(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal',
|
||||
activation='sigmoid', inner_activation='hard_sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -260,6 +266,7 @@ class GRU(Recurrent):
|
||||
self.truncate_gradient = truncate_gradient
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -317,7 +324,8 @@ class GRU(Recurrent):
|
||||
sequences=[x_z, x_r, x_h, padded_mask],
|
||||
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
|
||||
non_sequences=[self.U_z, self.U_r, self.U_h],
|
||||
truncate_gradient=self.truncate_gradient)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
@ -333,7 +341,8 @@ class GRU(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(GRU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -367,7 +376,7 @@ class LSTM(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',
|
||||
activation='tanh', inner_activation='hard_sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -377,6 +386,7 @@ class LSTM(Recurrent):
|
||||
self.truncate_gradient = truncate_gradient
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -447,7 +457,8 @@ class LSTM(Recurrent):
|
||||
T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)
|
||||
],
|
||||
non_sequences=[self.U_i, self.U_f, self.U_o, self.U_c],
|
||||
truncate_gradient=self.truncate_gradient)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
@ -464,7 +475,8 @@ class LSTM(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(LSTM, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -493,7 +505,7 @@ class JZS1(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal',
|
||||
activation='tanh', inner_activation='sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -502,6 +514,7 @@ class JZS1(Recurrent):
|
||||
self.truncate_gradient = truncate_gradient
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -566,7 +579,8 @@ class JZS1(Recurrent):
|
||||
sequences=[x_z, x_r, x_h, padded_mask],
|
||||
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
|
||||
non_sequences=[self.U_r, self.U_h],
|
||||
truncate_gradient=self.truncate_gradient)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
return outputs[-1]
|
||||
@ -581,7 +595,8 @@ class JZS1(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(JZS1, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -610,7 +625,7 @@ class JZS2(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal',
|
||||
activation='tanh', inner_activation='sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -619,6 +634,7 @@ class JZS2(Recurrent):
|
||||
self.truncate_gradient = truncate_gradient
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -684,7 +700,9 @@ class JZS2(Recurrent):
|
||||
sequences=[x_z, x_r, x_h, padded_mask],
|
||||
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
|
||||
non_sequences=[self.U_z, self.U_r, self.U_h],
|
||||
truncate_gradient=self.truncate_gradient)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
return outputs[-1]
|
||||
@ -699,7 +717,8 @@ class JZS2(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(JZS2, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@ -728,7 +747,7 @@ class JZS3(Recurrent):
|
||||
init='glorot_uniform', inner_init='orthogonal',
|
||||
activation='tanh', inner_activation='sigmoid',
|
||||
weights=None, truncate_gradient=-1, return_sequences=False,
|
||||
input_dim=None, input_length=None, **kwargs):
|
||||
input_dim=None, input_length=None, go_backwards=False, **kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
self.inner_init = initializations.get(inner_init)
|
||||
@ -737,6 +756,7 @@ class JZS3(Recurrent):
|
||||
self.truncate_gradient = truncate_gradient
|
||||
self.return_sequences = return_sequences
|
||||
self.initial_weights = weights
|
||||
self.go_backwards = go_backwards
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_length = input_length
|
||||
@ -794,8 +814,9 @@ class JZS3(Recurrent):
|
||||
sequences=[x_z, x_r, x_h, padded_mask],
|
||||
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
|
||||
non_sequences=[self.U_z, self.U_r, self.U_h],
|
||||
truncate_gradient=self.truncate_gradient
|
||||
)
|
||||
truncate_gradient=self.truncate_gradient,
|
||||
go_backwards=self.go_backwards)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs.dimshuffle((1, 0, 2))
|
||||
return outputs[-1]
|
||||
@ -810,6 +831,7 @@ class JZS3(Recurrent):
|
||||
"truncate_gradient": self.truncate_gradient,
|
||||
"return_sequences": self.return_sequences,
|
||||
"input_dim": self.input_dim,
|
||||
"input_length": self.input_length}
|
||||
"input_length": self.input_length,
|
||||
"go_backwards": self.go_backwards}
|
||||
base_config = super(JZS3, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
Loading…
Reference in New Issue
Block a user