refactored weighting in models.py
This commit is contained in:
parent
1b9f35f535
commit
824f9f5e80
@ -41,18 +41,12 @@ def slice_X(X, start=None, stop=None):
|
||||
else:
|
||||
return X[start:stop]
|
||||
|
||||
def standardize_sample_weights(sample_weight):
|
||||
if isinstance(sample_weight, list):
|
||||
w = np.array(sample_weight)
|
||||
elif isinstance(sample_weight, np.ndarray):
|
||||
w = sample_weight
|
||||
else:
|
||||
w = None
|
||||
return w
|
||||
|
||||
def calculate_loss_weights(Y, sample_weight=None, class_weight=None):
|
||||
if sample_weight is not None:
|
||||
w = sample_weight
|
||||
if isinstance(sample_weight, list):
|
||||
w = np.array(sample_weight)
|
||||
else:
|
||||
w = sample_weight
|
||||
elif isinstance(class_weight, dict):
|
||||
if Y.shape[1] > 1:
|
||||
y_classes = Y.argmax(axis=1)
|
||||
@ -125,7 +119,6 @@ class Model(object):
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
|
||||
sample_weight = standardize_sample_weights(sample_weight)
|
||||
w = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
|
||||
|
||||
ins = X + [y, w]
|
||||
@ -152,7 +145,7 @@ class Model(object):
|
||||
X = standardize_X(X)
|
||||
y = standardize_y(y)
|
||||
|
||||
sample_weight = standardize_sample_weights(sample_weight)
|
||||
sample_weight = calculate_loss_weights(y, sample_weight=sample_weight, class_weight=class_weight)
|
||||
|
||||
do_validation = False
|
||||
if validation_data:
|
||||
@ -207,10 +200,7 @@ class Model(object):
|
||||
X_batch = slice_X(X, batch_ids)
|
||||
y_batch = y[batch_ids]
|
||||
|
||||
if sample_weight is not None:
|
||||
w = calculate_loss_weights(y_batch, sample_weight=sample_weight[batch_ids])
|
||||
else:
|
||||
w = calculate_loss_weights(y_batch, class_weight=class_weight)
|
||||
w = sample_weight[batch_ids]
|
||||
|
||||
batch_logs = {}
|
||||
batch_logs['batch'] = batch_index
|
||||
|
Loading…
Reference in New Issue
Block a user