moved weight multiplication outside of T.max
This commit is contained in:
parent
24d735ecb1
commit
35c2f36759
@ -20,13 +20,15 @@ def mean_absolute_error(y_true, y_pred, weight=None):
|
||||
|
||||
def squared_hinge(y_true, y_pred, weight=None):
|
||||
if weight is not None:
|
||||
return T.sqr(T.maximum(1. - weight.reshape((weight.shape[0], 1))*(y_true * y_pred), 0.)).mean()
|
||||
weight = weight.reshape((weight.shape[0], 1))
|
||||
return T.sqr(weight*T.maximum(1. - (y_true * y_pred), 0.)).mean()
|
||||
else:
|
||||
return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean()
|
||||
|
||||
def hinge(y_true, y_pred, weight=None):
|
||||
if weight is not None:
|
||||
return T.maximum(1. - weight.reshape((weight.shape[0], 1))*(y_true * y_pred), 0.).mean()
|
||||
weight = weight.reshape((weight.shape[0], 1))
|
||||
return (weight*T.maximum(1. - (y_true * y_pred), 0.)).mean()
|
||||
else:
|
||||
return T.maximum(1. - y_true * y_pred, 0.).mean()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user