model.fit() return training history

This commit is contained in:
fchollet 2015-05-09 18:14:19 -07:00
parent 376373d159
commit 9390d63056
2 changed files with 28 additions and 2 deletions

@ -159,6 +159,14 @@ class Sequential(object):
if verbose:
print("Train on %d samples, validate on %d samples" % (len(y), len(y_val)))
history = {'epoch':[], 'loss':[]}
if show_accuracy:
history['acc'] = []
if do_validation:
history['val_loss'] = []
if show_accuracy:
history['val_acc'] = []
index_array = np.arange(len(X))
for epoch in range(nb_epoch):
if verbose:
@ -167,21 +175,29 @@ class Sequential(object):
if shuffle:
np.random.shuffle(index_array)
av_loss = 0.
av_acc = 0.
seen = 0
batches = make_batches(len(X), batch_size)
for batch_index, (batch_start, batch_end) in enumerate(batches):
if shuffle:
batch_ids = index_array[batch_start:batch_end]
else:
batch_ids = slice(batch_start, batch_end)
seen += len(batch_ids)
X_batch = X[batch_ids]
y_batch = y[batch_ids]
if show_accuracy:
loss, acc = self._train_with_acc(X_batch, y_batch)
log_values = [('loss', loss), ('acc.', acc)]
av_loss += loss * len(batch_ids)
av_acc += acc * len(batch_ids)
else:
loss = self._train(X_batch, y_batch)
log_values = [('loss', loss)]
av_loss += loss * len(batch_ids)
# validation
if do_validation and (batch_index == len(batches) - 1):
@ -196,6 +212,16 @@ class Sequential(object):
if verbose:
progbar.update(batch_end, log_values)
history['epoch'].append(epoch)
history['loss'].append(av_loss/seen)
if do_validation:
history['val_loss'].append(float(val_loss))
if show_accuracy:
history['acc'].append(av_acc/seen)
if do_validation:
history['val_acc'].append(float(val_acc))
return history
def predict_proba(self, X, batch_size=128, verbose=1):
batches = make_batches(len(X), batch_size)

@ -39,11 +39,12 @@ class Progbar(object):
'''
for k, v in values:
if k not in self.sum_values:
self.sum_values[k] = [v, 1]
self.sum_values[k] = [v * (current-self.seen_so_far), current-self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current-self.seen_so_far)
self.sum_values[k][1] += (current-self.seen_so_far)
self.seen_so_far = current
now = time.time()
if self.verbose == 1:
@ -84,7 +85,6 @@ class Progbar(object):
sys.stdout.write(info)
sys.stdout.flush()
self.seen_so_far = current
if current >= self.target:
sys.stdout.write("\n")