Added predict() for Sequential models

This commit is contained in:
Dan Becker 2015-05-11 21:58:15 -06:00
parent 43bfeb0a66
commit ca3e6846e7

@ -3,6 +3,7 @@ from __future__ import print_function
import theano
import theano.tensor as T
import numpy as np
import warnings
from . import optimizers
from . import objectives
@ -222,8 +223,7 @@ class Sequential(object):
history['val_acc'].append(float(val_acc))
return history
def predict_proba(self, X, batch_size=128, verbose=1):
def predict(self, X, batch_size=128, verbose=1):
batches = make_batches(len(X), batch_size)
if verbose==1:
progbar = Progbar(target=len(X))
@ -241,6 +241,12 @@ class Sequential(object):
return preds
def predict_proba(self, X, batch_size=128, verbose=1):
preds = self.predict(X, batch_size, verbose)
if preds.min()<0 or preds.max()>1:
warnings.warn("Network returning invalid probability values.")
return preds
def predict_classes(self, X, batch_size=128, verbose=1):
proba = self.predict_proba(X, batch_size=batch_size, verbose=verbose)