model.fit() return training history
This commit is contained in:
parent
376373d159
commit
9390d63056
@ -159,6 +159,14 @@ class Sequential(object):
|
||||
if verbose:
|
||||
print("Train on %d samples, validate on %d samples" % (len(y), len(y_val)))
|
||||
|
||||
history = {'epoch':[], 'loss':[]}
|
||||
if show_accuracy:
|
||||
history['acc'] = []
|
||||
if do_validation:
|
||||
history['val_loss'] = []
|
||||
if show_accuracy:
|
||||
history['val_acc'] = []
|
||||
|
||||
index_array = np.arange(len(X))
|
||||
for epoch in range(nb_epoch):
|
||||
if verbose:
|
||||
@ -167,21 +175,29 @@ class Sequential(object):
|
||||
if shuffle:
|
||||
np.random.shuffle(index_array)
|
||||
|
||||
av_loss = 0.
|
||||
av_acc = 0.
|
||||
seen = 0
|
||||
|
||||
batches = make_batches(len(X), batch_size)
|
||||
for batch_index, (batch_start, batch_end) in enumerate(batches):
|
||||
if shuffle:
|
||||
batch_ids = index_array[batch_start:batch_end]
|
||||
else:
|
||||
batch_ids = slice(batch_start, batch_end)
|
||||
seen += len(batch_ids)
|
||||
X_batch = X[batch_ids]
|
||||
y_batch = y[batch_ids]
|
||||
|
||||
if show_accuracy:
|
||||
loss, acc = self._train_with_acc(X_batch, y_batch)
|
||||
log_values = [('loss', loss), ('acc.', acc)]
|
||||
av_loss += loss * len(batch_ids)
|
||||
av_acc += acc * len(batch_ids)
|
||||
else:
|
||||
loss = self._train(X_batch, y_batch)
|
||||
log_values = [('loss', loss)]
|
||||
av_loss += loss * len(batch_ids)
|
||||
|
||||
# validation
|
||||
if do_validation and (batch_index == len(batches) - 1):
|
||||
@ -196,6 +212,16 @@ class Sequential(object):
|
||||
if verbose:
|
||||
progbar.update(batch_end, log_values)
|
||||
|
||||
history['epoch'].append(epoch)
|
||||
history['loss'].append(av_loss/seen)
|
||||
if do_validation:
|
||||
history['val_loss'].append(float(val_loss))
|
||||
if show_accuracy:
|
||||
history['acc'].append(av_acc/seen)
|
||||
if do_validation:
|
||||
history['val_acc'].append(float(val_acc))
|
||||
return history
|
||||
|
||||
|
||||
def predict_proba(self, X, batch_size=128, verbose=1):
|
||||
batches = make_batches(len(X), batch_size)
|
||||
|
@ -39,11 +39,12 @@ class Progbar(object):
|
||||
'''
|
||||
for k, v in values:
|
||||
if k not in self.sum_values:
|
||||
self.sum_values[k] = [v, 1]
|
||||
self.sum_values[k] = [v * (current-self.seen_so_far), current-self.seen_so_far]
|
||||
self.unique_values.append(k)
|
||||
else:
|
||||
self.sum_values[k][0] += v * (current-self.seen_so_far)
|
||||
self.sum_values[k][1] += (current-self.seen_so_far)
|
||||
self.seen_so_far = current
|
||||
|
||||
now = time.time()
|
||||
if self.verbose == 1:
|
||||
@ -84,7 +85,6 @@ class Progbar(object):
|
||||
|
||||
sys.stdout.write(info)
|
||||
sys.stdout.flush()
|
||||
self.seen_so_far = current
|
||||
|
||||
if current >= self.target:
|
||||
sys.stdout.write("\n")
|
||||
|
Loading…
Reference in New Issue
Block a user