Merge pull request #188 from tdhd/classweights

Add support for class_weight in fit
This commit is contained in:
François Chollet 2015-06-17 10:25:34 -07:00
commit 4830b4be27
4 changed files with 232 additions and 24 deletions

@ -13,7 +13,7 @@ model = keras.models.Sequential()
- __loss__: str (name of objective function) or objective function. See [objectives](objectives.md).
- __class_mode__: one of "categorical", "binary". This is only used for computing classification accuracy or using the predict_classes method.
- __theano_mode__: A `theano.compile.mode.Mode` ([reference](http://deeplearning.net/software/theano/library/compile/mode.html)) instance controlling specifying compilation options.
- __fit__(X, y, batch_size=128, nb_epoch=100, verbose=1, validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, callbacks=[]): Train a model for a fixed number of epochs.
- __fit__(X, y, batch_size=128, nb_epoch=100, verbose=1, validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, callbacks=[], class_weight=None): Train a model for a fixed number of epochs.
- __Return__: a history dictionary with a record of training loss values at successive epochs, as well as validation loss values (if applicable), accuracy (if applicable), etc.
- __Arguments__:
- __X__: data.
@ -26,6 +26,7 @@ model = keras.models.Sequential()
- __shuffle__: boolean. Whether to shuffle the samples at each epoch.
- __show_accuracy__: boolean. Whether to display class accuracy in the logs to stdout at each epoch.
- __callbacks__: `keras.callbacks.Callback` list. List of callbacks to apply during training. See [callbacks](callbacks.md).
- __class_weight__: If specified, must be a dictionary which maps every class label to its class weight. Scales the loss of every training sample depending on the passed weights.
- __evaluate__(X, y, batch_size=128, show_accuracy=False, verbose=1): Show performance of the model over some validation data.
- __Return__: The loss score over the data.
- __Arguments__: Same meaning as fit method above. verbose is used as a binary flag (progress bar or nothing).
@ -35,7 +36,7 @@ model = keras.models.Sequential()
- __predict_classes__(X, batch_size=128, verbose=1): Return an array of class predictions for some test data.
- __Return__: An array of labels for some test data.
- __Arguments__: Same meaning as fit method above. verbose is used as a binary flag (progress bar or nothing).
- __train__(X, y, accuracy=False): Single gradient update on one batch. if accuracy==False, return tuple (loss_on_batch, accuracy_on_batch). Else, return loss_on_batch.
- __train__(X, y, accuracy=False, class_weight=None): Single gradient update on one batch. if accuracy==False, return tuple (loss_on_batch, accuracy_on_batch). Else, return loss_on_batch.
- __Return__: loss over the data, or tuple `(loss, accuracy)` if `accuracy=True`.
- __test__(X, y, accuracy=False): Single performance evaluation on one batch. if accuracy==False, return tuple (loss_on_batch, accuracy_on_batch). Else, return loss_on_batch.
- __Return__: loss over the data, or tuple `(loss, accuracy)` if `accuracy=True`.

@ -41,6 +41,18 @@ def slice_X(X, start=None, stop=None):
else:
return X[start:stop]
def calculate_class_weights(Y, class_weight):
if isinstance(class_weight, dict):
if Y.shape[1] > 1:
y_classes = Y.argmax(axis=1)
elif Y.shape[1] == 1:
y_classes = np.reshape(Y, Y.shape[0])
else:
y_classes = Y
w = np.array(map(lambda x: class_weight[x], y_classes))
else:
w = np.ones((Y.shape[0]))
return w
class Model(object):
@ -58,7 +70,10 @@ class Model(object):
# target of model
self.y = T.zeros_like(self.y_train)
train_loss = self.loss(self.y, self.y_train)
# parameter for rescaling the objective function
self.class_weights = T.vector()
train_loss = self.loss(self.y, self.y_train, self.class_weights)
test_score = self.loss(self.y, self.y_test)
if class_mode == "categorical":
@ -75,11 +90,11 @@ class Model(object):
updates = self.optimizer.get_updates(self.params, self.regularizers, self.constraints, train_loss)
if type(self.X_train) == list:
train_ins = self.X_train + [self.y]
train_ins = self.X_train + [self.y, self.class_weights]
test_ins = self.X_test + [self.y]
predict_ins = self.X_test
else:
train_ins = [self.X_train, self.y]
train_ins = [self.X_train, self.y, self.class_weights]
test_ins = [self.X_test, self.y]
predict_ins = [self.X_test]
@ -95,10 +110,12 @@ class Model(object):
allow_input_downcast=True, mode=theano_mode)
def train(self, X, y, accuracy=False):
def train(self, X, y, accuracy=False, class_weight=None):
X = standardize_X(X)
y = standardize_y(y)
ins = X + [y]
# calculate the weight vector for the loss function
w = calculate_class_weights(y, class_weight)
ins = X + [y, w]
if accuracy:
return self._train_with_acc(*ins)
else:
@ -116,8 +133,8 @@ class Model(object):
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True, show_accuracy=False):
validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, class_weight=None):
X = standardize_X(X)
y = standardize_y(y)
@ -174,12 +191,15 @@ class Model(object):
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
# calculate weight vector for current batch
w = calculate_class_weights(y_batch, class_weight)
batch_logs = {}
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
ins = X_batch + [y_batch]
ins = X_batch + [y_batch, w]
if show_accuracy:
loss, acc = self._train_with_acc(*ins)
batch_logs['accuracy'] = acc
@ -289,7 +309,6 @@ class Sequential(Model, containers.Sequential):
- predict
- predict_proba
- predict_classes
Inherits from containers.Sequential the following methods:
- add
- get_output
@ -353,4 +372,5 @@ class Sequential(Model, containers.Sequential):
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
self.layers[k].set_weights(weights)
f.close()

@ -6,29 +6,52 @@ from six.moves import range
epsilon = 1.0e-9
def mean_squared_error(y_true, y_pred):
return T.sqr(y_pred - y_true).mean()
def mean_squared_error(y_true, y_pred, weight=None):
if weight is not None:
return T.sqr(weight.reshape((weight.shape[0], 1))*(y_pred - y_true)).mean()
else:
return T.sqr(y_pred - y_true).mean()
def mean_absolute_error(y_true, y_pred):
return T.abs_(y_pred - y_true).mean()
def mean_absolute_error(y_true, y_pred, weight=None):
if weight is not None:
return T.abs_(weight.reshape((weight.shape[0], 1))*(y_pred - y_true)).mean()
else:
return T.abs_(y_pred - y_true).mean()
def squared_hinge(y_true, y_pred):
return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean()
def squared_hinge(y_true, y_pred, weight=None):
if weight is not None:
weight = weight.reshape((weight.shape[0], 1))
return T.sqr(weight*T.maximum(1. - (y_true * y_pred), 0.)).mean()
else:
return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean()
def hinge(y_true, y_pred):
return T.maximum(1. - y_true * y_pred, 0.).mean()
def hinge(y_true, y_pred, weight=None):
if weight is not None:
weight = weight.reshape((weight.shape[0], 1))
return (weight*T.maximum(1. - (y_true * y_pred), 0.)).mean()
else:
return T.maximum(1. - y_true * y_pred, 0.).mean()
def categorical_crossentropy(y_true, y_pred):
def categorical_crossentropy(y_true, y_pred, weight=None):
'''Expects a binary class matrix instead of a vector of scalar classes
'''
y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon)
# scale preds so that the class probas of each sample sum to 1
y_pred /= y_pred.sum(axis=1, keepdims=True)
return T.nnet.categorical_crossentropy(y_pred, y_true).mean()
cce = T.nnet.categorical_crossentropy(y_pred, y_true)
if weight is not None:
# return avg. of scaled cat. crossentropy
return (weight*cce).mean()
else:
return cce.mean()
def binary_crossentropy(y_true, y_pred):
def binary_crossentropy(y_true, y_pred, weight=None):
y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon)
return T.nnet.binary_crossentropy(y_pred, y_true).mean()
bce = T.nnet.binary_crossentropy(y_pred, y_true)
if weight is not None:
return (weight.reshape((weight.shape[0], 1))*bce).mean()
else:
return bce.mean()
# aliases
mse = MSE = mean_squared_error

164
test/test_classweights.py Normal file

@ -0,0 +1,164 @@
from __future__ import absolute_import
from __future__ import print_function
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Merge, Dropout
from keras.optimizers import SGD
from keras.utils import np_utils
import numpy as np
nb_classes = 10
batch_size = 128
nb_epoch = 15
max_train_samples = 5000
max_test_samples = 1000
np.random.seed(1337) # for reproducibility
# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000,784)[:max_train_samples]
X_test = X_test.reshape(10000,784)[:max_test_samples]
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train /= 255
X_test /= 255
# convert class vectors to binary class matrices
y_train = y_train[:max_train_samples]
y_test = y_test[:max_test_samples]
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
def createMNISTModel():
model = Sequential()
model.add(Dense(784, 50))
model.add(Activation('relu'))
model.add(Dense(50, 10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
return model
model_classweights_fit = createMNISTModel()
model_fit = createMNISTModel()
model_classweights_train = createMNISTModel()
model_train = createMNISTModel()
high_weight = 100
class_weight = {0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:1,8:1,9:high_weight}
############################
# categorical crossentropy #
############################
print("Testing fit methods with and without classweights")
# fit
model_classweights_fit.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=0, validation_data=(X_test, Y_test), class_weight=class_weight)
model_fit.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=0, validation_data=(X_test, Y_test))
print("Testing train methods with and without classweights")
# train
model_classweights_train.train(X_train, Y_train, class_weight=class_weight)
model_train.train(X_train, Y_train)
print('MNIST Classification accuracies on test set for fitted models:')
for nb in range(nb_classes):
testIdcs = np.where(y_test == np.array(nb))[0]
X_temp = X_test[testIdcs, :]
Y_temp = Y_test[testIdcs,:]
# eval model which was trained with fit()
score_cw = model_classweights_fit.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
score = model_fit.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
# eval model which was trained with train()
score_cw_train = model_classweights_train.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
score_train = model_train.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
# print test accuracies for class weighted model vs. uniform weights
print("Digit %d: class_weight = %d -> %.3f \t class_weight = %d -> %.3f" % (nb, class_weight[nb], score_cw[1], 1, score[1]))
if class_weight[nb] == high_weight and (score_cw[1] <= score[1] or score_cw_train[1] <= score_train[1]):
raise Exception('Class weights are not implemented correctly')
####################################################
# test cases for all remaining objective functions #
####################################################
batch_size = 64
nb_epoch = 10
np.random.seed(1337) # for reproducibility
def generateData(n_samples, n_dim):
A_feats = np.random.randn(n_samples, n_dim)
B_feats = np.random.randn(n_samples, n_dim)
A_label = np.zeros((n_samples,1))
B_label = np.ones((n_samples,1))
X = np.vstack((A_feats, B_feats))
y = np.vstack((A_label, B_label)).squeeze()
return X, y
n_dim = 100
X_train, y_train = generateData(1000, n_dim)
X_test, y_test = generateData(5000, n_dim)
def createModel(ls, n_dim, activation="sigmoid"):
model = Sequential()
model.add(Dense(n_dim, 50))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(50, 1))
model.add(Activation(activation))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=ls, optimizer=sgd, class_mode="binary")
return model
verbosity = 0
cw = {0: 1.5, 1: 1}
# binary crossentropy
model = createModel('binary_crossentropy', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("binary crossentropy: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# MAE
model = createModel('mean_absolute_error', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("MAE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# MSE
model = createModel('mean_squared_error', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("MSE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# hinge losses, map labels
y_train[y_train == 0] = -1
y_test[y_test == 0] = -1
cw = {-1: 1.5, 1: 1}
# hinge
model = createModel('hinge', n_dim, "tanh")
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
res = model.predict(X_test, verbose=verbosity)
res[res < 0] = -1
res[res >= 0] = 1
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# squared hinge
model = createModel('squared_hinge', n_dim, "tanh")
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
res = model.predict(X_test, verbose=verbosity)
res[res < 0] = -1
res[res >= 0] = 1
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("sqr hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))