2016-05-05 18:01:48 +00:00
|
|
|
'''Compare LSTM implementations on the IMDB sentiment classification task.
|
|
|
|
|
|
|
|
consume_less='cpu' preprocesses input to the LSTM which typically results in
|
|
|
|
faster computations at the expense of increased peak memory usage as the
|
|
|
|
preprocessed input must be kept in memory.
|
|
|
|
|
|
|
|
consume_less='mem' does away with the preprocessing, meaning that it might take
|
|
|
|
a little longer, but should require less peak memory.
|
|
|
|
|
|
|
|
consume_less='gpu' concatenates the input, output and forget gate's weights
|
|
|
|
into one, large matrix, resulting in faster computation time as the GPU can
|
|
|
|
utilize more cores, at the expense of reduced regularization because the same
|
2016-05-05 18:17:25 +00:00
|
|
|
dropout is shared across the gates.
|
|
|
|
|
|
|
|
Note that the relative performance of the different `consume_less` modes
|
|
|
|
can vary depending on your device, your model and the size of your data.
|
2016-05-05 18:01:48 +00:00
|
|
|
'''
|
|
|
|
|
|
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
from keras.preprocessing import sequence
|
|
|
|
from keras.models import Sequential
|
2016-05-05 18:17:25 +00:00
|
|
|
from keras.layers import Embedding, Dense, LSTM
|
2016-05-05 18:01:48 +00:00
|
|
|
from keras.datasets import imdb
|
|
|
|
|
|
|
|
max_features = 20000
|
|
|
|
max_length = 80
|
2016-05-05 18:17:25 +00:00
|
|
|
embedding_dim = 256
|
|
|
|
batch_size = 128
|
2016-05-05 18:01:48 +00:00
|
|
|
epochs = 10
|
|
|
|
modes = ['cpu', 'mem', 'gpu']
|
|
|
|
|
|
|
|
print('Loading data...')
|
2017-02-15 00:08:30 +00:00
|
|
|
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=max_features)
|
2016-05-05 18:01:48 +00:00
|
|
|
X_train = sequence.pad_sequences(X_train, max_length)
|
|
|
|
X_test = sequence.pad_sequences(X_test, max_length)
|
|
|
|
|
|
|
|
# Compile and train different models while meauring performance.
|
|
|
|
results = []
|
|
|
|
for mode in modes:
|
2016-05-05 18:17:25 +00:00
|
|
|
print('Testing mode: consume_less="{}"'.format(mode))
|
2016-05-05 18:01:48 +00:00
|
|
|
|
|
|
|
model = Sequential()
|
2016-05-05 18:17:25 +00:00
|
|
|
model.add(Embedding(max_features, embedding_dim, input_length=max_length, dropout=0.2))
|
|
|
|
model.add(LSTM(embedding_dim, dropout_W=0.2, dropout_U=0.2, consume_less=mode))
|
2016-05-05 18:01:48 +00:00
|
|
|
model.add(Dense(1, activation='sigmoid'))
|
2016-05-05 18:17:25 +00:00
|
|
|
model.compile(loss='binary_crossentropy',
|
|
|
|
optimizer='adam',
|
|
|
|
metrics=['accuracy'])
|
2016-05-05 18:01:48 +00:00
|
|
|
|
|
|
|
start_time = time.time()
|
2016-05-05 18:17:25 +00:00
|
|
|
history = model.fit(X_train, y_train,
|
|
|
|
batch_size=batch_size,
|
2017-02-15 00:08:30 +00:00
|
|
|
epochs=epochs,
|
2016-05-05 18:17:25 +00:00
|
|
|
validation_data=(X_test, y_test))
|
2016-05-05 18:01:48 +00:00
|
|
|
average_time_per_epoch = (time.time() - start_time) / epochs
|
|
|
|
|
|
|
|
results.append((history, average_time_per_epoch))
|
|
|
|
|
|
|
|
# Compare models' accuracy, loss and elapsed time per epoch.
|
|
|
|
plt.style.use('ggplot')
|
2016-05-05 18:17:25 +00:00
|
|
|
ax1 = plt.subplot2grid((2, 2), (0, 0))
|
2016-05-05 18:01:48 +00:00
|
|
|
ax1.set_title('Accuracy')
|
|
|
|
ax1.set_ylabel('Validation Accuracy')
|
|
|
|
ax1.set_xlabel('Epochs')
|
2016-05-05 18:17:25 +00:00
|
|
|
ax2 = plt.subplot2grid((2, 2), (1, 0))
|
2016-05-05 18:01:48 +00:00
|
|
|
ax2.set_title('Loss')
|
|
|
|
ax2.set_ylabel('Validation Loss')
|
|
|
|
ax2.set_xlabel('Epochs')
|
2016-05-05 18:17:25 +00:00
|
|
|
ax3 = plt.subplot2grid((2, 2), (0, 1), rowspan=2)
|
2016-05-05 18:01:48 +00:00
|
|
|
ax3.set_title('Time')
|
|
|
|
ax3.set_ylabel('Seconds')
|
|
|
|
for mode, result in zip(modes, results):
|
|
|
|
ax1.plot(result[0].epoch, result[0].history['val_acc'], label=mode)
|
|
|
|
ax2.plot(result[0].epoch, result[0].history['val_loss'], label=mode)
|
|
|
|
ax1.legend()
|
|
|
|
ax2.legend()
|
2016-05-05 18:17:25 +00:00
|
|
|
ax3.bar(np.arange(len(results)), [x[1] for x in results],
|
|
|
|
tick_label=modes, align='center')
|
2016-05-05 18:01:48 +00:00
|
|
|
plt.tight_layout()
|
|
|
|
plt.show()
|