Fix TF RNN issues
This commit is contained in:
parent
7ecd6c3c5f
commit
cbee000b66
@ -41,7 +41,7 @@ y_test = np.array(y_test)
|
||||
|
||||
print('Build model...')
|
||||
model = Graph()
|
||||
model.add_input(name='input', input_shape=(1,), dtype=int)
|
||||
model.add_input(name='input', input_shape=(maxlen,), dtype=int)
|
||||
model.add_node(Embedding(max_features, 128, input_length=maxlen),
|
||||
name='embedding', input='input')
|
||||
model.add_node(LSTM(64), name='forward', input='embedding')
|
||||
|
@ -381,7 +381,7 @@ def rnn(step_function, inputs, initial_states,
|
||||
successive_states = []
|
||||
successive_outputs = []
|
||||
if go_backwards:
|
||||
input_list = input_list.reverse()
|
||||
input_list.reverse()
|
||||
for input in input_list:
|
||||
output, new_states = step_function(input, states)
|
||||
if masking:
|
||||
|
@ -224,12 +224,13 @@ class Graph(Layer):
|
||||
self.input_order.append(name)
|
||||
layer = Layer() # empty layer
|
||||
layer.set_input_shape(input_shape)
|
||||
ndim = len(input_shape) + 1
|
||||
if dtype == 'float':
|
||||
layer.input = K.placeholder(ndim=ndim, name=name)
|
||||
layer.input = K.placeholder(shape=layer.input_shape, name=name)
|
||||
else:
|
||||
if ndim == 2:
|
||||
layer.input = K.placeholder(ndim=2, dtype='int32', name=name)
|
||||
if len(input_shape) == 1:
|
||||
layer.input = K.placeholder(shape=layer.input_shape,
|
||||
dtype='int32',
|
||||
name=name)
|
||||
else:
|
||||
raise Exception('Type "int" can only be used with ndim==2 (Embedding).')
|
||||
self.inputs[name] = layer
|
||||
|
@ -17,9 +17,12 @@ class Embedding(Layer):
|
||||
'''
|
||||
input_ndim = 2
|
||||
|
||||
def __init__(self, input_dim, output_dim, init='uniform', input_length=None,
|
||||
W_regularizer=None, activity_regularizer=None, W_constraint=None,
|
||||
mask_zero=False, weights=None, **kwargs):
|
||||
def __init__(self, input_dim, output_dim,
|
||||
init='uniform', input_length=None,
|
||||
W_regularizer=None, activity_regularizer=None,
|
||||
W_constraint=None,
|
||||
mask_zero=False,
|
||||
weights=None, **kwargs):
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.init = initializations.get(init)
|
||||
|
Loading…
Reference in New Issue
Block a user