Enforce shape invariance for states in RNN loop (#4536)
Fixes some shape invariance errors arising sometimes when building RNNs.
This commit is contained in:
parent
83b90c172c
commit
24d6cca275
@ -1284,6 +1284,8 @@ def rnn(step_function, inputs, initial_states,
|
|||||||
output, new_states = step_function(current_input,
|
output, new_states = step_function(current_input,
|
||||||
tuple(states) +
|
tuple(states) +
|
||||||
tuple(constants))
|
tuple(constants))
|
||||||
|
for state, new_state in zip(states, new_states):
|
||||||
|
new_state.set_shape(state.get_shape())
|
||||||
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]]))
|
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]]))
|
||||||
output = tf.select(tiled_mask_t, output, states[0])
|
output = tf.select(tiled_mask_t, output, states[0])
|
||||||
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
|
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
|
||||||
@ -1295,6 +1297,8 @@ def rnn(step_function, inputs, initial_states,
|
|||||||
output, new_states = step_function(current_input,
|
output, new_states = step_function(current_input,
|
||||||
tuple(states) +
|
tuple(states) +
|
||||||
tuple(constants))
|
tuple(constants))
|
||||||
|
for state, new_state in zip(states, new_states):
|
||||||
|
new_state.set_shape(state.get_shape())
|
||||||
output_ta_t = output_ta_t.write(time, output)
|
output_ta_t = output_ta_t.write(time, output)
|
||||||
return (time + 1, output_ta_t) + tuple(new_states)
|
return (time + 1, output_ta_t) + tuple(new_states)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user