Spelling errors (#6232)
This commit is contained in:
parent
9eb7ecd3e5
commit
5bd3976e79
@ -78,7 +78,7 @@ INVERT = True
|
|||||||
|
|
||||||
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
|
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
|
||||||
# int is DIGITS.
|
# int is DIGITS.
|
||||||
MAxLEN = DIGITS + 1 + DIGITS
|
MAXLEN = DIGITS + 1 + DIGITS
|
||||||
|
|
||||||
# All the numbers, plus sign and space for padding.
|
# All the numbers, plus sign and space for padding.
|
||||||
chars = '0123456789+ '
|
chars = '0123456789+ '
|
||||||
@ -98,9 +98,9 @@ while len(questions) < TRAINING_SIZE:
|
|||||||
if key in seen:
|
if key in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
# Pad the data with spaces such that it is always MAxLEN.
|
# Pad the data with spaces such that it is always MAXLEN.
|
||||||
q = '{}+{}'.format(a, b)
|
q = '{}+{}'.format(a, b)
|
||||||
query = q + ' ' * (MAxLEN - len(q))
|
query = q + ' ' * (MAXLEN - len(q))
|
||||||
ans = str(a + b)
|
ans = str(a + b)
|
||||||
# Answers can be of maximum size DIGITS + 1.
|
# Answers can be of maximum size DIGITS + 1.
|
||||||
ans += ' ' * (DIGITS + 1 - len(ans))
|
ans += ' ' * (DIGITS + 1 - len(ans))
|
||||||
@ -113,10 +113,10 @@ while len(questions) < TRAINING_SIZE:
|
|||||||
print('Total addition questions:', len(questions))
|
print('Total addition questions:', len(questions))
|
||||||
|
|
||||||
print('Vectorization...')
|
print('Vectorization...')
|
||||||
x = np.zeros((len(questions), MAxLEN, len(chars)), dtype=np.bool)
|
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
|
||||||
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
|
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
|
||||||
for i, sentence in enumerate(questions):
|
for i, sentence in enumerate(questions):
|
||||||
x[i] = ctable.encode(sentence, MAxLEN)
|
x[i] = ctable.encode(sentence, MAXLEN)
|
||||||
for i, sentence in enumerate(expected):
|
for i, sentence in enumerate(expected):
|
||||||
y[i] = ctable.encode(sentence, DIGITS + 1)
|
y[i] = ctable.encode(sentence, DIGITS + 1)
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ model = Sequential()
|
|||||||
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE.
|
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE.
|
||||||
# Note: In a situation where your input sequences have a variable length,
|
# Note: In a situation where your input sequences have a variable length,
|
||||||
# use input_shape=(None, num_feature).
|
# use input_shape=(None, num_feature).
|
||||||
model.add(RNN(HIDDEN_SIZE, input_shape=(MAxLEN, len(chars))))
|
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
|
||||||
# As the decoder RNN's input, repeatedly provide with the last hidden state of
|
# As the decoder RNN's input, repeatedly provide with the last hidden state of
|
||||||
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
|
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
|
||||||
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
|
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
|
||||||
|
Loading…
Reference in New Issue
Block a user