Support for fit_generator + tensorboard
This commit is contained in:
parent
f2443de96d
commit
d68e3316da
@ -456,6 +456,7 @@ class TensorBoard(Callback):
|
|||||||
'with the TensorFlow backend.')
|
'with the TensorFlow backend.')
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
self.histogram_freq = histogram_freq
|
self.histogram_freq = histogram_freq
|
||||||
|
self.merged = None
|
||||||
|
|
||||||
def _set_model(self, model):
|
def _set_model(self, model):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -463,7 +464,7 @@ class TensorBoard(Callback):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.sess = KTF._get_session()
|
self.sess = KTF._get_session()
|
||||||
if self.histogram_freq:
|
if self.histogram_freq and not self.merged:
|
||||||
mod_type = self.model.get_config()['name']
|
mod_type = self.model.get_config()['name']
|
||||||
if mod_type == 'Sequential':
|
if mod_type == 'Sequential':
|
||||||
layers = {l.get_config()['name']: l for l in self.model.layers}
|
layers = {l.get_config()['name']: l for l in self.model.layers}
|
||||||
@ -515,7 +516,7 @@ class TensorBoard(Callback):
|
|||||||
|
|
||||||
all_values = self.totals.copy()
|
all_values = self.totals.copy()
|
||||||
all_values.update(logs)
|
all_values.update(logs)
|
||||||
|
|
||||||
for name, value in all_values.items():
|
for name, value in all_values.items():
|
||||||
if name in ['batch', 'size']:
|
if name in ['batch', 'size']:
|
||||||
continue
|
continue
|
||||||
|
@ -970,8 +970,17 @@ class Sequential(Model, containers.Sequential):
|
|||||||
_stop.set()
|
_stop.set()
|
||||||
raise Exception('The generator output tuple must have '
|
raise Exception('The generator output tuple must have '
|
||||||
'2 or 3 elements.')
|
'2 or 3 elements.')
|
||||||
|
|
||||||
|
sample_weight = standardize_weights(y, sample_weight=sample_weight,
|
||||||
|
sample_weight_mode=self.sample_weight_mode)
|
||||||
return X, y, sample_weight
|
return X, y, sample_weight
|
||||||
|
|
||||||
|
if do_validation:
|
||||||
|
X_val, y_val, sample_weight_val = input_validation(validation_data)
|
||||||
|
self.validation_data = X_val + [y_val, sample_weight_val]
|
||||||
|
else:
|
||||||
|
self.validation_data = None
|
||||||
|
|
||||||
# start generator thread storing batches into a queue
|
# start generator thread storing batches into a queue
|
||||||
generator_queue = queue.Queue()
|
generator_queue = queue.Queue()
|
||||||
_stop = threading.Event()
|
_stop = threading.Event()
|
||||||
@ -1043,10 +1052,9 @@ class Sequential(Model, containers.Sequential):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
# input validation
|
# input validation
|
||||||
X, y, sample_weight = input_validation(validation_data)
|
val_outs = self.evaluate(X_val, y_val,
|
||||||
val_outs = self.evaluate(X, y,
|
|
||||||
show_accuracy=show_accuracy,
|
show_accuracy=show_accuracy,
|
||||||
sample_weight=sample_weight,
|
sample_weight=sample_weight_val,
|
||||||
verbose=0)
|
verbose=0)
|
||||||
if type(val_outs) != list:
|
if type(val_outs) != list:
|
||||||
val_outs = [val_outs]
|
val_outs = [val_outs]
|
||||||
@ -1431,8 +1439,19 @@ class Graph(Model, containers.Graph):
|
|||||||
[len(sample_weight[name]) for name in sample_weight.keys()])) != 1:
|
[len(sample_weight[name]) for name in sample_weight.keys()])) != 1:
|
||||||
raise Exception('All input arrays and target arrays must have '
|
raise Exception('All input arrays and target arrays must have '
|
||||||
'the same number of samples.')
|
'the same number of samples.')
|
||||||
|
sample_weight = {name: standardize_weights(data[name],
|
||||||
|
sample_weight=sample_weight.get(name),
|
||||||
|
sample_weight_mode=self.sample_weight_modes.get(name)) for name in self.output_order}
|
||||||
return data, sample_weight
|
return data, sample_weight
|
||||||
|
|
||||||
|
if do_validation:
|
||||||
|
data_val, sample_weight_val = input_validation(validation_data)
|
||||||
|
sample_weight_val_l = [sample_weight_val[name] for name in self.output_order]
|
||||||
|
y_val = [standardize_y(data_val[name]) for name in self.output_order]
|
||||||
|
self.validation_data = [data_val[name] for name in self.input_order] + y_val + sample_weight_val_l
|
||||||
|
else:
|
||||||
|
self.validation_data = None
|
||||||
|
|
||||||
# start generator thread storing batches into a queue
|
# start generator thread storing batches into a queue
|
||||||
generator_queue = queue.Queue()
|
generator_queue = queue.Queue()
|
||||||
_stop = threading.Event()
|
_stop = threading.Event()
|
||||||
@ -1498,10 +1517,8 @@ class Graph(Model, containers.Graph):
|
|||||||
_stop.set()
|
_stop.set()
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
# input validation
|
val_outs = self.evaluate(data_val,
|
||||||
data, sample_weight = input_validation(validation_data)
|
sample_weight=sample_weight_val,
|
||||||
val_outs = self.evaluate(data,
|
|
||||||
sample_weight=sample_weight,
|
|
||||||
verbose=0)
|
verbose=0)
|
||||||
if type(val_outs) != list:
|
if type(val_outs) != list:
|
||||||
val_outs = [val_outs]
|
val_outs = [val_outs]
|
||||||
|
@ -136,7 +136,30 @@ def test_TensorBoard():
|
|||||||
nb_class=nb_class)
|
nb_class=nb_class)
|
||||||
y_test = np_utils.to_categorical(y_test)
|
y_test = np_utils.to_categorical(y_test)
|
||||||
y_train = np_utils.to_categorical(y_train)
|
y_train = np_utils.to_categorical(y_train)
|
||||||
# case 1 Sequential wo accuracy
|
|
||||||
|
def data_generator(train):
|
||||||
|
if train:
|
||||||
|
max_batch_index = len(X_train) // batch_size
|
||||||
|
else:
|
||||||
|
max_batch_index = len(X_test) // batch_size
|
||||||
|
i = 0
|
||||||
|
while 1:
|
||||||
|
if train:
|
||||||
|
yield (X_train[i * batch_size: (i + 1) * batch_size], y_train[i * batch_size: (i + 1) * batch_size])
|
||||||
|
else:
|
||||||
|
yield (X_test[i * batch_size: (i + 1) * batch_size], y_test[i * batch_size: (i + 1) * batch_size])
|
||||||
|
i += 1
|
||||||
|
i = i % max_batch_index
|
||||||
|
|
||||||
|
def data_generator_graph(train):
|
||||||
|
while 1:
|
||||||
|
if train:
|
||||||
|
yield {'X_vars': X_train, 'output': y_train}
|
||||||
|
else:
|
||||||
|
yield {'X_vars': X_test, 'output': y_test}
|
||||||
|
|
||||||
|
# case 1 Sequential
|
||||||
|
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
session = tf.Session('')
|
session = tf.Session('')
|
||||||
KTF._set_session(session)
|
KTF._set_session(session)
|
||||||
@ -147,28 +170,42 @@ def test_TensorBoard():
|
|||||||
|
|
||||||
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
||||||
cbks = [tsb]
|
cbks = [tsb]
|
||||||
|
|
||||||
|
# fit with validation data
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=False,
|
||||||
|
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
|
||||||
|
|
||||||
|
# fit with validation data and accuracy
|
||||||
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
|
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
|
||||||
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
|
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
|
||||||
|
|
||||||
|
# fit generator with validation data
|
||||||
|
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
|
||||||
|
show_accuracy=False,
|
||||||
|
validation_data=(X_test, y_test),
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
|
# fit generator without validation data
|
||||||
|
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
|
||||||
|
show_accuracy=False,
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
|
# fit generator with validation data and accuracy
|
||||||
|
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
|
||||||
|
show_accuracy=True,
|
||||||
|
validation_data=(X_test, y_test),
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
|
# fit generator without validation data and accuracy
|
||||||
|
model.fit_generator(data_generator(True), len(X_train), nb_epoch=2,
|
||||||
|
show_accuracy=True,
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
assert os.path.exists(filepath)
|
assert os.path.exists(filepath)
|
||||||
shutil.rmtree(filepath)
|
shutil.rmtree(filepath)
|
||||||
|
|
||||||
# case 2 Sequential w accuracy
|
# case 2 Graph
|
||||||
with tf.Graph().as_default():
|
|
||||||
session = tf.Session('')
|
|
||||||
KTF._set_session(session)
|
|
||||||
model = Sequential()
|
|
||||||
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
|
|
||||||
model.add(Dense(nb_class, activation='softmax'))
|
|
||||||
model.compile(loss='categorical_crossentropy', optimizer='sgd')
|
|
||||||
|
|
||||||
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
|
||||||
cbks = [tsb]
|
|
||||||
model.fit(X_train, y_train, batch_size=batch_size, show_accuracy=True,
|
|
||||||
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=2)
|
|
||||||
assert os.path.exists(filepath)
|
|
||||||
shutil.rmtree(filepath)
|
|
||||||
|
|
||||||
# case 3 Graph
|
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
session = tf.Session('')
|
session = tf.Session('')
|
||||||
KTF._set_session(session)
|
KTF._set_session(session)
|
||||||
@ -185,10 +222,27 @@ def test_TensorBoard():
|
|||||||
|
|
||||||
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
||||||
cbks = [tsb]
|
cbks = [tsb]
|
||||||
|
|
||||||
|
# fit with validation
|
||||||
model.fit({'X_vars': X_train, 'output': y_train},
|
model.fit({'X_vars': X_train, 'output': y_train},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
validation_data={'X_vars': X_test, 'output': y_test},
|
validation_data={'X_vars': X_test, 'output': y_test},
|
||||||
callbacks=cbks, nb_epoch=2)
|
callbacks=cbks, nb_epoch=2)
|
||||||
|
|
||||||
|
# fit wo validation
|
||||||
|
model.fit({'X_vars': X_train, 'output': y_train},
|
||||||
|
batch_size=batch_size,
|
||||||
|
callbacks=cbks, nb_epoch=2)
|
||||||
|
|
||||||
|
# fit generator with validation
|
||||||
|
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
|
||||||
|
validation_data={'X_vars': X_test, 'output': y_test},
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
|
# fit generator wo validation
|
||||||
|
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
|
||||||
|
callbacks=cbks)
|
||||||
|
|
||||||
assert os.path.exists(filepath)
|
assert os.path.exists(filepath)
|
||||||
shutil.rmtree(filepath)
|
shutil.rmtree(filepath)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user