Fix some docstring in keras_core/losses and fill in missing tests (#219)
* initials * add tests
This commit is contained in:
parent
a460a35362
commit
31cd921e50
@ -47,7 +47,7 @@ class MeanSquaredError(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -73,7 +73,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -99,7 +99,7 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -129,7 +129,7 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -168,7 +168,7 @@ class CosineSimilarity(LossFunctionWrapper):
|
||||
(the features axis). Defaults to -1.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -257,7 +257,7 @@ class Hinge(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -284,7 +284,7 @@ class SquaredHinge(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -310,7 +310,7 @@ class CategoricalHinge(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -336,7 +336,7 @@ class KLDivergence(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -360,7 +360,7 @@ class Poisson(LossFunctionWrapper):
|
||||
Args:
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
"""
|
||||
|
||||
@ -380,26 +380,25 @@ class BinaryCrossentropy(LossFunctionWrapper):
|
||||
|
||||
- `y_true` (true label): This is either 0 or 1.
|
||||
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
|
||||
floating-point value which either represents a
|
||||
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
||||
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
|
||||
`from_logits=False`).
|
||||
floating-point value which either represents a
|
||||
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
||||
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
|
||||
`from_logits=False`).
|
||||
|
||||
Args:
|
||||
from_logits: Whether to interpret `y_pred` as a tensor of
|
||||
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
|
||||
assume that `y_pred` contains probabilities (i.e., values in [0,
|
||||
1]).
|
||||
assume that `y_pred` is probabilities (i.e., values in [0, 1]).
|
||||
label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs.
|
||||
When > 0, we compute the loss between the predicted labels
|
||||
and a smoothed version of the true labels, where the smoothing
|
||||
squeezes the labels towards 0.5. Larger values of
|
||||
`label_smoothing` correspond to heavier smoothing.
|
||||
axis: The axis along which to compute crossentropy (the features
|
||||
axis). Defaults to -1.
|
||||
axis: The axis along which to compute crossentropy (the features axis).
|
||||
Defaults to -1.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
|
||||
Examples:
|
||||
@ -493,10 +492,10 @@ class BinaryFocalCrossentropy(LossFunctionWrapper):
|
||||
|
||||
- `y_true` (true label): This is either 0 or 1.
|
||||
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
|
||||
floating-point value which either represents a
|
||||
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
||||
when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
|
||||
`from_logits=False`).
|
||||
floating-point value which either represents a
|
||||
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
||||
when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when
|
||||
`from_logits=False`).
|
||||
|
||||
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
|
||||
helps to apply a "focal factor" to down-weight easy examples and focus more
|
||||
@ -529,7 +528,7 @@ class BinaryFocalCrossentropy(LossFunctionWrapper):
|
||||
Defaults to `-1`.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
|
||||
Examples:
|
||||
@ -665,11 +664,8 @@ class CategoricalCrossentropy(LossFunctionWrapper):
|
||||
Use this crossentropy loss function when there are two or more label
|
||||
classes. We expect labels to be provided in a `one_hot` representation. If
|
||||
you want to provide labels as integers, please use
|
||||
`SparseCategoricalCrossentropy` loss. There should be `# classes` floating
|
||||
point values per feature.
|
||||
|
||||
In the snippet below, there is `# classes` floating pointing values per
|
||||
example. The shape of both `y_pred` and `y_true` are
|
||||
`SparseCategoricalCrossentropy` loss. There should be `num_classes` floating
|
||||
point values per feature, i.e., the shape of both `y_pred` and `y_true` are
|
||||
`[batch_size, num_classes]`.
|
||||
|
||||
Args:
|
||||
@ -683,7 +679,7 @@ class CategoricalCrossentropy(LossFunctionWrapper):
|
||||
axis). Defaults to -1.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
|
||||
Examples:
|
||||
@ -791,7 +787,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper):
|
||||
Extending this to multi-class case is straightforward:
|
||||
`FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)`
|
||||
|
||||
In the snippet below, there is `# classes` floating pointing values per
|
||||
In the snippet below, there is `num_classes` floating pointing values per
|
||||
example. The shape of both `y_pred` and `y_true` are
|
||||
`(batch_size, num_classes)`.
|
||||
|
||||
@ -814,7 +810,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper):
|
||||
axis). Defaults to -1.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
|
||||
Examples:
|
||||
@ -903,16 +899,16 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
|
||||
feature for `y_true`.
|
||||
|
||||
In the snippet below, there is a single floating point value per example for
|
||||
`y_true` and `# classes` floating pointing values per example for `y_pred`.
|
||||
The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
|
||||
`[batch_size, num_classes]`.
|
||||
`y_true` and `num_classes` floating pointing values per example for
|
||||
`y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred`
|
||||
is `[batch_size, num_classes]`.
|
||||
|
||||
Args:
|
||||
from_logits: Whether `y_pred` is expected to be a logits tensor. By
|
||||
default, we assume that `y_pred` encodes a probability distribution.
|
||||
reduction: Type of reduction to apply to the loss. In almost all cases
|
||||
this should be `"sum_over_batch_size"`.
|
||||
Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
|
||||
name: Optional name for the loss instance.
|
||||
|
||||
Examples:
|
||||
@ -1729,8 +1725,8 @@ def binary_focal_crossentropy(
|
||||
helps to apply a focal factor to down-weight easy examples and focus more on
|
||||
hard examples. By default, the focal tensor is computed as follows:
|
||||
|
||||
`focal_factor = (1 - output)**gamma` for class 1
|
||||
`focal_factor = output**gamma` for class 0
|
||||
`focal_factor = (1 - output) ** gamma` for class 1
|
||||
`focal_factor = output ** gamma` for class 0
|
||||
where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
|
||||
effect on the binary crossentropy loss.
|
||||
|
||||
|
@ -37,8 +37,16 @@ class MeanSquaredErrorTest(testing.TestCase):
|
||||
self.assertAlmostEqual(loss, 767.8 / 6)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
# TODO
|
||||
pass
|
||||
mse_obj = losses.MeanSquaredError()
|
||||
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
|
||||
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
|
||||
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
loss = mse_obj(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
self.assertAlmostEqual(loss, 97.833336)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
mse_obj = losses.MeanSquaredError()
|
||||
@ -47,10 +55,6 @@ class MeanSquaredErrorTest(testing.TestCase):
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=0)
|
||||
self.assertAlmostEqual(loss, 0.0)
|
||||
|
||||
def test_invalid_sample_weight(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_no_reduction(self):
|
||||
mse_obj = losses.MeanSquaredError(reduction=None)
|
||||
y_true = np.array([[1, 9, 2], [-5, -2, 6]])
|
||||
@ -101,8 +105,16 @@ class MeanAbsoluteErrorTest(testing.TestCase):
|
||||
self.assertAlmostEqual(loss, 81.4 / 6)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
# TODO
|
||||
pass
|
||||
mae_obj = losses.MeanAbsoluteError()
|
||||
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
|
||||
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
|
||||
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
loss = mae_obj(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
self.assertAlmostEqual(loss, 13.833333)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
mae_obj = losses.MeanAbsoluteError()
|
||||
@ -111,10 +123,6 @@ class MeanAbsoluteErrorTest(testing.TestCase):
|
||||
loss = mae_obj(y_true, y_pred, sample_weight=0)
|
||||
self.assertAlmostEqual(loss, 0.0)
|
||||
|
||||
def test_invalid_sample_weight(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_no_reduction(self):
|
||||
mae_obj = losses.MeanAbsoluteError(reduction=None)
|
||||
y_true = np.array([[1, 9, 2], [-5, -2, 6]])
|
||||
@ -165,8 +173,16 @@ class MeanAbsolutePercentageErrorTest(testing.TestCase):
|
||||
self.assertAlmostEqual(loss, 422.8888, 3)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
# TODO
|
||||
pass
|
||||
mape_obj = losses.MeanAbsolutePercentageError()
|
||||
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
|
||||
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
|
||||
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
loss = mape_obj(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
self.assertAlmostEqual(loss, 694.4444)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
mape_obj = losses.MeanAbsolutePercentageError()
|
||||
@ -212,8 +228,16 @@ class MeanSquaredLogarithmicErrorTest(testing.TestCase):
|
||||
self.assertAlmostEqual(loss, 3.7856, 3)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
# TODO
|
||||
pass
|
||||
msle_obj = losses.MeanSquaredLogarithmicError()
|
||||
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
|
||||
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
|
||||
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
|
||||
loss = msle_obj(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
self.assertAlmostEqual(loss, 2.647374)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
msle_obj = losses.MeanSquaredLogarithmicError()
|
||||
|
Loading…
Reference in New Issue
Block a user