Add cosine proximity objective

This commit is contained in:
Francois Chollet 2016-01-05 10:35:29 -08:00
parent 458641f33a
commit 6cb1172668
2 changed files with 13 additions and 1 deletions

@ -27,3 +27,5 @@ For a few examples of such functions, check out the [objectives source](https://
- __hinge__
- __binary_crossentropy__: Also known as logloss.
- __categorical_crossentropy__: Also known as multiclass logloss. __Note__: using this objective requires that your labels are binary arrays of shape `(nb_samples, nb_classes)`.
- __poisson__: mean of `(predictions - targets * log(predictions))`
- __cosine_proximity__: the opposite (negative) of the mean cosine proximity between predictions and targets.

@ -44,15 +44,25 @@ def binary_crossentropy(y_true, y_pred):
return K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1)
def poisson_loss(y_true, y_pred):
def poisson(y_true, y_pred):
return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
def cosine_proximity(y_true, y_pred):
assert K.ndim(y_true) == 2
assert K.ndim(y_pred) == 2
y_true = K.l2_normalize(y_true, axis=1)
y_pred = K.l2_normalize(y_pred, axis=1)
return -K.mean(y_true * y_pred, axis=1)
# aliases
mse = MSE = mean_squared_error
rmse = RMSE = root_mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
from .utils.generic_utils import get_from_module
def get(identifier):