Enforce shape invariance for states in RNN loop (#4536)

Fixes some shape invariance errors arising sometimes when building RNNs.
This commit is contained in:
Javier Dehesa 2016-11-29 19:31:14 +00:00 committed by François Chollet
parent 83b90c172c
commit 24d6cca275

@ -1284,6 +1284,8 @@ def rnn(step_function, inputs, initial_states,
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]]))
output = tf.select(tiled_mask_t, output, states[0])
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
@ -1295,6 +1297,8 @@ def rnn(step_function, inputs, initial_states,
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)