2015-12-09 02:49:14 +00:00
|
|
|
'''Example script to generate text from Nietzsche's writings.
|
|
|
|
|
|
|
|
At least 20 epochs are required before the generated text
|
|
|
|
starts sounding coherent.
|
|
|
|
|
|
|
|
It is recommended to run this script on GPU, as recurrent
|
|
|
|
networks are quite computationally intensive.
|
|
|
|
|
|
|
|
If you try this script on new data, make sure your corpus
|
|
|
|
has at least ~100k characters. ~1M is better.
|
|
|
|
'''
|
|
|
|
|
2015-06-16 00:54:59 +00:00
|
|
|
from __future__ import print_function
|
2015-06-16 00:43:25 +00:00
|
|
|
from keras.models import Sequential
|
2017-01-06 17:25:03 +00:00
|
|
|
from keras.layers import Dense, Activation
|
2016-05-12 01:45:37 +00:00
|
|
|
from keras.layers import LSTM
|
2016-07-17 00:47:52 +00:00
|
|
|
from keras.optimizers import RMSprop
|
2016-03-07 01:31:57 +00:00
|
|
|
from keras.utils.data_utils import get_file
|
2015-06-16 00:43:25 +00:00
|
|
|
import numpy as np
|
2015-10-05 01:44:49 +00:00
|
|
|
import random
|
|
|
|
import sys
|
2015-06-16 00:43:25 +00:00
|
|
|
|
2017-03-12 03:44:29 +00:00
|
|
|
path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
|
2015-06-16 00:43:25 +00:00
|
|
|
text = open(path).read().lower()
|
|
|
|
print('corpus length:', len(text))
|
|
|
|
|
2016-06-23 23:17:24 +00:00
|
|
|
chars = sorted(list(set(text)))
|
2015-06-16 00:43:25 +00:00
|
|
|
print('total chars:', len(chars))
|
|
|
|
char_indices = dict((c, i) for i, c in enumerate(chars))
|
|
|
|
indices_char = dict((i, c) for i, c in enumerate(chars))
|
|
|
|
|
2015-06-17 05:52:06 +00:00
|
|
|
# cut the text in semi-redundant sequences of maxlen characters
|
2016-03-19 16:07:15 +00:00
|
|
|
maxlen = 40
|
2015-06-17 05:52:06 +00:00
|
|
|
step = 3
|
2015-06-16 00:43:25 +00:00
|
|
|
sentences = []
|
|
|
|
next_chars = []
|
|
|
|
for i in range(0, len(text) - maxlen, step):
|
2015-10-05 01:44:49 +00:00
|
|
|
sentences.append(text[i: i + maxlen])
|
2015-06-16 00:43:25 +00:00
|
|
|
next_chars.append(text[i + maxlen])
|
|
|
|
print('nb sequences:', len(sentences))
|
|
|
|
|
|
|
|
print('Vectorization...')
|
2015-06-22 21:24:46 +00:00
|
|
|
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
|
|
|
|
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
|
2015-06-16 00:43:25 +00:00
|
|
|
for i, sentence in enumerate(sentences):
|
|
|
|
for t, char in enumerate(sentence):
|
2015-06-22 21:24:46 +00:00
|
|
|
X[i, t, char_indices[char]] = 1
|
|
|
|
y[i, char_indices[next_chars[i]]] = 1
|
2015-06-16 00:43:25 +00:00
|
|
|
|
|
|
|
|
2016-08-18 21:03:26 +00:00
|
|
|
# build the model: a single LSTM
|
2015-06-16 00:43:25 +00:00
|
|
|
print('Build model...')
|
|
|
|
model = Sequential()
|
2016-07-17 00:47:52 +00:00
|
|
|
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
|
2015-10-05 01:44:49 +00:00
|
|
|
model.add(Dense(len(chars)))
|
2015-06-16 00:43:25 +00:00
|
|
|
model.add(Activation('softmax'))
|
|
|
|
|
2016-07-17 00:47:52 +00:00
|
|
|
optimizer = RMSprop(lr=0.01)
|
|
|
|
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
|
2015-06-16 00:43:25 +00:00
|
|
|
|
2015-10-05 01:44:49 +00:00
|
|
|
|
2016-07-17 00:47:52 +00:00
|
|
|
def sample(preds, temperature=1.0):
|
2015-10-05 01:44:49 +00:00
|
|
|
# helper function to sample an index from a probability array
|
2016-07-17 00:47:52 +00:00
|
|
|
preds = np.asarray(preds).astype('float64')
|
|
|
|
preds = np.log(preds) / temperature
|
|
|
|
exp_preds = np.exp(preds)
|
|
|
|
preds = exp_preds / np.sum(exp_preds)
|
|
|
|
probas = np.random.multinomial(1, preds, 1)
|
|
|
|
return np.argmax(probas)
|
2015-06-16 00:43:25 +00:00
|
|
|
|
2015-06-17 05:52:06 +00:00
|
|
|
# train the model, output generated text after each iteration
|
|
|
|
for iteration in range(1, 60):
|
2015-06-16 00:43:25 +00:00
|
|
|
print()
|
|
|
|
print('-' * 50)
|
|
|
|
print('Iteration', iteration)
|
2017-03-26 14:27:49 +00:00
|
|
|
model.fit(X, y,
|
|
|
|
batch_size=128,
|
|
|
|
epochs=1)
|
2015-06-16 00:43:25 +00:00
|
|
|
|
|
|
|
start_index = random.randint(0, len(text) - maxlen - 1)
|
|
|
|
|
2015-07-20 22:42:14 +00:00
|
|
|
for diversity in [0.2, 0.5, 1.0, 1.2]:
|
2015-06-16 00:43:25 +00:00
|
|
|
print()
|
|
|
|
print('----- diversity:', diversity)
|
|
|
|
|
|
|
|
generated = ''
|
2015-10-05 01:44:49 +00:00
|
|
|
sentence = text[start_index: start_index + maxlen]
|
2015-06-16 00:43:25 +00:00
|
|
|
generated += sentence
|
|
|
|
print('----- Generating with seed: "' + sentence + '"')
|
|
|
|
sys.stdout.write(generated)
|
|
|
|
|
2015-12-23 09:14:51 +00:00
|
|
|
for i in range(400):
|
2015-06-16 00:43:25 +00:00
|
|
|
x = np.zeros((1, maxlen, len(chars)))
|
|
|
|
for t, char in enumerate(sentence):
|
|
|
|
x[0, t, char_indices[char]] = 1.
|
|
|
|
|
|
|
|
preds = model.predict(x, verbose=0)[0]
|
|
|
|
next_index = sample(preds, diversity)
|
|
|
|
next_char = indices_char[next_index]
|
|
|
|
|
|
|
|
generated += next_char
|
|
|
|
sentence = sentence[1:] + next_char
|
|
|
|
|
|
|
|
sys.stdout.write(next_char)
|
|
|
|
sys.stdout.flush()
|
2015-07-20 22:42:14 +00:00
|
|
|
print()
|