Merge pull request #1824 from farizrahman4u/patch-2
bAbI: Doubling the FB LSTM baseline :)
This commit is contained in:
commit
55d9374961
@ -7,8 +7,8 @@ http://arxiv.org/abs/1502.05698
|
||||
|
||||
Task Number | FB LSTM Baseline | Keras QA
|
||||
--- | --- | ---
|
||||
QA1 - Single Supporting Fact | 50 | 52.1
|
||||
QA2 - Two Supporting Facts | 20 | 37.0
|
||||
QA1 - Single Supporting Fact | 50 | 100.0
|
||||
QA2 - Two Supporting Facts | 20 | 50.0
|
||||
QA3 - Three Supporting Facts | 20 | 20.5
|
||||
QA4 - Two Arg. Relations | 61 | 62.9
|
||||
QA5 - Three Arg. Relations | 70 | 61.9
|
||||
@ -34,8 +34,8 @@ https://research.facebook.com/researchers/1543934539189348
|
||||
Notes:
|
||||
|
||||
- With default word, sentence, and query vector sizes, the GRU model achieves:
|
||||
- 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
|
||||
- 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
|
||||
- 100% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
|
||||
- 50% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
|
||||
In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.
|
||||
|
||||
- The task does not traditionally parse the question separately. This likely
|
||||
@ -138,12 +138,12 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
|
||||
Y.append(y)
|
||||
return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y)
|
||||
|
||||
RNN = recurrent.GRU
|
||||
RNN = recurrent.LSTM
|
||||
EMBED_HIDDEN_SIZE = 50
|
||||
SENT_HIDDEN_SIZE = 100
|
||||
QUERY_HIDDEN_SIZE = 100
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS = 20
|
||||
EPOCHS = 40
|
||||
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE))
|
||||
|
||||
path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz')
|
||||
@ -178,15 +178,19 @@ print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))
|
||||
print('Build model...')
|
||||
|
||||
sentrnn = Sequential()
|
||||
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
|
||||
sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False))
|
||||
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=story_maxlen, mask_zero=True))
|
||||
sentrnn.add(Dropout(0.3))
|
||||
|
||||
qrnn = Sequential()
|
||||
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
|
||||
qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False))
|
||||
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=query_maxlen))
|
||||
qrnn.add(Dropout(0.3))
|
||||
qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
|
||||
qrnn.add(RepeatVector(story_maxlen))
|
||||
|
||||
model = Sequential()
|
||||
model.add(Merge([sentrnn, qrnn], mode='concat'))
|
||||
model.add(Merge([sentrnn, qrnn], mode='sum'))
|
||||
model.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
|
||||
model.add(Dropout(0.3))
|
||||
model.add(Dense(vocab_size, activation='softmax'))
|
||||
|
||||
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical')
|
||||
|
Loading…
Reference in New Issue
Block a user