Merge branch 'instance_weight' of https://github.com/tdhd/keras into tdhd-instance_weight

This commit is contained in:
fchollet 2015-06-20 14:46:40 -07:00
commit 34cbc1d401
3 changed files with 139 additions and 15 deletions

@ -13,7 +13,7 @@ model = keras.models.Sequential()
- __loss__: str (name of objective function) or objective function. See [objectives](objectives.md).
- __class_mode__: one of "categorical", "binary". This is only used for computing classification accuracy or using the predict_classes method.
- __theano_mode__: A `theano.compile.mode.Mode` ([reference](http://deeplearning.net/software/theano/library/compile/mode.html)) instance controlling specifying compilation options.
- __fit__(X, y, batch_size=128, nb_epoch=100, verbose=1, validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, callbacks=[], class_weight=None): Train a model for a fixed number of epochs.
- __fit__(X, y, batch_size=128, nb_epoch=100, verbose=1, validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, callbacks=[], class_weight=None, sample_weight=None): Train a model for a fixed number of epochs.
- __Return__: a history dictionary with a record of training loss values at successive epochs, as well as validation loss values (if applicable), accuracy (if applicable), etc.
- __Arguments__:
- __X__: data.
@ -27,6 +27,7 @@ model = keras.models.Sequential()
- __show_accuracy__: boolean. Whether to display class accuracy in the logs to stdout at each epoch.
- __callbacks__: `keras.callbacks.Callback` list. List of callbacks to apply during training. See [callbacks](callbacks.md).
- __class_weight__: If specified, must be a dictionary which maps every class label to its class weight. Scales the loss of every training sample depending on the passed weights.
- __sample_weight__: If specified, must either be a list or np.array with the same number of elements as y. Scales the loss of every training sample depending on the passed weights.
- __evaluate__(X, y, batch_size=128, show_accuracy=False, verbose=1): Show performance of the model over some validation data.
- __Return__: The loss score over the data.
- __Arguments__: Same meaning as fit method above. verbose is used as a binary flag (progress bar or nothing).

@ -41,8 +41,13 @@ def slice_X(X, start=None, stop=None):
else:
return X[start:stop]
def calculate_class_weights(Y, class_weight):
if isinstance(class_weight, dict):
def calculate_loss_weights(Y, sample_weight=None, class_weight=None):
if sample_weight is not None:
if isinstance(sample_weight, list):
w = np.array(sample_weight)
else:
w = sample_weight
elif isinstance(class_weight, dict):
if Y.shape[1] > 1:
y_classes = Y.argmax(axis=1)
elif Y.shape[1] == 1:
@ -71,9 +76,9 @@ class Model(object):
self.y = T.zeros_like(self.y_train)
# parameter for rescaling the objective function
self.class_weights = T.vector()
self.loss_weights = T.vector()
train_loss = self.loss(self.y, self.y_train, self.class_weights)
train_loss = self.loss(self.y, self.y_train, self.loss_weights)
test_score = self.loss(self.y, self.y_test)
if class_mode == "categorical":
@ -90,11 +95,11 @@ class Model(object):
updates = self.optimizer.get_updates(self.params, self.regularizers, self.constraints, train_loss)
if type(self.X_train) == list:
train_ins = self.X_train + [self.y, self.class_weights]
train_ins = self.X_train + [self.y, self.loss_weights]
test_ins = self.X_test + [self.y]
predict_ins = self.X_test
else:
train_ins = [self.X_train, self.y, self.class_weights]
train_ins = [self.X_train, self.y, self.loss_weights]
test_ins = [self.X_test, self.y]
predict_ins = [self.X_test]
@ -110,11 +115,12 @@ class Model(object):
allow_input_downcast=True, mode=theano_mode)
def train(self, X, y, accuracy=False, class_weight=None):
def train(self, X, y, accuracy=False, sample_weight=None, class_weight=None):
X = standardize_X(X)
y = standardize_y(y)
# calculate the weight vector for the loss function
w = calculate_class_weights(y, class_weight)
w = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
ins = X + [y, w]
if accuracy:
return self._train_with_acc(*ins)
@ -133,11 +139,14 @@ class Model(object):
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True, show_accuracy=False, class_weight=None):
validation_split=0., validation_data=None, shuffle=True, show_accuracy=False,
sample_weight=None, class_weight=None):
X = standardize_X(X)
y = standardize_y(y)
sample_weight = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
do_validation = False
if validation_data:
try:
@ -191,8 +200,7 @@ class Model(object):
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
# calculate weight vector for current batch
w = calculate_class_weights(y_batch, class_weight)
w = sample_weight[batch_ids]
batch_logs = {}
batch_logs['batch'] = batch_index

@ -85,8 +85,6 @@ for nb in range(nb_classes):
batch_size = 64
nb_epoch = 10
np.random.seed(1337) # for reproducibility
def generateData(n_samples, n_dim):
A_feats = np.random.randn(n_samples, n_dim)
B_feats = np.random.randn(n_samples, n_dim)
@ -162,3 +160,120 @@ neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)
assert(neg_preds > pos_preds)
print("sqr hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))
############################
# sample weight test cases #
############################
batch_size = 128
nb_epoch = 15
# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000,784)[:max_train_samples]
X_test = X_test.reshape(10000,784)[:max_test_samples]
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train /= 255
X_test /= 255
# convert class vectors to binary class matrices
y_train = y_train[:max_train_samples]
y_test = y_test[:max_test_samples]
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
print("Sample weight test cases")
# categorical crossentropy
model_sampleweights_fit = createMNISTModel()
model_sampleweights_train = createMNISTModel()
model_fit = createMNISTModel()
model_train = createMNISTModel()
sample_weight = np.ones((Y_train.shape[0]))
sample_weight[y_train == 9] = high_weight
model_sampleweights_fit.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=0, validation_data=(X_test, Y_test), sample_weight=sample_weight)
model_fit.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=0, validation_data=(X_test, Y_test))
model_sampleweights_train.train(X_train, Y_train, sample_weight=sample_weight)
model_train.train(X_train, Y_train)
print('MNIST Classification accuracies on test set for fitted models:')
for nb in range(nb_classes):
testIdcs = np.where(y_test == np.array(nb))[0]
X_temp = X_test[testIdcs, :]
Y_temp = Y_test[testIdcs,:]
# eval model which was trained with fit()
score_sw = model_sampleweights_fit.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
score = model_fit.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
# eval model which was trained with train()
score_sw_train = model_sampleweights_train.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
score_train = model_train.evaluate(X_temp, Y_temp, show_accuracy=True, verbose=0)
# print test accuracies for class weighted model vs. uniform weights
print("Digit %d: sample_weight = %d -> %.3f \t sample_weight = %d -> %.3f" % (nb, 100 if nb == 9 else 1, score_sw[1], 1, score[1]))
if nb == 9 and (score_sw[1] <= score[1] or score_sw_train[1] <= score_train[1]):
raise Exception('Sample weights are not implemented correctly')
n_dim = 100
X_train, y_train = generateData(1000, n_dim)
X_test, y_test = generateData(5000, n_dim)
y_train[y_train == -1] = 0
y_test[y_test == -1] = 0
sample_weight = np.ones((y_train.shape[0]))
sample_weight[y_train == 0] = 1.5
# binary crossentropy
model = createModel('binary_crossentropy', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), sample_weight=sample_weight)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("binary crossentropy: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# MAE
model = createModel('mean_absolute_error', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), sample_weight=sample_weight)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("MAE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# MSE
model = createModel('mean_squared_error', n_dim)
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), sample_weight=sample_weight)
res = model.predict(X_test, verbose=verbosity).round()
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("MSE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# hinge losses, map labels
y_train[y_train == 0] = -1
y_test[y_test == 0] = -1
sample_weight = np.ones((y_train.shape[0]))
sample_weight[y_train == -1] = 1.5
# hinge
model = createModel('hinge', n_dim, "tanh")
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), sample_weight=sample_weight)
res = model.predict(X_test, verbose=verbosity)
res[res < 0] = -1
res[res >= 0] = 1
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))
# squared hinge
model = createModel('squared_hinge', n_dim, "tanh")
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), sample_weight=sample_weight)
res = model.predict(X_test, verbose=verbosity)
res[res < 0] = -1
res[res >= 0] = 1
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
assert(neg_preds > pos_preds)
print("sqr hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))