This commit is contained in:
parent
ca360b0d15
commit
28882a868d
@ -55,7 +55,7 @@ model.add_node(Dense(1, activation='sigmoid'), name='sigmoid', input='dropout')
|
|||||||
model.add_output(name='output', input='sigmoid')
|
model.add_output(name='output', input='sigmoid')
|
||||||
|
|
||||||
# try using different optimizers and different optimizer configs
|
# try using different optimizers and different optimizer configs
|
||||||
model.compile('adam',{'output':'binary_crossentropy'})
|
model.compile('adam', {'output':'binary_crossentropy'})
|
||||||
|
|
||||||
print("Train...")
|
print("Train...")
|
||||||
model.fit({'input':X_train, 'output':y_train}, batch_size=batch_size, nb_epoch=4)
|
model.fit({'input':X_train, 'output':y_train}, batch_size=batch_size, nb_epoch=4)
|
||||||
|
@ -125,7 +125,8 @@ class SimpleRNN(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(SimpleRNN, self).get_config()
|
base_config = super(SimpleRNN, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -224,7 +225,8 @@ class SimpleDeepRNN(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(SimpleDeepRNN, self).get_config()
|
base_config = super(SimpleDeepRNN, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -339,7 +341,8 @@ class GRU(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(GRU, self).get_config()
|
base_config = super(GRU, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -472,7 +475,8 @@ class LSTM(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(LSTM, self).get_config()
|
base_config = super(LSTM, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -591,7 +595,8 @@ class JZS1(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(JZS1, self).get_config()
|
base_config = super(JZS1, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -712,7 +717,8 @@ class JZS2(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(JZS2, self).get_config()
|
base_config = super(JZS2, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -825,6 +831,7 @@ class JZS3(Recurrent):
|
|||||||
"truncate_gradient": self.truncate_gradient,
|
"truncate_gradient": self.truncate_gradient,
|
||||||
"return_sequences": self.return_sequences,
|
"return_sequences": self.return_sequences,
|
||||||
"input_dim": self.input_dim,
|
"input_dim": self.input_dim,
|
||||||
"input_length": self.input_length}
|
"input_length": self.input_length,
|
||||||
|
"go_backwards": self.go_backwards}
|
||||||
base_config = super(JZS3, self).get_config()
|
base_config = super(JZS3, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
Loading…
Reference in New Issue
Block a user