keras/examples/addition_rnn.py

165 lines
5.8 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
from __future__ import print_function
from keras.models import Sequential, slice_X
from keras.layers.core import Activation, Dense, RepeatVector
from keras.layers import recurrent
from sklearn.utils import shuffle
import numpy as np
"""
An implementation of sequence to sequence learning for performing addition
Input: "535+61"
Output: "596"
Padding is handled by using a repeated sentinel character (space)
By default, the JZS1 recurrent neural network is used
JZS1 was an "evolved" recurrent neural network performing well on arithmetic benchmark in:
"An Empirical Exploration of Recurrent Network Architectures"
http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf
Input may optionally be inverted, shown to increase performance in many tasks in:
"Learning to Execute"
http://arxiv.org/abs/1410.4615
and
"Sequence to Sequence Learning with Neural Networks"
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
Theoretically it introduces shorter term dependencies between source and target.
2015-08-18 00:57:20 +00:00
Two digits inverted:
2015-08-18 00:57:20 +00:00
+ One layer JZS1 (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
Three digits inverted:
2015-08-18 00:57:20 +00:00
+ One layer JZS1 (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
Four digits inverted:
2015-08-18 00:57:20 +00:00
+ One layer JZS1 (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
Five digits inverted:
2015-08-18 00:57:20 +00:00
+ One layer JZS1 (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
"""
class CharacterTable(object):
"""
Given a set of characters:
+ Encode them to a one hot integer representation
+ Decode the one hot integer representation to their character output
+ Decode a vector of probabilties to their character output
"""
def __init__(self, chars, maxlen):
self.chars = sorted(set(chars))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen
def encode(self, C, maxlen=None):
maxlen = maxlen if maxlen else self.maxlen
X = np.zeros((maxlen, len(self.chars)))
for i, c in enumerate(C):
X[i, self.char_indices[c]] = 1
return X
def decode(self, X, calc_argmax=True):
if calc_argmax:
X = X.argmax(axis=-1)
return ''.join(self.indices_char[x] for x in X)
2015-08-18 00:57:20 +00:00
class colors:
ok = '\033[92m'
fail = '\033[91m'
close = '\033[0m'
# Parameters for the model and dataset
2015-08-18 00:57:20 +00:00
TRAINING_SIZE = 50000
DIGITS = 3
INVERT = True
# Try replacing JZS1 with LSTM, GRU, or SimpleRNN
RNN = recurrent.JZS1
HIDDEN_SIZE = 128
BATCH_SIZE = 128
LAYERS = 1
MAXLEN = DIGITS + 1 + DIGITS
chars = '0123456789+ '
ctable = CharacterTable(chars, MAXLEN)
questions = []
expected = []
seen = set()
print('Generating data...')
2015-08-18 00:57:20 +00:00
while len(questions) < TRAINING_SIZE:
f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in xrange(np.random.randint(1, DIGITS + 1))))
a, b = f(), f()
# Skip any addition questions we've already seen
# Also skip any such that X+Y == Y+X (hence the sorting)
key = tuple(sorted((a, b)))
if key in seen:
continue
seen.add(key)
# Pad the data with spaces such that it is always MAXLEN
q = '{}+{}'.format(a, b)
query = q + ' ' * (MAXLEN - len(q))
ans = str(a + b)
# Answers can be of maximum size DIGITS + 1
ans += ' ' * (DIGITS + 1 - len(ans))
if INVERT:
query = query[::-1]
questions.append(query)
expected.append(ans)
print('Total addition questions:', len(questions))
print('Vectorization...')
X = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
X[i] = ctable.encode(sentence, maxlen=MAXLEN)
for i, sentence in enumerate(expected):
y[i] = ctable.encode(sentence, maxlen=DIGITS + 1)
# Shuffle (X, y) in unison as the later parts of X will almost all be larger digits
X, y = shuffle(X, y)
# Explicitly set apart 10% for validation data that we never train over
split_at = len(X) - len(X) / 10
(X_train, X_val) = (slice_X(X, 0, split_at), slice_X(X, split_at))
(y_train, y_val) = (y[:split_at], y[split_at:])
print('Build model...')
model = Sequential()
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE
model.add(RNN(len(chars), HIDDEN_SIZE))
# For the decoder's input, we repeat the encoded input for each time step
model.add(RepeatVector(DIGITS + 1))
# The decoder RNN could be multiple layers stacked or a single layer
for _ in xrange(LAYERS):
model.add(RNN(HIDDEN_SIZE, HIDDEN_SIZE, return_sequences=True))
# For each of step of the output sequence, decide which character should be chosen
model.add(Dense(HIDDEN_SIZE, len(chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
# Train the model each generation and show predictions against the validation dataset
2015-08-18 00:57:20 +00:00
for iteration in range(1, 200):
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(X_val, y_val), show_accuracy=True)
###
# Select 10 samples from the validation set at random so we can visualize errors
for i in xrange(10):
ind = np.random.randint(0, len(X_val))
rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])]
preds = model.predict_classes(rowX, verbose=0)
q = ctable.decode(rowX[0])
correct = ctable.decode(rowy[0])
guess = ctable.decode(preds[0], calc_argmax=False)
print('Q', q[::-1] if INVERT else q)
print('T', correct)
2015-08-18 00:57:20 +00:00
print(colors.ok + '' + colors.close if correct == guess else colors.fail + '' + colors.close, guess)
print('---')