keras/examples/lstm_text_generation.py

105 lines
3.2 KiB
Python
Raw Normal View History

2015-12-09 02:49:14 +00:00
'''Example script to generate text from Nietzsche's writings.
At least 20 epochs are required before the generated text
starts sounding coherent.
It is recommended to run this script on GPU, as recurrent
networks are quite computationally intensive.
If you try this script on new data, make sure your corpus
has at least ~100k characters. ~1M is better.
'''
from __future__ import print_function
2015-06-16 00:43:25 +00:00
from keras.models import Sequential
2016-05-12 01:45:37 +00:00
from keras.layers import Dense, Activation, Dropout
from keras.layers import LSTM
2016-07-17 00:47:52 +00:00
from keras.optimizers import RMSprop
2016-03-07 01:31:57 +00:00
from keras.utils.data_utils import get_file
2015-06-16 00:43:25 +00:00
import numpy as np
2015-10-05 01:44:49 +00:00
import random
import sys
2015-06-16 00:43:25 +00:00
path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
text = open(path).read().lower()
print('corpus length:', len(text))
2016-06-23 23:17:24 +00:00
chars = sorted(list(set(text)))
2015-06-16 00:43:25 +00:00
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
2015-06-17 05:52:06 +00:00
# cut the text in semi-redundant sequences of maxlen characters
2016-03-19 16:07:15 +00:00
maxlen = 40
2015-06-17 05:52:06 +00:00
step = 3
2015-06-16 00:43:25 +00:00
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
2015-10-05 01:44:49 +00:00
sentences.append(text[i: i + maxlen])
2015-06-16 00:43:25 +00:00
next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))
print('Vectorization...')
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
2015-06-16 00:43:25 +00:00
for i, sentence in enumerate(sentences):
for t, char in enumerate(sentence):
X[i, t, char_indices[char]] = 1
y[i, char_indices[next_chars[i]]] = 1
2015-06-16 00:43:25 +00:00
# build the model: a single LSTM
2015-06-16 00:43:25 +00:00
print('Build model...')
model = Sequential()
2016-07-17 00:47:52 +00:00
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
2015-10-05 01:44:49 +00:00
model.add(Dense(len(chars)))
2015-06-16 00:43:25 +00:00
model.add(Activation('softmax'))
2016-07-17 00:47:52 +00:00
optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
2015-06-16 00:43:25 +00:00
2015-10-05 01:44:49 +00:00
2016-07-17 00:47:52 +00:00
def sample(preds, temperature=1.0):
2015-10-05 01:44:49 +00:00
# helper function to sample an index from a probability array
2016-07-17 00:47:52 +00:00
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)
2015-06-16 00:43:25 +00:00
2015-06-17 05:52:06 +00:00
# train the model, output generated text after each iteration
for iteration in range(1, 60):
2015-06-16 00:43:25 +00:00
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X, y, batch_size=128, nb_epoch=1)
start_index = random.randint(0, len(text) - maxlen - 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
2015-06-16 00:43:25 +00:00
print()
print('----- diversity:', diversity)
generated = ''
2015-10-05 01:44:49 +00:00
sentence = text[start_index: start_index + maxlen]
2015-06-16 00:43:25 +00:00
generated += sentence
print('----- Generating with seed: "' + sentence + '"')
sys.stdout.write(generated)
2015-12-23 09:14:51 +00:00
for i in range(400):
2015-06-16 00:43:25 +00:00
x = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(sentence):
x[0, t, char_indices[char]] = 1.
preds = model.predict(x, verbose=0)[0]
next_index = sample(preds, diversity)
next_char = indices_char[next_index]
generated += next_char
sentence = sentence[1:] + next_char
sys.stdout.write(next_char)
sys.stdout.flush()
print()