200 lines
7.8 KiB
Python
200 lines
7.8 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from functools import reduce
|
|
import re
|
|
import tarfile
|
|
|
|
import numpy as np
|
|
np.random.seed(1337) # for reproducibility
|
|
|
|
from keras.datasets.data_utils import get_file
|
|
from keras.layers.embeddings import Embedding
|
|
from keras.layers.core import Dense, Merge
|
|
from keras.layers import recurrent
|
|
from keras.models import Sequential
|
|
from keras.preprocessing.sequence import pad_sequences
|
|
|
|
'''
|
|
Trains two recurrent neural networks based upon a story and a question.
|
|
The resulting merged vector is then queried to answer a range of bAbI tasks.
|
|
|
|
The results are comparable to those for an LSTM model provided in Weston et al.:
|
|
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"
|
|
http://arxiv.org/abs/1502.05698
|
|
|
|
Task Number | FB LSTM Baseline | Keras QA
|
|
--- | --- | ---
|
|
QA1 - Single Supporting Fact | 50 | 52.1
|
|
QA2 - Two Supporting Facts | 20 | 37.0
|
|
QA3 - Three Supporting Facts | 20 | 20.5
|
|
QA4 - Two Arg. Relations | 61 | 62.9
|
|
QA5 - Three Arg. Relations | 70 | 61.9
|
|
QA6 - Yes/No Questions | 48 | 50.7
|
|
QA7 - Counting | 49 | 78.9
|
|
QA8 - Lists/Sets | 45 | 77.2
|
|
QA9 - Simple Negation | 64 | 64.0
|
|
QA10 - Indefinite Knowledge | 44 | 47.7
|
|
QA11 - Basic Coreference | 72 | 74.9
|
|
QA12 - Conjunction | 74 | 76.4
|
|
QA13 - Compound Coreference | 94 | 94.4
|
|
QA14 - Time Reasoning | 27 | 34.8
|
|
QA15 - Basic Deduction | 21 | 32.4
|
|
QA16 - Basic Induction | 23 | 50.6
|
|
QA17 - Positional Reasoning | 51 | 49.1
|
|
QA18 - Size Reasoning | 52 | 90.8
|
|
QA19 - Path Finding | 8 | 9.0
|
|
QA20 - Agent's Motivations | 91 | 90.7
|
|
|
|
For the resources related to the bAbI project, refer to:
|
|
https://research.facebook.com/researchers/1543934539189348
|
|
|
|
Notes:
|
|
|
|
- With default word, sentence, and query vector sizes, the GRU model achieves:
|
|
- 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
|
|
- 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
|
|
In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.
|
|
|
|
- The task does not traditionally parse the question separately. This likely
|
|
improves accuracy and is a good example of merging two RNNs.
|
|
|
|
- The word vector embeddings are not shared between the story and question RNNs.
|
|
|
|
- See how the accuracy changes given 10,000 training samples (en-10k) instead
|
|
of only 1000. 1000 was used in order to be comparable to the original paper.
|
|
|
|
- Experiment with GRU, LSTM, and JZS1-3 as they give subtly different results.
|
|
|
|
- The length and noise (i.e. 'useless' story components) impact the ability for
|
|
LSTMs / GRUs to provide the correct answer. Given only the supporting facts,
|
|
these RNNs can achieve 100% accuracy on many tasks. Memory networks and neural
|
|
networks that use attentional processes can efficiently search through this
|
|
noise to find the relevant statements, improving performance substantially.
|
|
This becomes especially obvious on QA2 and QA3, both far longer than QA1.
|
|
'''
|
|
|
|
|
|
def tokenize(sent):
|
|
'''Return the tokens of a sentence including punctuation.
|
|
|
|
>>> tokenize('Bob dropped the apple. Where is the apple?')
|
|
['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
|
|
'''
|
|
return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]
|
|
|
|
|
|
def parse_stories(lines, only_supporting=False):
|
|
'''Parse stories provided in the bAbi tasks format
|
|
|
|
If only_supporting is true, only the sentences that support the answer are kept.
|
|
'''
|
|
data = []
|
|
story = []
|
|
for line in lines:
|
|
line = line.decode('utf-8').strip()
|
|
nid, line = line.split(' ', 1)
|
|
nid = int(nid)
|
|
if nid == 1:
|
|
story = []
|
|
if '\t' in line:
|
|
q, a, supporting = line.split('\t')
|
|
q = tokenize(q)
|
|
substory = None
|
|
if only_supporting:
|
|
# Only select the related substory
|
|
supporting = map(int, supporting.split())
|
|
substory = [story[i - 1] for i in supporting]
|
|
else:
|
|
# Provide all the substories
|
|
substory = [x for x in story if x]
|
|
data.append((substory, q, a))
|
|
story.append('')
|
|
else:
|
|
sent = tokenize(line)
|
|
story.append(sent)
|
|
return data
|
|
|
|
|
|
def get_stories(f, only_supporting=False, max_length=None):
|
|
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story.
|
|
|
|
If max_length is supplied, any stories longer than max_length tokens will be discarded.
|
|
'''
|
|
data = parse_stories(f.readlines(), only_supporting=only_supporting)
|
|
flatten = lambda data: reduce(lambda x, y: x + y, data)
|
|
data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length]
|
|
return data
|
|
|
|
|
|
def vectorize_stories(data):
|
|
X = []
|
|
Xq = []
|
|
Y = []
|
|
for story, query, answer in data:
|
|
x = [word_idx[w] for w in story]
|
|
xq = [word_idx[w] for w in query]
|
|
y = np.zeros(vocab_size)
|
|
y[word_idx[answer]] = 1
|
|
X.append(x)
|
|
Xq.append(xq)
|
|
Y.append(y)
|
|
return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y)
|
|
|
|
RNN = recurrent.GRU
|
|
EMBED_HIDDEN_SIZE = 50
|
|
SENT_HIDDEN_SIZE = 100
|
|
QUERY_HIDDEN_SIZE = 100
|
|
BATCH_SIZE = 32
|
|
EPOCHS = 20
|
|
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE))
|
|
|
|
path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz')
|
|
tar = tarfile.open(path)
|
|
# Default QA1 with 1000 samples
|
|
# challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt'
|
|
# QA1 with 10,000 samples
|
|
# challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt'
|
|
# QA2 with 1000 samples
|
|
challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
|
|
# QA2 with 10,000 samples
|
|
# challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt'
|
|
train = get_stories(tar.extractfile(challenge.format('train')))
|
|
test = get_stories(tar.extractfile(challenge.format('test')))
|
|
|
|
vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test)))
|
|
# Reserve 0 for masking via pad_sequences
|
|
vocab_size = len(vocab) + 1
|
|
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
|
|
story_maxlen = max(map(len, (x for x, _, _ in train + test)))
|
|
query_maxlen = max(map(len, (x for _, x, _ in train + test)))
|
|
|
|
X, Xq, Y = vectorize_stories(train)
|
|
tX, tXq, tY = vectorize_stories(test)
|
|
|
|
print('vocab = {}'.format(vocab))
|
|
print('X.shape = {}'.format(X.shape))
|
|
print('Xq.shape = {}'.format(Xq.shape))
|
|
print('Y.shape = {}'.format(Y.shape))
|
|
print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))
|
|
|
|
print('Build model...')
|
|
|
|
sentrnn = Sequential()
|
|
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
|
|
sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False))
|
|
|
|
qrnn = Sequential()
|
|
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
|
|
qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False))
|
|
|
|
model = Sequential()
|
|
model.add(Merge([sentrnn, qrnn], mode='concat'))
|
|
model.add(Dense(vocab_size, activation='softmax'))
|
|
|
|
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical')
|
|
|
|
print('Training')
|
|
model.fit([X, Xq], Y, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, validation_split=0.05, show_accuracy=True)
|
|
loss, acc = model.evaluate([tX, tXq], tY, batch_size=BATCH_SIZE, show_accuracy=True)
|
|
print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc))
|