Fix babi_rnn example
This commit is contained in:
parent
c429e651c1
commit
62f9053330
@ -179,11 +179,12 @@ print('Build model...')
|
|||||||
|
|
||||||
sentrnn = Sequential()
|
sentrnn = Sequential()
|
||||||
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE,
|
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE,
|
||||||
input_length=story_maxlen, mask_zero=True))
|
input_length=story_maxlen))
|
||||||
sentrnn.add(Dropout(0.3))
|
sentrnn.add(Dropout(0.3))
|
||||||
|
|
||||||
qrnn = Sequential()
|
qrnn = Sequential()
|
||||||
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=query_maxlen))
|
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE,
|
||||||
|
input_length=query_maxlen))
|
||||||
qrnn.add(Dropout(0.3))
|
qrnn.add(Dropout(0.3))
|
||||||
qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
|
qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
|
||||||
qrnn.add(RepeatVector(story_maxlen))
|
qrnn.add(RepeatVector(story_maxlen))
|
||||||
|
Loading…
Reference in New Issue
Block a user