diff --git a/keras/models.py b/keras/models.py index e89ef8eb0..ae9302bd1 100644 --- a/keras/models.py +++ b/keras/models.py @@ -159,6 +159,14 @@ class Sequential(object): if verbose: print("Train on %d samples, validate on %d samples" % (len(y), len(y_val))) + history = {'epoch':[], 'loss':[]} + if show_accuracy: + history['acc'] = [] + if do_validation: + history['val_loss'] = [] + if show_accuracy: + history['val_acc'] = [] + index_array = np.arange(len(X)) for epoch in range(nb_epoch): if verbose: @@ -167,21 +175,29 @@ class Sequential(object): if shuffle: np.random.shuffle(index_array) + av_loss = 0. + av_acc = 0. + seen = 0 + batches = make_batches(len(X), batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): if shuffle: batch_ids = index_array[batch_start:batch_end] else: batch_ids = slice(batch_start, batch_end) + seen += len(batch_ids) X_batch = X[batch_ids] y_batch = y[batch_ids] if show_accuracy: loss, acc = self._train_with_acc(X_batch, y_batch) log_values = [('loss', loss), ('acc.', acc)] + av_loss += loss * len(batch_ids) + av_acc += acc * len(batch_ids) else: loss = self._train(X_batch, y_batch) log_values = [('loss', loss)] + av_loss += loss * len(batch_ids) # validation if do_validation and (batch_index == len(batches) - 1): @@ -196,6 +212,16 @@ class Sequential(object): if verbose: progbar.update(batch_end, log_values) + history['epoch'].append(epoch) + history['loss'].append(av_loss/seen) + if do_validation: + history['val_loss'].append(float(val_loss)) + if show_accuracy: + history['acc'].append(av_acc/seen) + if do_validation: + history['val_acc'].append(float(val_acc)) + return history + def predict_proba(self, X, batch_size=128, verbose=1): batches = make_batches(len(X), batch_size) diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index ad39d041e..4865c876a 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -39,11 +39,12 @@ class Progbar(object): ''' for k, v in values: if k not in self.sum_values: - self.sum_values[k] = [v, 1] + self.sum_values[k] = [v * (current-self.seen_so_far), current-self.seen_so_far] self.unique_values.append(k) else: self.sum_values[k][0] += v * (current-self.seen_so_far) self.sum_values[k][1] += (current-self.seen_so_far) + self.seen_so_far = current now = time.time() if self.verbose == 1: @@ -84,7 +85,6 @@ class Progbar(object): sys.stdout.write(info) sys.stdout.flush() - self.seen_so_far = current if current >= self.target: sys.stdout.write("\n")