refactored weighting in models.py

This commit is contained in:
Philipp 2015-06-19 21:30:23 +02:00
parent 1b9f35f535
commit 824f9f5e80

@ -41,17 +41,11 @@ def slice_X(X, start=None, stop=None):
else:
return X[start:stop]
def standardize_sample_weights(sample_weight):
if isinstance(sample_weight, list):
w = np.array(sample_weight)
elif isinstance(sample_weight, np.ndarray):
w = sample_weight
else:
w = None
return w
def calculate_loss_weights(Y, sample_weight=None, class_weight=None):
if sample_weight is not None:
if isinstance(sample_weight, list):
w = np.array(sample_weight)
else:
w = sample_weight
elif isinstance(class_weight, dict):
if Y.shape[1] > 1:
@ -125,7 +119,6 @@ class Model(object):
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_sample_weights(sample_weight)
w = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
ins = X + [y, w]
@ -152,7 +145,7 @@ class Model(object):
X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_sample_weights(sample_weight)
sample_weight = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
do_validation = False
if validation_data:
@ -207,10 +200,7 @@ class Model(object):
X_batch = slice_X(X, batch_ids)
y_batch = y[batch_ids]
if sample_weight is not None:
w = calculate_loss_weights(y_batch, sample_weight=sample_weight[batch_ids])
else:
w = calculate_loss_weights(y_batch, class_weight=class_weight)
w = sample_weight[batch_ids]
batch_logs = {}
batch_logs['batch'] = batch_index