Merge pull request #1824 from farizrahman4u/patch-2

bAbI: Doubling the FB LSTM baseline :)
This commit is contained in:
François Chollet 2016-02-25 12:51:04 -08:00
commit 55d9374961

@ -7,8 +7,8 @@ http://arxiv.org/abs/1502.05698
Task Number | FB LSTM Baseline | Keras QA
--- | --- | ---
QA1 - Single Supporting Fact | 50 | 52.1
QA2 - Two Supporting Facts | 20 | 37.0
QA1 - Single Supporting Fact | 50 | 100.0
QA2 - Two Supporting Facts | 20 | 50.0
QA3 - Three Supporting Facts | 20 | 20.5
QA4 - Two Arg. Relations | 61 | 62.9
QA5 - Three Arg. Relations | 70 | 61.9
@ -34,8 +34,8 @@ https://research.facebook.com/researchers/1543934539189348
Notes:
- With default word, sentence, and query vector sizes, the GRU model achieves:
- 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
- 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
- 100% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
- 50% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.
- The task does not traditionally parse the question separately. This likely
@ -138,12 +138,12 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
Y.append(y)
return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y)
RNN = recurrent.GRU
RNN = recurrent.LSTM
EMBED_HIDDEN_SIZE = 50
SENT_HIDDEN_SIZE = 100
QUERY_HIDDEN_SIZE = 100
BATCH_SIZE = 32
EPOCHS = 20
EPOCHS = 40
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE))
path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz')
@ -178,15 +178,19 @@ print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))
print('Build model...')
sentrnn = Sequential()
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False))
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=story_maxlen, mask_zero=True))
sentrnn.add(Dropout(0.3))
qrnn = Sequential()
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False))
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=query_maxlen))
qrnn.add(Dropout(0.3))
qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
qrnn.add(RepeatVector(story_maxlen))
model = Sequential()
model.add(Merge([sentrnn, qrnn], mode='concat'))
model.add(Merge([sentrnn, qrnn], mode='sum'))
model.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
model.add(Dropout(0.3))
model.add(Dense(vocab_size, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical')