Fix some docstring in keras_core/losses and fill in missing tests (#219)

* initials

* add tests
This commit is contained in:
Chen Qian 2023-05-25 21:45:35 -07:00 committed by Francois Chollet
parent a460a35362
commit 31cd921e50
2 changed files with 74 additions and 54 deletions

@ -47,7 +47,7 @@ class MeanSquaredError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -73,7 +73,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -99,7 +99,7 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -129,7 +129,7 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -168,7 +168,7 @@ class CosineSimilarity(LossFunctionWrapper):
(the features axis). Defaults to -1.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -257,7 +257,7 @@ class Hinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -284,7 +284,7 @@ class SquaredHinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -310,7 +310,7 @@ class CategoricalHinge(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -336,7 +336,7 @@ class KLDivergence(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -360,7 +360,7 @@ class Poisson(LossFunctionWrapper):
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
"""
@ -380,26 +380,25 @@ class BinaryCrossentropy(LossFunctionWrapper):
- `y_true` (true label): This is either 0 or 1.
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
floating-point value which either represents a
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
`from_logits=False`).
floating-point value which either represents a
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
`from_logits=False`).
Args:
from_logits: Whether to interpret `y_pred` as a tensor of
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
assume that `y_pred` contains probabilities (i.e., values in [0,
1]).
assume that `y_pred` is probabilities (i.e., values in [0, 1]).
label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs.
When > 0, we compute the loss between the predicted labels
and a smoothed version of the true labels, where the smoothing
squeezes the labels towards 0.5. Larger values of
`label_smoothing` correspond to heavier smoothing.
axis: The axis along which to compute crossentropy (the features
axis). Defaults to -1.
axis: The axis along which to compute crossentropy (the features axis).
Defaults to -1.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
Examples:
@ -493,10 +492,10 @@ class BinaryFocalCrossentropy(LossFunctionWrapper):
- `y_true` (true label): This is either 0 or 1.
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
floating-point value which either represents a
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
`from_logits=False`).
floating-point value which either represents a
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
`from_logits=False`).
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a "focal factor" to down-weight easy examples and focus more
@ -529,7 +528,7 @@ class BinaryFocalCrossentropy(LossFunctionWrapper):
Defaults to `-1`.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
Examples:
@ -665,11 +664,8 @@ class CategoricalCrossentropy(LossFunctionWrapper):
Use this crossentropy loss function when there are two or more label
classes. We expect labels to be provided in a `one_hot` representation. If
you want to provide labels as integers, please use
`SparseCategoricalCrossentropy` loss. There should be `# classes` floating
point values per feature.
In the snippet below, there is `# classes` floating pointing values per
example. The shape of both `y_pred` and `y_true` are
`SparseCategoricalCrossentropy` loss. There should be `num_classes` floating
point values per feature, i.e., the shape of both `y_pred` and `y_true` are
`[batch_size, num_classes]`.
Args:
@ -683,7 +679,7 @@ class CategoricalCrossentropy(LossFunctionWrapper):
axis). Defaults to -1.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
Examples:
@ -791,7 +787,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper):
Extending this to multi-class case is straightforward:
`FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)`
In the snippet below, there is `# classes` floating pointing values per
In the snippet below, there is `num_classes` floating pointing values per
example. The shape of both `y_pred` and `y_true` are
`(batch_size, num_classes)`.
@ -814,7 +810,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper):
axis). Defaults to -1.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
Examples:
@ -903,16 +899,16 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
feature for `y_true`.
In the snippet below, there is a single floating point value per example for
`y_true` and `# classes` floating pointing values per example for `y_pred`.
The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
`[batch_size, num_classes]`.
`y_true` and `num_classes` floating pointing values per example for
`y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred`
is `[batch_size, num_classes]`.
Args:
from_logits: Whether `y_pred` is expected to be a logits tensor. By
default, we assume that `y_pred` encodes a probability distribution.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`.
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name: Optional name for the loss instance.
Examples:
@ -1729,8 +1725,8 @@ def binary_focal_crossentropy(
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. By default, the focal tensor is computed as follows:
`focal_factor = (1 - output)**gamma` for class 1
`focal_factor = output**gamma` for class 0
`focal_factor = (1 - output) ** gamma` for class 1
`focal_factor = output ** gamma` for class 0
where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
effect on the binary crossentropy loss.

@ -37,8 +37,16 @@ class MeanSquaredErrorTest(testing.TestCase):
self.assertAlmostEqual(loss, 767.8 / 6)
def test_timestep_weighted(self):
# TODO
pass
mse_obj = losses.MeanSquaredError()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
loss = mse_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
self.assertAlmostEqual(loss, 97.833336)
def test_zero_weighted(self):
mse_obj = losses.MeanSquaredError()
@ -47,10 +55,6 @@ class MeanSquaredErrorTest(testing.TestCase):
loss = mse_obj(y_true, y_pred, sample_weight=0)
self.assertAlmostEqual(loss, 0.0)
def test_invalid_sample_weight(self):
# TODO
pass
def test_no_reduction(self):
mse_obj = losses.MeanSquaredError(reduction=None)
y_true = np.array([[1, 9, 2], [-5, -2, 6]])
@ -101,8 +105,16 @@ class MeanAbsoluteErrorTest(testing.TestCase):
self.assertAlmostEqual(loss, 81.4 / 6)
def test_timestep_weighted(self):
# TODO
pass
mae_obj = losses.MeanAbsoluteError()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
loss = mae_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
self.assertAlmostEqual(loss, 13.833333)
def test_zero_weighted(self):
mae_obj = losses.MeanAbsoluteError()
@ -111,10 +123,6 @@ class MeanAbsoluteErrorTest(testing.TestCase):
loss = mae_obj(y_true, y_pred, sample_weight=0)
self.assertAlmostEqual(loss, 0.0)
def test_invalid_sample_weight(self):
# TODO
pass
def test_no_reduction(self):
mae_obj = losses.MeanAbsoluteError(reduction=None)
y_true = np.array([[1, 9, 2], [-5, -2, 6]])
@ -165,8 +173,16 @@ class MeanAbsolutePercentageErrorTest(testing.TestCase):
self.assertAlmostEqual(loss, 422.8888, 3)
def test_timestep_weighted(self):
# TODO
pass
mape_obj = losses.MeanAbsolutePercentageError()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
loss = mape_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
self.assertAlmostEqual(loss, 694.4444)
def test_zero_weighted(self):
mape_obj = losses.MeanAbsolutePercentageError()
@ -212,8 +228,16 @@ class MeanSquaredLogarithmicErrorTest(testing.TestCase):
self.assertAlmostEqual(loss, 3.7856, 3)
def test_timestep_weighted(self):
# TODO
pass
msle_obj = losses.MeanSquaredLogarithmicError()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
loss = msle_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
self.assertAlmostEqual(loss, 2.647374)
def test_zero_weighted(self):
msle_obj = losses.MeanSquaredLogarithmicError()