Fixed several bugs discovered during testing.
This commit is contained in:
parent
138ad116dc
commit
3d8a8e2d77
@ -1,17 +1,27 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..utils.np_utils import to_categorical
|
||||||
|
|
||||||
class KerasClassifier(object):
|
class KerasClassifier(object):
|
||||||
"""
|
"""
|
||||||
Implementation of the scikit-learn classifier API for Keras.
|
Implementation of the scikit-learn classifier API for Keras.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model : object, optional
|
model : object
|
||||||
A pre-compiled Keras model is required to use the scikit-learn wrapper.
|
An un-compiled Keras model object is required to use the scikit-learn wrapper.
|
||||||
|
optimizer : string, optional
|
||||||
|
Optimization method used by the model during compilation/training.
|
||||||
|
loss : string, optional
|
||||||
|
Loss function used by the model during compilation/training.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model=None):
|
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy'):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss = loss
|
||||||
|
self.compiled_model_ = None
|
||||||
self.classes_ = []
|
self.classes_ = []
|
||||||
self.config_ = []
|
self.config_ = []
|
||||||
self.weights_ = []
|
self.weights_ = []
|
||||||
@ -31,7 +41,7 @@ class KerasClassifier(object):
|
|||||||
params : dict
|
params : dict
|
||||||
Dictionary of parameter names mapped to their values.
|
Dictionary of parameter names mapped to their values.
|
||||||
"""
|
"""
|
||||||
return {'model': self.model}
|
return {'model': self.model, 'optimizer': self.optimizer, 'loss': self.loss}
|
||||||
|
|
||||||
def set_params(self, **params):
|
def set_params(self, **params):
|
||||||
"""
|
"""
|
||||||
@ -47,13 +57,17 @@ class KerasClassifier(object):
|
|||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
for parameter, value in params.items():
|
for parameter, value in params.items():
|
||||||
self.setattr(parameter, value)
|
setattr(self, parameter, value)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):
|
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):
|
||||||
"""
|
"""
|
||||||
Fit the model according to the given training data.
|
Fit the model according to the given training data.
|
||||||
|
|
||||||
|
Makes a copy of the un-compiled model definition to use for
|
||||||
|
compilation and fitting, leaving the original definition
|
||||||
|
intact.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : array-like, shape = (n_samples, n_features)
|
X : array-like, shape = (n_samples, n_features)
|
||||||
@ -77,11 +91,17 @@ class KerasClassifier(object):
|
|||||||
"""
|
"""
|
||||||
if len(y.shape) == 1:
|
if len(y.shape) == 1:
|
||||||
self.classes_ = list(np.unique(y))
|
self.classes_ = list(np.unique(y))
|
||||||
|
if self.loss == 'categorical_crossentropy':
|
||||||
|
y = to_categorical(y)
|
||||||
else:
|
else:
|
||||||
self.classes_ = np.arange(0, y.shape[1])
|
self.classes_ = np.arange(0, y.shape[1])
|
||||||
self.model.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, shuffle=shuffle)
|
|
||||||
|
self.compiled_model_ = copy.deepcopy(self.model)
|
||||||
|
self.compiled_model_.compile(optimizer=self.optimizer, loss=self.loss)
|
||||||
|
self.compiled_model_.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, shuffle=shuffle)
|
||||||
self.config_ = self.model.get_config()
|
self.config_ = self.model.get_config()
|
||||||
self.weights_ = self.model.get_weights()
|
self.weights_ = self.model.get_weights()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def score(self, X, y, batch_size=128, verbose=0):
|
def score(self, X, y, batch_size=128, verbose=0):
|
||||||
@ -105,7 +125,8 @@ class KerasClassifier(object):
|
|||||||
score : float
|
score : float
|
||||||
Mean accuracy of self.predict(X) wrt. y.
|
Mean accuracy of self.predict(X) wrt. y.
|
||||||
"""
|
"""
|
||||||
loss, accuracy = self.model.evaluate(X, y, batch_size=batch_size, show_accuracy=True, verbose=verbose)
|
loss, accuracy = self.compiled_model_.evaluate(X, y, batch_size=batch_size,
|
||||||
|
show_accuracy=True, verbose=verbose)
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
def predict(self, X, batch_size=128, verbose=0):
|
def predict(self, X, batch_size=128, verbose=0):
|
||||||
@ -127,7 +148,7 @@ class KerasClassifier(object):
|
|||||||
preds : array-like, shape = (n_samples)
|
preds : array-like, shape = (n_samples)
|
||||||
Class predictions.
|
Class predictions.
|
||||||
"""
|
"""
|
||||||
return self.model.predict_classes(X, batch_size=batch_size, verbose=verbose)
|
return self.compiled_model_.predict_classes(X, batch_size=batch_size, verbose=verbose)
|
||||||
|
|
||||||
def predict_proba(self, X, batch_size=128, verbose=0):
|
def predict_proba(self, X, batch_size=128, verbose=0):
|
||||||
"""
|
"""
|
||||||
@ -148,4 +169,4 @@ class KerasClassifier(object):
|
|||||||
proba : array-like, shape = (n_samples, n_outputs)
|
proba : array-like, shape = (n_samples, n_outputs)
|
||||||
Class probability estimates.
|
Class probability estimates.
|
||||||
"""
|
"""
|
||||||
return self.model.predict_proba(X, batch_size=batch_size, verbose=verbose)
|
return self.compiled_model_.predict_proba(X, batch_size=batch_size, verbose=verbose)
|
||||||
|
Loading…
Reference in New Issue
Block a user