From d68e3316dafb0a49e60420bf1067e3145547f997 Mon Sep 17 00:00:00 2001 From: tboquet Date: Mon, 8 Feb 2016 14:55:00 -0500 Subject: [PATCH] Support for fit_generator + tensorboard --- keras/callbacks.py | 5 +- keras/models.py | 31 +++++++++--- tests/keras/test_callbacks.py | 88 ++++++++++++++++++++++++++++------- 3 files changed, 98 insertions(+), 26 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index 93c47e706..d1c612706 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -456,6 +456,7 @@ class TensorBoard(Callback): 'with the TensorFlow backend.') self.log_dir = log_dir self.histogram_freq = histogram_freq + self.merged = None def _set_model(self, model): import tensorflow as tf @@ -463,7 +464,7 @@ class TensorBoard(Callback): self.model = model self.sess = KTF._get_session() - if self.histogram_freq: + if self.histogram_freq and not self.merged: mod_type = self.model.get_config()['name'] if mod_type == 'Sequential': layers = {l.get_config()['name']: l for l in self.model.layers} @@ -515,7 +516,7 @@ class TensorBoard(Callback): all_values = self.totals.copy() all_values.update(logs) - + for name, value in all_values.items(): if name in ['batch', 'size']: continue diff --git a/keras/models.py b/keras/models.py index d3da5a335..2d7569d56 100644 --- a/keras/models.py +++ b/keras/models.py @@ -970,8 +970,17 @@ class Sequential(Model, containers.Sequential): _stop.set() raise Exception('The generator output tuple must have ' '2 or 3 elements.') + + sample_weight = standardize_weights(y, sample_weight=sample_weight, + sample_weight_mode=self.sample_weight_mode) return X, y, sample_weight + if do_validation: + X_val, y_val, sample_weight_val = input_validation(validation_data) + self.validation_data = X_val + [y_val, sample_weight_val] + else: + self.validation_data = None + # start generator thread storing batches into a queue generator_queue = queue.Queue() _stop = threading.Event() @@ -1043,10 +1052,9 @@ class Sequential(Model, containers.Sequential): raise NotImplementedError() else: # input validation - X, y, sample_weight = input_validation(validation_data) - val_outs = self.evaluate(X, y, + val_outs = self.evaluate(X_val, y_val, show_accuracy=show_accuracy, - sample_weight=sample_weight, + sample_weight=sample_weight_val, verbose=0) if type(val_outs) != list: val_outs = [val_outs] @@ -1431,8 +1439,19 @@ class Graph(Model, containers.Graph): [len(sample_weight[name]) for name in sample_weight.keys()])) != 1: raise Exception('All input arrays and target arrays must have ' 'the same number of samples.') + sample_weight = {name: standardize_weights(data[name], + sample_weight=sample_weight.get(name), + sample_weight_mode=self.sample_weight_modes.get(name)) for name in self.output_order} return data, sample_weight + if do_validation: + data_val, sample_weight_val = input_validation(validation_data) + sample_weight_val_l = [sample_weight_val[name] for name in self.output_order] + y_val = [standardize_y(data_val[name]) for name in self.output_order] + self.validation_data = [data_val[name] for name in self.input_order] + y_val + sample_weight_val_l + else: + self.validation_data = None + # start generator thread storing batches into a queue generator_queue = queue.Queue() _stop = threading.Event() @@ -1498,10 +1517,8 @@ class Graph(Model, containers.Graph): _stop.set() raise NotImplementedError() else: - # input validation - data, sample_weight = input_validation(validation_data) - val_outs = self.evaluate(data, - sample_weight=sample_weight, + val_outs = self.evaluate(data_val, + sample_weight=sample_weight_val, verbose=0) if type(val_outs) != list: val_outs = [val_outs] diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index b02d9d390..3c936fbdc 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -136,7 +136,30 @@ def test_TensorBoard(): nb_class=nb_class) y_test = np_utils.to_categorical(y_test) y_train = np_utils.to_categorical(y_train) - # case 1 Sequential wo accuracy + + def data_generator(train): + if train: + max_batch_index = len(X_train) // batch_size + else: + max_batch_index = len(X_test) // batch_size + i = 0 + while 1: + if train: + yield (X_train[i * batch_size: (i + 1) * batch_size], y_train[i * batch_size: (i + 1) * batch_size]) + else: + yield (X_test[i * batch_size: (i + 1) * batch_size], y_test[i * batch_size: (i + 1) * batch_size]) + i += 1 + i = i % max_batch_index + + def data_generator_graph(train): + while 1: + if train: + yield {'X_vars': X_train, 'output': y_train} + else: + yield {'X_vars': X_test, 'output': y_test} + + # case 1 Sequential + with tf.Graph().as_default(): session = tf.Session('') KTF._set_session(session) @@ -147,28 +170,42 @@ def test_TensorBoard(): tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1) cbks = [tsb] + + # fit with validation data + model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=False, + validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2) + + # fit with validation data and accuracy model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True, validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2) + + # fit generator with validation data + model.fit_generator(data_generator(True), len(X_train), nb_epoch=2, + show_accuracy=False, + validation_data=(X_test, y_test), + callbacks=cbks) + + # fit generator without validation data + model.fit_generator(data_generator(True), len(X_train), nb_epoch=2, + show_accuracy=False, + callbacks=cbks) + + # fit generator with validation data and accuracy + model.fit_generator(data_generator(True), len(X_train), nb_epoch=2, + show_accuracy=True, + validation_data=(X_test, y_test), + callbacks=cbks) + + # fit generator without validation data and accuracy + model.fit_generator(data_generator(True), len(X_train), nb_epoch=2, + show_accuracy=True, + callbacks=cbks) + assert os.path.exists(filepath) shutil.rmtree(filepath) - # case 2 Sequential w accuracy - with tf.Graph().as_default(): - session = tf.Session('') - KTF._set_session(session) - model = Sequential() - model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu')) - model.add(Dense(nb_class, activation='softmax')) - model.compile(loss='categorical_crossentropy', optimizer='sgd') + # case 2 Graph - tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1) - cbks = [tsb] - model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True, - validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2) - assert os.path.exists(filepath) - shutil.rmtree(filepath) - - # case 3 Graph with tf.Graph().as_default(): session = tf.Session('') KTF._set_session(session) @@ -185,10 +222,27 @@ def test_TensorBoard(): tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1) cbks = [tsb] + + # fit with validation model.fit({'X_vars': X_train, 'output': y_train}, batch_size=batch_size, validation_data={'X_vars': X_test, 'output': y_test}, callbacks=cbks, nb_epoch=2) + + # fit wo validation + model.fit({'X_vars': X_train, 'output': y_train}, + batch_size=batch_size, + callbacks=cbks, nb_epoch=2) + + # fit generator with validation + model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2, + validation_data={'X_vars': X_test, 'output': y_test}, + callbacks=cbks) + + # fit generator wo validation + model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2, + callbacks=cbks) + assert os.path.exists(filepath) shutil.rmtree(filepath)