Fix TF RNN issues

This commit is contained in:
Francois Chollet 2015-11-29 12:22:41 -08:00
parent 7ecd6c3c5f
commit cbee000b66
4 changed files with 13 additions and 9 deletions

@ -41,7 +41,7 @@ y_test = np.array(y_test)
print('Build model...')
model = Graph()
model.add_input(name='input', input_shape=(1,), dtype=int)
model.add_input(name='input', input_shape=(maxlen,), dtype=int)
model.add_node(Embedding(max_features, 128, input_length=maxlen),
name='embedding', input='input')
model.add_node(LSTM(64), name='forward', input='embedding')

@ -381,7 +381,7 @@ def rnn(step_function, inputs, initial_states,
successive_states = []
successive_outputs = []
if go_backwards:
input_list = input_list.reverse()
input_list.reverse()
for input in input_list:
output, new_states = step_function(input, states)
if masking:

@ -224,12 +224,13 @@ class Graph(Layer):
self.input_order.append(name)
layer = Layer() # empty layer
layer.set_input_shape(input_shape)
ndim = len(input_shape) + 1
if dtype == 'float':
layer.input = K.placeholder(ndim=ndim, name=name)
layer.input = K.placeholder(shape=layer.input_shape, name=name)
else:
if ndim == 2:
layer.input = K.placeholder(ndim=2, dtype='int32', name=name)
if len(input_shape) == 1:
layer.input = K.placeholder(shape=layer.input_shape,
dtype='int32',
name=name)
else:
raise Exception('Type "int" can only be used with ndim==2 (Embedding).')
self.inputs[name] = layer

@ -17,9 +17,12 @@ class Embedding(Layer):
'''
input_ndim = 2
def __init__(self, input_dim, output_dim, init='uniform', input_length=None,
W_regularizer=None, activity_regularizer=None, W_constraint=None,
mask_zero=False, weights=None, **kwargs):
def __init__(self, input_dim, output_dim,
init='uniform', input_length=None,
W_regularizer=None, activity_regularizer=None,
W_constraint=None,
mask_zero=False,
weights=None, **kwargs):
self.input_dim = input_dim
self.output_dim = output_dim
self.init = initializations.get(init)