Support for fit_generator + tensorboard

This commit is contained in:
tboquet 2016-02-08 14:55:00 -05:00
parent f2443de96d
commit d68e3316da
3 changed files with 98 additions and 26 deletions

@ -456,6 +456,7 @@ class TensorBoard(Callback):
'with the TensorFlow backend.')
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.merged = None
def _set_model(self, model):
import tensorflow as tf
@ -463,7 +464,7 @@ class TensorBoard(Callback):
self.model = model
self.sess = KTF._get_session()
if self.histogram_freq:
if self.histogram_freq and not self.merged:
mod_type = self.model.get_config()['name']
if mod_type == 'Sequential':
layers = {l.get_config()['name']: l for l in self.model.layers}
@ -515,7 +516,7 @@ class TensorBoard(Callback):
all_values = self.totals.copy()
all_values.update(logs)
for name, value in all_values.items():
if name in ['batch', 'size']:
continue

@ -970,8 +970,17 @@ class Sequential(Model, containers.Sequential):
_stop.set()
raise Exception('The generator output tuple must have '
'2 or 3 elements.')
sample_weight = standardize_weights(y, sample_weight=sample_weight,
sample_weight_mode=self.sample_weight_mode)
return X, y, sample_weight
if do_validation:
X_val, y_val, sample_weight_val = input_validation(validation_data)
self.validation_data = X_val + [y_val, sample_weight_val]
else:
self.validation_data = None
# start generator thread storing batches into a queue
generator_queue = queue.Queue()
_stop = threading.Event()
@ -1043,10 +1052,9 @@ class Sequential(Model, containers.Sequential):
raise NotImplementedError()
else:
# input validation
X, y, sample_weight = input_validation(validation_data)
val_outs = self.evaluate(X, y,
val_outs = self.evaluate(X_val, y_val,
show_accuracy=show_accuracy,
sample_weight=sample_weight,
sample_weight=sample_weight_val,
verbose=0)
if type(val_outs) != list:
val_outs = [val_outs]
@ -1431,8 +1439,19 @@ class Graph(Model, containers.Graph):
[len(sample_weight[name]) for name in sample_weight.keys()])) != 1:
raise Exception('All input arrays and target arrays must have '
'the same number of samples.')
sample_weight = {name: standardize_weights(data[name],
sample_weight=sample_weight.get(name),
sample_weight_mode=self.sample_weight_modes.get(name)) for name in self.output_order}
return data, sample_weight
if do_validation:
data_val, sample_weight_val = input_validation(validation_data)
sample_weight_val_l = [sample_weight_val[name] for name in self.output_order]
y_val = [standardize_y(data_val[name]) for name in self.output_order]
self.validation_data = [data_val[name] for name in self.input_order] + y_val + sample_weight_val_l
else:
self.validation_data = None
# start generator thread storing batches into a queue
generator_queue = queue.Queue()
_stop = threading.Event()
@ -1498,10 +1517,8 @@ class Graph(Model, containers.Graph):
_stop.set()
raise NotImplementedError()
else:
# input validation
data, sample_weight = input_validation(validation_data)
val_outs = self.evaluate(data,
sample_weight=sample_weight,
val_outs = self.evaluate(data_val,
sample_weight=sample_weight_val,
verbose=0)
if type(val_outs) != list:
val_outs = [val_outs]

@ -136,7 +136,30 @@ def test_TensorBoard():
nb_class=nb_class)
y_test = np_utils.to_categorical(y_test)
y_train = np_utils.to_categorical(y_train)
# case 1 Sequential wo accuracy
def data_generator(train):
if train:
max_batch_index = len(X_train) // batch_size
else:
max_batch_index = len(X_test) // batch_size
i = 0
while 1:
if train:
yield (X_train[i * batch_size: (i + 1) * batch_size], y_train[i * batch_size: (i + 1) * batch_size])
else:
yield (X_test[i * batch_size: (i + 1) * batch_size], y_test[i * batch_size: (i + 1) * batch_size])
i += 1
i = i % max_batch_index
def data_generator_graph(train):
while 1:
if train:
yield {'X_vars': X_train, 'output': y_train}
else:
yield {'X_vars': X_test, 'output': y_test}
# case 1 Sequential
with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
@ -147,28 +170,42 @@ def test_TensorBoard():
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]
# fit with validation data
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=False,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
# fit with validation data and accuracy
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
# fit generator with validation data
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=False,
validation_data=(X_test, y_test),
callbacks=cbks)
# fit generator without validation data
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=False,
callbacks=cbks)
# fit generator with validation data and accuracy
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=True,
validation_data=(X_test, y_test),
callbacks=cbks)
# fit generator without validation data and accuracy
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
show_accuracy=True,
callbacks=cbks)
assert os.path.exists(filepath)
shutil.rmtree(filepath)
# case 2 Sequential w accuracy
with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
model = Sequential()
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
model.add(Dense(nb_class, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd')
# case 2 Graph
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
assert os.path.exists(filepath)
shutil.rmtree(filepath)
# case 3 Graph
with tf.Graph().as_default():
session = tf.Session('')
KTF._set_session(session)
@ -185,10 +222,27 @@ def test_TensorBoard():
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]
# fit with validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks, nb_epoch=2)
# fit wo validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
callbacks=cbks, nb_epoch=2)
# fit generator with validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks)
# fit generator wo validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
callbacks=cbks)
assert os.path.exists(filepath)
shutil.rmtree(filepath)