From 5e2b9358d6741840f3b86adbaf4f5ac8a99ea9ee Mon Sep 17 00:00:00 2001 From: fchollet Date: Tue, 26 May 2015 13:50:54 -0700 Subject: [PATCH] Fix issue in Theano scan for recurrent layers --- keras/layers/recurrent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 962bd3680..81f793914 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -60,7 +60,7 @@ class SimpleRNN(Layer): self._step, # this will be called with arguments (sequences[i], outputs[i-1], non_sequences[i]) sequences=x, # tensors to iterate over, inputs to _step # initialization of the output. Input to _step with default tap=-1. - outputs_info=alloc_zeros_matrix(X.shape[1], self.output_dim), + outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1), non_sequences=self.U, # static inputs to _step truncate_gradient=self.truncate_gradient ) @@ -129,7 +129,7 @@ class SimpleDeepRNN(Layer): self._step, sequences=x, outputs_info=[dict( - initial=T.alloc(np.cast[theano.config.floatX](0.), self.depth, X.shape[1], self.output_dim), + initial=T.alloc(np.cast[theano.config.floatX](0.), self.depth, X.shape[1], self.output_dim), taps = [(-i-1) for i in range(self.depth)] )], non_sequences=self.Us, @@ -233,7 +233,7 @@ class GRU(Layer): outputs, updates = theano.scan( self._step, sequences=[x_z, x_r, x_h], - outputs_info=alloc_zeros_matrix(X.shape[1], self.output_dim), + outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1), non_sequences=[self.U_z, self.U_r, self.U_h], truncate_gradient=self.truncate_gradient ) @@ -346,8 +346,8 @@ class LSTM(Layer): self._step, sequences=[xi, xf, xo, xc], outputs_info=[ - alloc_zeros_matrix(X.shape[1], self.output_dim), - alloc_zeros_matrix(X.shape[1], self.output_dim) + T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1), + T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1) ], non_sequences=[self.U_i, self.U_f, self.U_o, self.U_c], truncate_gradient=self.truncate_gradient