Add cosine similarity loss and update l2_normalize from regularizers (#34)

* Begin cosine loss

* Add testing for cosine similarity

* Fix formatting

* Docstring standardization

* Formatting

* Create numerical_utils
This commit is contained in:
Gabriel Rasskin 2023-04-27 19:09:42 -04:00 committed by Francois Chollet
parent 4d18257aed
commit d695e974a2
5 changed files with 175 additions and 12 deletions

@ -4,6 +4,7 @@ from keras_core.api_export import keras_core_export
from keras_core.losses.loss import Loss
from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.saving import serialization_lib
from keras_core.utils.numerical_utils import l2_normalize
class LossFunctionWrapper(Loss):
@ -44,7 +45,7 @@ class MeanSquaredError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance.
"""
@ -70,7 +71,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance.
"""
@ -96,7 +97,7 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance.
"""
@ -126,7 +127,7 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance.
"""
@ -143,6 +144,43 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.CosineSimilarity")
class CosineSimilarity(LossFunctionWrapper):
"""Computes the cosine similarity between `y_true` & `y_pred`.
Note that it is a number between -1 and 1. When it is a negative number
between -1 and 0, 0 indicates orthogonality and values closer to -1
indicate greater similarity. This makes it usable as a loss function in a
setting where you try to maximize the proximity between predictions and
targets. If either `y_true` or `y_pred` is a zero vector, cosine similarity
will be 0 regardless of the proximity between predictions and targets.
Formula:
```python
loss = mean(square(log(y_true + 1) - log(y_pred + 1)))
```
Args:
axis: The axis along which the cosine similarity is computed
(the features axis). Defaults to -1.
reduction: Type of reduction to apply to loss. Options are `"sum"`,
`"sum_over_batch_size"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the instance.
"""
def __init__(
self,
axis=-1,
reduction="sum_over_batch_size",
name="cosine_similarity",
):
super().__init__(
cosine_similarity, reduction=reduction, name=name, axis=axis
)
@keras_core_export("keras_core.losses.Hinge")
class Hinge(LossFunctionWrapper):
"""Computes the hinge loss between `y_true` & `y_pred`.
@ -159,7 +197,7 @@ class Hinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance. Defaults to `"hinge"`
"""
@ -186,7 +224,7 @@ class SquaredHinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance. Defaults to `"squared_hinge"`
"""
@ -212,7 +250,7 @@ class CategoricalHinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to loss. For almost all cases
this defaults to `"sum_over_batch_size"`. Options are `"sum"`,
`"sum_over_batch_size"` or None.
`"sum_over_batch_size"` or `None`.
name: Optional name for the instance. Defaults to
`"categorical_hinge"`
"""
@ -505,3 +543,42 @@ def mean_squared_logarithmic_error(y_true, y_pred):
first_log = ops.log(ops.maximum(y_pred, epsilon) + 1.0)
second_log = ops.log(ops.maximum(y_true, epsilon) + 1.0)
return ops.mean(ops.square(first_log - second_log), axis=-1)
@keras_core_export("keras_core.losses.cosine_similarity")
def cosine_similarity(y_true, y_pred, axis=-1):
"""Computes the cosine similarity between labels and predictions.
Formula:
```python
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
```
Note that it is a number between -1 and 1. When it is a negative number
between -1 and 0, 0 indicates orthogonality and values closer to -1
indicate greater similarity. This makes it usable as a loss function in a
setting where you try to maximize the proximity between predictions and
targets. If either `y_true` or `y_pred` is a zero vector, cosine
similarity will be 0 regardless of the proximity between predictions
and targets.
Standalone usage:
>>> y_true = [[0., 1.], [1., 1.], [1., 1.]]
>>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]
>>> loss = keras_core.losses.cosine_similarity(y_true, y_pred, axis=-1)
[-0., -0.99999994, 0.99999994]
Args:
y_true: Tensor of true targets.
y_pred: Tensor of predicted targets.
axis: Axis along which to determine similarity. Defaults to -1.
Returns:
Cosine similarity tensor.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
y_pred = l2_normalize(y_pred, axis=axis)
y_true = l2_normalize(y_true, axis=axis)
return -ops.sum(y_true * y_pred, axis=axis)

@ -382,3 +382,83 @@ class CategoricalHingeTest(testing.TestCase):
hinge_obj = losses.CategoricalHinge()
loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(loss, 0.0)
class CosineSimilarityTest(testing.TestCase):
def l2_norm(self, x, axis):
epsilon = 1e-12
square_sum = np.sum(np.square(x), axis=axis, keepdims=True)
x_inv_norm = 1 / np.sqrt(np.maximum(square_sum, epsilon))
return np.multiply(x, x_inv_norm)
def setup(self, axis=1):
self.np_y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)
self.np_y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)
y_true = self.l2_norm(self.np_y_true, axis)
y_pred = self.l2_norm(self.np_y_pred, axis)
self.expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(axis,))
self.y_true = self.np_y_true
self.y_pred = self.np_y_pred
def test_config(self):
cosine_obj = losses.CosineSimilarity(
axis=2, reduction="sum", name="cosine_loss"
)
self.assertEqual(cosine_obj.name, "cosine_loss")
self.assertEqual(cosine_obj.reduction, "sum")
def test_unweighted(self):
self.setup()
cosine_obj = losses.CosineSimilarity()
loss = cosine_obj(self.y_true, self.y_pred)
expected_loss = -np.mean(self.expected_loss)
self.assertAlmostEqual(loss, expected_loss, 3)
def test_scalar_weighted(self):
self.setup()
cosine_obj = losses.CosineSimilarity()
sample_weight = 2.3
loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
expected_loss = -np.mean(self.expected_loss * sample_weight)
self.assertAlmostEqual(loss, expected_loss, 3)
def test_sample_weighted(self):
self.setup()
cosine_obj = losses.CosineSimilarity()
sample_weight = np.asarray([1.2, 3.4])
loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
expected_loss = -np.mean(self.expected_loss * sample_weight)
self.assertAlmostEqual(loss, expected_loss, 3)
def test_timestep_weighted(self):
self.setup()
cosine_obj = losses.CosineSimilarity()
np_y_true = self.np_y_true.reshape((2, 3, 1))
np_y_pred = self.np_y_pred.reshape((2, 3, 1))
sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape((2, 3))
y_true = self.l2_norm(np_y_true, 2)
y_pred = self.l2_norm(np_y_pred, 2)
expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(2,))
y_true = np_y_true
y_pred = np_y_pred
loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)
expected_loss = -np.mean(expected_loss * sample_weight)
self.assertAlmostEqual(loss, expected_loss, 3)
def test_zero_weighted(self):
self.setup()
cosine_obj = losses.CosineSimilarity()
loss = cosine_obj(self.y_true, self.y_pred, sample_weight=0)
self.assertAlmostEqual(loss, 0.0, 3)
def test_axis(self):
self.setup(axis=1)
cosine_obj = losses.CosineSimilarity(axis=1)
loss = cosine_obj(self.y_true, self.y_pred)
expected_loss = -np.mean(self.expected_loss)
self.assertAlmostEqual(loss, expected_loss, 3)

@ -1,6 +1,7 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.backend import floatx
from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.metrics import reduction_metrics

@ -2,6 +2,7 @@ import math
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.utils.numerical_utils import l2_normalize
@keras_core_export(
@ -344,8 +345,3 @@ def validate_float_arg(value, name):
f"Received: {name}={value}"
)
return float(value)
def l2_normalize(x, axis=0):
l2_norm = ops.sqrt(ops.sum(ops.square(x), axis=axis))
return x / l2_norm

@ -0,0 +1,9 @@
from keras_core import backend
from keras_core import operations as ops
def l2_normalize(x, axis=0):
epsilon = backend.epsilon()
square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon)))
return ops.multiply(x, l2_norm)