Daniele Bonadiman 2015-11-09 00:29:14 +01:00
parent ca360b0d15
commit 28882a868d
2 changed files with 15 additions and 8 deletions

@ -55,7 +55,7 @@ model.add_node(Dense(1, activation='sigmoid'), name='sigmoid', input='dropout')
model.add_output(name='output', input='sigmoid') model.add_output(name='output', input='sigmoid')
# try using different optimizers and different optimizer configs # try using different optimizers and different optimizer configs
model.compile('adam',{'output':'binary_crossentropy'}) model.compile('adam', {'output':'binary_crossentropy'})
print("Train...") print("Train...")
model.fit({'input':X_train, 'output':y_train}, batch_size=batch_size, nb_epoch=4) model.fit({'input':X_train, 'output':y_train}, batch_size=batch_size, nb_epoch=4)

@ -125,7 +125,8 @@ class SimpleRNN(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(SimpleRNN, self).get_config() base_config = super(SimpleRNN, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -224,7 +225,8 @@ class SimpleDeepRNN(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(SimpleDeepRNN, self).get_config() base_config = super(SimpleDeepRNN, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -339,7 +341,8 @@ class GRU(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(GRU, self).get_config() base_config = super(GRU, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -472,7 +475,8 @@ class LSTM(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(LSTM, self).get_config() base_config = super(LSTM, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -591,7 +595,8 @@ class JZS1(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(JZS1, self).get_config() base_config = super(JZS1, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -712,7 +717,8 @@ class JZS2(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(JZS2, self).get_config() base_config = super(JZS2, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -825,6 +831,7 @@ class JZS3(Recurrent):
"truncate_gradient": self.truncate_gradient, "truncate_gradient": self.truncate_gradient,
"return_sequences": self.return_sequences, "return_sequences": self.return_sequences,
"input_dim": self.input_dim, "input_dim": self.input_dim,
"input_length": self.input_length} "input_length": self.input_length,
"go_backwards": self.go_backwards}
base_config = super(JZS3, self).get_config() base_config = super(JZS3, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))