From 24d6cca275ea7e47664901fbfb59f033af901242 Mon Sep 17 00:00:00 2001 From: Javier Dehesa Date: Tue, 29 Nov 2016 19:31:14 +0000 Subject: [PATCH] Enforce shape invariance for states in RNN loop (#4536) Fixes some shape invariance errors arising sometimes when building RNNs. --- keras/backend/tensorflow_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 2c1c06e4d..5b56e705d 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1284,6 +1284,8 @@ def rnn(step_function, inputs, initial_states, output, new_states = step_function(current_input, tuple(states) + tuple(constants)) + for state, new_state in zip(states, new_states): + new_state.set_shape(state.get_shape()) tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]])) output = tf.select(tiled_mask_t, output, states[0]) new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))] @@ -1295,6 +1297,8 @@ def rnn(step_function, inputs, initial_states, output, new_states = step_function(current_input, tuple(states) + tuple(constants)) + for state, new_state in zip(states, new_states): + new_state.set_shape(state.get_shape()) output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states)