diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 2c1c06e4d..5b56e705d 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1284,6 +1284,8 @@ def rnn(step_function, inputs, initial_states, output, new_states = step_function(current_input, tuple(states) + tuple(constants)) + for state, new_state in zip(states, new_states): + new_state.set_shape(state.get_shape()) tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]])) output = tf.select(tiled_mask_t, output, states[0]) new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))] @@ -1295,6 +1297,8 @@ def rnn(step_function, inputs, initial_states, output, new_states = step_function(current_input, tuple(states) + tuple(constants)) + for state, new_state in zip(states, new_states): + new_state.set_shape(state.get_shape()) output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states)