From ca3e6846e73fee7dc06ecceddde3320ae46d871f Mon Sep 17 00:00:00 2001 From: Dan Becker Date: Mon, 11 May 2015 21:58:15 -0600 Subject: [PATCH] Added predict() for Sequential models --- keras/models.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index ae9302bd1..854d6e4fc 100644 --- a/keras/models.py +++ b/keras/models.py @@ -3,6 +3,7 @@ from __future__ import print_function import theano import theano.tensor as T import numpy as np +import warnings from . import optimizers from . import objectives @@ -222,8 +223,7 @@ class Sequential(object): history['val_acc'].append(float(val_acc)) return history - - def predict_proba(self, X, batch_size=128, verbose=1): + def predict(self, X, batch_size=128, verbose=1): batches = make_batches(len(X), batch_size) if verbose==1: progbar = Progbar(target=len(X)) @@ -241,6 +241,12 @@ class Sequential(object): return preds + def predict_proba(self, X, batch_size=128, verbose=1): + preds = self.predict(X, batch_size, verbose) + if preds.min()<0 or preds.max()>1: + warnings.warn("Network returning invalid probability values.") + return preds + def predict_classes(self, X, batch_size=128, verbose=1): proba = self.predict_proba(X, batch_size=batch_size, verbose=verbose)