updated testcase for class weights
This commit is contained in:
parent
35c2f36759
commit
2165fd03dc
@ -2,7 +2,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from keras.datasets import mnist
|
from keras.datasets import mnist
|
||||||
from keras.models import Sequential
|
from keras.models import Sequential
|
||||||
from keras.layers.core import Dense, Activation, Merge
|
from keras.layers.core import Dense, Activation, Merge, Dropout
|
||||||
|
from keras.optimizers import SGD
|
||||||
from keras.utils import np_utils
|
from keras.utils import np_utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -49,9 +50,9 @@ model_train = createMNISTModel()
|
|||||||
high_weight = 100
|
high_weight = 100
|
||||||
class_weight = {0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:1,8:1,9:high_weight}
|
class_weight = {0:1,1:1,2:1,3:1,4:1,5:1,6:1,7:1,8:1,9:high_weight}
|
||||||
|
|
||||||
########################
|
############################
|
||||||
#test different methods#
|
# categorical crossentropy #
|
||||||
########################
|
############################
|
||||||
|
|
||||||
print("Testing fit methods with and without classweights")
|
print("Testing fit methods with and without classweights")
|
||||||
# fit
|
# fit
|
||||||
@ -78,4 +79,86 @@ for nb in range(nb_classes):
|
|||||||
if class_weight[nb] == high_weight and (score_cw[1] <= score[1] or score_cw_train[1] <= score_train[1]):
|
if class_weight[nb] == high_weight and (score_cw[1] <= score[1] or score_cw_train[1] <= score_train[1]):
|
||||||
raise Exception('Class weights are not implemented correctly')
|
raise Exception('Class weights are not implemented correctly')
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# test cases for all remaining objective functions #
|
||||||
|
####################################################
|
||||||
|
batch_size = 64
|
||||||
|
nb_epoch = 10
|
||||||
|
|
||||||
|
np.random.seed(1337) # for reproducibility
|
||||||
|
|
||||||
|
def generateData(n_samples, n_dim):
|
||||||
|
A_feats = np.random.randn(n_samples, n_dim)
|
||||||
|
B_feats = np.random.randn(n_samples, n_dim)
|
||||||
|
A_label = np.zeros((n_samples,1))
|
||||||
|
B_label = np.ones((n_samples,1))
|
||||||
|
X = np.vstack((A_feats, B_feats))
|
||||||
|
y = np.vstack((A_label, B_label)).squeeze()
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
n_dim = 100
|
||||||
|
X_train, y_train = generateData(1000, n_dim)
|
||||||
|
X_test, y_test = generateData(5000, n_dim)
|
||||||
|
|
||||||
|
def createModel(ls, n_dim, activation="sigmoid"):
|
||||||
|
model = Sequential()
|
||||||
|
model.add(Dense(n_dim, 50))
|
||||||
|
model.add(Activation('relu'))
|
||||||
|
model.add(Dropout(0.5))
|
||||||
|
model.add(Dense(50, 1))
|
||||||
|
model.add(Activation(activation))
|
||||||
|
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
|
||||||
|
model.compile(loss=ls, optimizer=sgd, class_mode="binary")
|
||||||
|
return model
|
||||||
|
|
||||||
|
verbosity = 0
|
||||||
|
cw = {0: 1.5, 1: 1}
|
||||||
|
# binary crossentropy
|
||||||
|
model = createModel('binary_crossentropy', n_dim)
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
|
||||||
|
res = model.predict(X_test, verbose=verbosity).round()
|
||||||
|
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
|
||||||
|
assert(neg_preds > pos_preds)
|
||||||
|
print("binary crossentropy: %0.2f VS %0.2f" % (neg_preds, pos_preds))
|
||||||
|
|
||||||
|
# MAE
|
||||||
|
model = createModel('mean_absolute_error', n_dim)
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
|
||||||
|
res = model.predict(X_test, verbose=verbosity).round()
|
||||||
|
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
|
||||||
|
assert(neg_preds > pos_preds)
|
||||||
|
print("MAE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
|
||||||
|
|
||||||
|
# MSE
|
||||||
|
model = createModel('mean_squared_error', n_dim)
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
|
||||||
|
res = model.predict(X_test, verbose=verbosity).round()
|
||||||
|
neg_preds, pos_preds = (1.0*np.sum(res == 0)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
|
||||||
|
assert(neg_preds > pos_preds)
|
||||||
|
print("MSE: %0.2f VS %0.2f" % (neg_preds, pos_preds))
|
||||||
|
|
||||||
|
# hinge losses, map labels
|
||||||
|
y_train[y_train == 0] = -1
|
||||||
|
y_test[y_test == 0] = -1
|
||||||
|
cw = {-1: 1.5, 1: 1}
|
||||||
|
|
||||||
|
# hinge
|
||||||
|
model = createModel('hinge', n_dim, "tanh")
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
|
||||||
|
res = model.predict(X_test, verbose=verbosity)
|
||||||
|
res[res < 0] = -1
|
||||||
|
res[res >= 0] = 1
|
||||||
|
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
|
||||||
|
assert(neg_preds > pos_preds)
|
||||||
|
print("hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))
|
||||||
|
|
||||||
|
# squared hinge
|
||||||
|
model = createModel('squared_hinge', n_dim, "tanh")
|
||||||
|
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=verbosity, validation_data=(X_test, y_test), class_weight=cw)
|
||||||
|
res = model.predict(X_test, verbose=verbosity)
|
||||||
|
res[res < 0] = -1
|
||||||
|
res[res >= 0] = 1
|
||||||
|
neg_preds, pos_preds = (1.0*np.sum(res == -1)/res.shape[0], 1.0*np.sum(res == 1)/res.shape[0])
|
||||||
|
assert(neg_preds > pos_preds)
|
||||||
|
print("sqr hinge: %0.2f VS %0.2f" % (neg_preds, pos_preds))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user