Fixed several bugs discovered during testing.

This commit is contained in:
John Wittenauer 2015-05-25 22:51:02 -04:00
parent 138ad116dc
commit 3d8a8e2d77

@ -1,17 +1,27 @@
from __future__ import absolute_import
import copy
import numpy as np
from ..utils.np_utils import to_categorical
class KerasClassifier(object):
"""
Implementation of the scikit-learn classifier API for Keras.
Parameters
----------
model : object, optional
A pre-compiled Keras model is required to use the scikit-learn wrapper.
model : object
An un-compiled Keras model object is required to use the scikit-learn wrapper.
optimizer : string, optional
Optimization method used by the model during compilation/training.
loss : string, optional
Loss function used by the model during compilation/training.
"""
def __init__(self, model=None):
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy'):
self.model = model
self.optimizer = optimizer
self.loss = loss
self.compiled_model_ = None
self.classes_ = []
self.config_ = []
self.weights_ = []
@ -31,7 +41,7 @@ class KerasClassifier(object):
params : dict
Dictionary of parameter names mapped to their values.
"""
return {'model': self.model}
return {'model': self.model, 'optimizer': self.optimizer, 'loss': self.loss}
def set_params(self, **params):
"""
@ -47,13 +57,17 @@ class KerasClassifier(object):
self
"""
for parameter, value in params.items():
self.setattr(parameter, value)
setattr(self, parameter, value)
return self
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):
"""
Fit the model according to the given training data.
Makes a copy of the un-compiled model definition to use for
compilation and fitting, leaving the original definition
intact.
Parameters
----------
X : array-like, shape = (n_samples, n_features)
@ -77,11 +91,17 @@ class KerasClassifier(object):
"""
if len(y.shape) == 1:
self.classes_ = list(np.unique(y))
if self.loss == 'categorical_crossentropy':
y = to_categorical(y)
else:
self.classes_ = np.arange(0, y.shape[1])
self.model.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, shuffle=shuffle)
self.compiled_model_ = copy.deepcopy(self.model)
self.compiled_model_.compile(optimizer=self.optimizer, loss=self.loss)
self.compiled_model_.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, shuffle=shuffle)
self.config_ = self.model.get_config()
self.weights_ = self.model.get_weights()
return self
def score(self, X, y, batch_size=128, verbose=0):
@ -105,7 +125,8 @@ class KerasClassifier(object):
score : float
Mean accuracy of self.predict(X) wrt. y.
"""
loss, accuracy = self.model.evaluate(X, y, batch_size=batch_size, show_accuracy=True, verbose=verbose)
loss, accuracy = self.compiled_model_.evaluate(X, y, batch_size=batch_size,
show_accuracy=True, verbose=verbose)
return accuracy
def predict(self, X, batch_size=128, verbose=0):
@ -127,7 +148,7 @@ class KerasClassifier(object):
preds : array-like, shape = (n_samples)
Class predictions.
"""
return self.model.predict_classes(X, batch_size=batch_size, verbose=verbose)
return self.compiled_model_.predict_classes(X, batch_size=batch_size, verbose=verbose)
def predict_proba(self, X, batch_size=128, verbose=0):
"""
@ -148,4 +169,4 @@ class KerasClassifier(object):
proba : array-like, shape = (n_samples, n_outputs)
Class probability estimates.
"""
return self.model.predict_proba(X, batch_size=batch_size, verbose=verbose)
return self.compiled_model_.predict_proba(X, batch_size=batch_size, verbose=verbose)