Touch-ups in IRNN example

This commit is contained in:
fchollet 2015-07-02 15:21:37 -07:00
parent e50462e71f
commit 12a5c6fe46

@ -7,12 +7,13 @@ from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.initializations import normal, identity
from keras.layers.recurrent import SimpleRNN
from keras.layers.recurrent import SimpleRNN, LSTM
from keras.optimizers import RMSprop
from keras.utils import np_utils
'''
This is a variant of IRNN experiment with sequential MNIST in
This is a reproduction of the IRNN experiment
with pixel-by-pixel sequential MNIST in
"A Simple Way to Initialize Recurrent Networks of Rectified Linear Units "
by Quoc V. Le, Navdeep Jaitly, Geoffrey E. Hinton
@ -22,24 +23,20 @@ from keras.utils import np_utils
Optimizer is replaced with RMSprop which give more stable and steady
improvement.
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_irnn_rmsprop.py
Reaches to 80% train/test accuracy and 0.55 train/test loss after 70 epochs
0.80 train/test accuracy and 0.55 train/test loss after 70 epochs
(it's still underfitting at that point, though).
About 15 minuts per epoch on a GRID K520 GPU.
'''
batch_size = 16
batch_size = 32
nb_classes = 10
nb_epochs = 200
hidden_units = 100
learning_rate = 1e-6
clip_norm = 1.0
BPTT_trancate = 28*28
BPTT_truncate = 28*28
# the data, shuffled and split between tran and test sets
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], -1, 1)
@ -56,11 +53,12 @@ print(X_test.shape[0], 'test samples')
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
print('Evaluate IRNN...')
model = Sequential()
model.add(SimpleRNN(input_dim=1, output_dim=hidden_units,
init=lambda shape: normal(shape, scale=0.001),
inner_init=lambda shape: identity(shape, scale=1.0),
activation='relu', truncate_gradient=BPTT_trancate))
activation='relu', truncate_gradient=BPTT_truncate))
model.add(Dense(hidden_units, nb_classes))
model.add(Activation('softmax'))
rmsprop = RMSprop(lr=learning_rate)
@ -69,6 +67,21 @@ model.compile(loss='categorical_crossentropy', optimizer=rmsprop)
model.fit(X_train, Y_train, batch_size=16, nb_epoch=nb_epochs,
show_accuracy=True, verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
scores = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0)
print('IRNN test score:', scores[0])
print('IRNN test accuracy:', scores[1])
print('Compare to LSTM...')
model = Sequential()
model.add(LSTM(1, hidden_units))
model.add(Dense(hidden_units, nb_classes))
model.add(Activation('softmax'))
rmsprop = RMSprop(lr=learning_rate)
model.compile(loss='categorical_crossentropy', optimizer=rmsprop)
model.fit(X_train, Y_train, batch_size=16, nb_epoch=nb_epochs,
show_accuracy=True, verbose=1, validation_data=(X_test, Y_test))
scores = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0)
print('LSTM test score:', scores[0])
print('LSTM test accuracy:', scores[1])