Added learning phase to callbacks (#2297) (#2303)

* added learning phase to callbacks (#2297)

* cleaned imports

* replaced tabs by spaces

* added case where uses_learning_phase is False

* fixed pep8 blank line bug
This commit is contained in:
Thomas Boquet 2016-04-13 21:00:49 -04:00 committed by François Chollet
parent 4f5f88b9ba
commit 57ea065db7
2 changed files with 10 additions and 8 deletions

@ -467,8 +467,14 @@ class TensorBoard(Callback):
if epoch % self.histogram_freq == 0:
# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)
feed_dict = dict(zip(self.model.inputs,
self.model.validation_data))
if self.model.uses_learning_phase:
cut_v_data = len(self.model.inputs)
val_data = self.model.validation_data[:cut_v_data] + [0]
tensors = self.model.inputs + [K.learning_phase()]
else:
val_data = self.model.validation_data
tensors = self.model.inputs
feed_dict = dict(zip(tensors, val_data))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)

@ -126,7 +126,7 @@ def test_LearningRateScheduler():
assert (float(K.get_value(model.optimizer.lr)) - 0.2) < K.epsilon()
@pytest.mark.skipif((K._BACKEND != 'tensorflow') or (sys.version_info[0] == 3),
@pytest.mark.skipif((K._BACKEND != 'tensorflow'),
reason="Requires tensorflow backend")
def test_TensorBoard():
import shutil
@ -252,8 +252,4 @@ def test_TensorBoard():
KTF.set_session(old_session)
if __name__ == '__main__':
# pytest.main([__file__])
# test_ModelCheckpoint()
# test_EarlyStopping()
# test_LearningRateScheduler()
test_TensorBoard()
pytest.main([__file__])