Touch-ups to addition RNN example

This commit is contained in:
fchollet 2015-08-17 17:57:20 -07:00
parent 588ce7a7e2
commit d5455154f2

@ -25,14 +25,21 @@ and
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
Theoretically it introduces shorter term dependencies between source and target.
Two digits inverted:
+ One layer JZS1 (128 HN) with 55 iterations = 99% train/test accuracy
+ One layer JZS1 (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
Three digits inverted:
+ One layer JZS1 (128 HN) with 19 iterations = 99% train/test accuracy
+ One layer JZS1 (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
Four digits inverted:
+ One layer JZS1 (128 HN) with 20 iterations = 99% train/test accuracy
+ One layer JZS1 (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
Five digits inverted:
+ One layer JZS1 (128 HN) with 28 iterations = 99% train/test accuracy
+ One layer JZS1 (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
"""
@ -61,9 +68,14 @@ class CharacterTable(object):
X = X.argmax(axis=-1)
return ''.join(self.indices_char[x] for x in X)
class colors:
ok = '\033[92m'
fail = '\033[91m'
close = '\033[0m'
# Parameters for the model and dataset
# Note: Training size is number of queries to generate, not final number of unique queries
TRAINING_SIZE = 800000
TRAINING_SIZE = 50000
DIGITS = 3
INVERT = True
# Try replacing JZS1 with LSTM, GRU, or SimpleRNN
@ -80,7 +92,7 @@ questions = []
expected = []
seen = set()
print('Generating data...')
for i in xrange(TRAINING_SIZE):
while len(questions) < TRAINING_SIZE:
f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in xrange(np.random.randint(1, DIGITS + 1))))
a, b = f(), f()
# Skip any addition questions we've already seen
@ -132,7 +144,7 @@ model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
# Train the model each generation and show predictions against the validation dataset
for iteration in range(1, 60):
for iteration in range(1, 200):
print()
print('-' * 50)
print('Iteration', iteration)
@ -148,5 +160,5 @@ for iteration in range(1, 60):
guess = ctable.decode(preds[0], calc_argmax=False)
print('Q', q[::-1] if INVERT else q)
print('T', correct)
print('' if correct == guess else '', guess)
print(colors.ok + '' + colors.close if correct == guess else colors.fail + '' + colors.close, guess)
print('---')