keras/examples/mnist_hierarchical_rnn.py

88 lines
3.3 KiB
Python
Raw Normal View History

2016-08-14 22:41:04 +00:00
"""This is an example of using Hierarchical RNN (HRNN) to classify MNIST digits.
2016-08-14 22:40:13 +00:00
2016-08-14 22:41:04 +00:00
HRNNs can learn across multiple levels of temporal hiearchy over a complex sequence.
Usually, the first recurrent layer of an HRNN encodes a sentence (e.g. of word vectors)
2016-08-14 22:40:13 +00:00
into a sentence vector. The second recurrent layer then encodes a sequence of
2016-08-14 22:41:04 +00:00
such vectors (encoded by the first layer) into a document vector. This
2016-08-14 22:40:13 +00:00
document vector is considered to preserve both the word-level and
sentence-level structure of the context.
2016-08-14 22:41:04 +00:00
# References
- [A Hierarchical Neural Autoencoder for Paragraphs and Documents](https://web.stanford.edu/~jurafsky/pubs/P15-1107.pdf)
Encodes paragraphs and documents with HRNN.
Results have shown that HRNN outperforms standard
RNNs and may play some role in more sophisticated generation tasks like
summarization or question answering.
- [Hierarchical recurrent neural network for skeleton based action recognition](http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7298714)
Achieved state-of-the-art results on skeleton based action recognition with 3 levels
of bidirectional HRNN combined with fully connected layers.
2016-08-14 22:40:13 +00:00
In the below MNIST example the first LSTM layer first encodes every
2016-08-14 22:41:04 +00:00
column of pixels of shape (28, 1) to a column vector of shape (128,). The second LSTM
layer encodes then these 28 column vectors of shape (28, 128) to a image vector
2016-08-14 22:40:13 +00:00
representing the whole image. A final Dense layer is added for prediction.
After 5 epochs: train acc: 0.9858, val acc: 0.9864
"""
from __future__ import print_function
from keras.datasets import mnist
from keras.models import Model
2016-08-14 22:40:13 +00:00
from keras.layers import Input, Dense, TimeDistributed
from keras.layers import LSTM
from keras.utils import np_utils
2016-08-14 22:41:04 +00:00
# Training parameters.
2016-08-14 22:40:13 +00:00
batch_size = 32
nb_classes = 10
nb_epochs = 5
2016-08-14 22:41:04 +00:00
# Embedding dimensions.
row_hidden = 128
col_hidden = 128
# The data, shuffled and split between train and test sets.
2016-08-14 22:40:13 +00:00
(X_train, y_train), (X_test, y_test) = mnist.load_data()
2016-08-14 22:41:04 +00:00
# Reshapes data to 4D for Hierarchical RNN.
2016-08-14 22:40:13 +00:00
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
2016-08-14 22:41:04 +00:00
# Converts class vectors to binary class matrices.
2016-08-14 22:40:13 +00:00
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
row, col, pixel = X_train.shape[1:]
2016-08-14 22:41:04 +00:00
# 4D input.
2016-08-14 22:40:13 +00:00
x = Input(shape=(row, col, pixel))
2016-08-14 22:41:04 +00:00
# Encodes a row of pixels using TimeDistributed Wrapper.
encoded_rows = TimeDistributed(LSTM(output_dim=row_hidden))(x)
2016-08-14 22:40:13 +00:00
2016-08-14 22:41:04 +00:00
# Encodes columns of encoded rows.
encoded_columns = LSTM(col_hidden)(encoded_rows)
2016-08-14 22:40:13 +00:00
2016-08-14 22:41:04 +00:00
# Final predictions and model.
2016-08-14 22:40:13 +00:00
prediction = Dense(nb_classes, activation='softmax')(encoded_columns)
2016-08-14 22:41:04 +00:00
model = Model(input=x, output=prediction)
2016-08-14 22:40:13 +00:00
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
2016-08-14 22:41:04 +00:00
# Training.
2016-08-14 22:40:13 +00:00
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epochs,
verbose=1, validation_data=(X_test, Y_test))
2016-08-14 22:41:04 +00:00
# Evaluation.
2016-08-14 22:40:13 +00:00
scores = model.evaluate(X_test, Y_test, verbose=0)
2016-08-14 22:41:04 +00:00
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])