From bb9e71f2f1aa9a1c71e273497d668673036f98a0 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 1 May 2023 11:18:19 -0700 Subject: [PATCH] Merge branch 'main' into crossentropy --- keras_core/losses/losses.py | 1112 +++++++++++++++++++++++++++++++---- 1 file changed, 1010 insertions(+), 102 deletions(-) diff --git a/keras_core/losses/losses.py b/keras_core/losses/losses.py index 4c4c1f160..3475d3a16 100644 --- a/keras_core/losses/losses.py +++ b/keras_core/losses/losses.py @@ -1,3 +1,5 @@ +import warnings + from keras_core import backend from keras_core import operations as ops from keras_core.api_export import keras_core_export @@ -43,10 +45,10 @@ class MeanSquaredError(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -69,10 +71,10 @@ class MeanAbsoluteError(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -95,10 +97,10 @@ class MeanAbsolutePercentageError(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -125,10 +127,10 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -164,10 +166,10 @@ class CosineSimilarity(LossFunctionWrapper): Args: axis: The axis along which the cosine similarity is computed (the features axis). Defaults to -1. - reduction: Type of reduction to apply to loss. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. Defaults to - `"sum_over_batch_size"`. - name: Optional name for the instance. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -195,10 +197,10 @@ class Hinge(LossFunctionWrapper): provided we will convert them to -1 or 1. Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. Defaults to `"hinge"` + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__(self, reduction="sum_over_batch_size", name="hinge"): @@ -222,10 +224,10 @@ class SquaredHinge(LossFunctionWrapper): provided we will convert them to -1 or 1. Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. Defaults to `"squared_hinge"` + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__(self, reduction="sum_over_batch_size", name="squared_hinge"): @@ -248,11 +250,10 @@ class CategoricalHinge(LossFunctionWrapper): where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. Defaults to - `"categorical_hinge"` + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__( @@ -275,10 +276,10 @@ class KLDivergence(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. Defaults to 'kl_divergence'. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__(self, reduction="sum_over_batch_size", name="kl_divergence"): @@ -299,10 +300,10 @@ class Poisson(LossFunctionWrapper): ``` Args: - reduction: Type of reduction to apply to loss. For almost all cases - this defaults to `"sum_over_batch_size"`. Options are `"sum"`, - `"sum_over_batch_size"` or `None`. - name: Optional name for the instance. Defaults to `"poisson"` + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. """ def __init__(self, reduction="sum_over_batch_size", name="poisson"): @@ -312,6 +313,572 @@ class Poisson(LossFunctionWrapper): return Loss.get_config(self) +@keras_core_export("keras_core.losses.BinaryCrossentropy") +class BinaryCrossentropy(LossFunctionWrapper): + """Computes the cross-entropy loss between true labels and predicted labels. + + Use this cross-entropy loss for binary (0 or 1) classification applications. + The loss function requires the following inputs: + + - `y_true` (true label): This is either 0 or 1. + - `y_pred` (predicted value): This is the model's prediction, i.e, a single + floating-point value which either represents a + [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] + when `from_logits=True`) or a probability (i.e, value in [0., 1.] when + `from_logits=False`). + + Args: + from_logits: Whether to interpret `y_pred` as a tensor of + [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we + assume that `y_pred` contains probabilities (i.e., values in [0, + 1]). + label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs. + When > 0, we compute the loss between the predicted labels + and a smoothed version of the true labels, where the smoothing + squeezes the labels towards 0.5. Larger values of + `label_smoothing` correspond to heavier smoothing. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to -1. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. + + Examples: + + **Recommended Usage:** (set `from_logits=True`) + + With `compile()` API: + + ```python + model.compile( + loss=keras_core.losses.BinaryCrossentropy(from_logits=True), + ... + ) + ``` + + As a standalone function: + + >>> # Example 1: (batch_size = 1, number of samples = 4) + >>> y_true = [0, 1, 0, 0] + >>> y_pred = [-18.6, 0.51, 2.94, -12.8] + >>> bce = keras_core.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred) + 0.865 + + >>> # Example 2: (batch_size = 2, number of samples = 4) + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]] + >>> # Using default 'auto'/'sum_over_batch_size' reduction type. + >>> bce = keras_core.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred) + 0.865 + >>> # Using 'sample_weight' attribute + >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.243 + >>> # Using 'sum' reduction` type. + >>> bce = keras_core.losses.BinaryCrossentropy(from_logits=True, + ... reduction="sum") + >>> bce(y_true, y_pred) + 1.730 + >>> # Using 'none' reduction type. + >>> bce = keras_core.losses.BinaryCrossentropy(from_logits=True, + ... reduction=None) + >>> bce(y_true, y_pred) + array([0.235, 1.496], dtype=float32) + + **Default Usage:** (set `from_logits=False`) + + >>> # Make the following updates to the above "Recommended Usage" section + >>> # 1. Set `from_logits=False` + >>> keras_core.losses.BinaryCrossentropy() # OR ...('from_logits=False') + >>> # 2. Update `y_pred` to use probabilities instead of logits + >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]] + """ + + def __init__( + self, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="binary_crossentropy", + ): + super().__init__( + binary_crossentropy, + name=name, + reduction=reduction, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + + +@keras_core_export("keras_core.losses.BinaryFocalCrossentropy") +class BinaryFocalCrossentropy(LossFunctionWrapper): + """Computes focal cross-entropy loss between true labels and predictions. + + Binary cross-entropy loss is often used for binary (0 or 1) classification + tasks. The loss function requires the following inputs: + + - `y_true` (true label): This is either 0 or 1. + - `y_pred` (predicted value): This is the model's prediction, i.e, a single + floating-point value which either represents a + [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] + when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when + `from_logits=False`). + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a "focal factor" to down-weight easy examples and focus more + on hard examples. By default, the focal tensor is computed as follows: + + `focal_factor = (1 - output) ** gamma` for class 1 + `focal_factor = output ** gamma` for class 0 + where `gamma` is a focusing parameter. When `gamma=0`, this function is + equivalent to the binary crossentropy loss. + + Args: + apply_class_balancing: A bool, whether to apply weight balancing on the + binary classes 0 and 1. + alpha: A weight balancing factor for class 1, default is `0.25` as + mentioned in reference [Lin et al., 2018]( + https://arxiv.org/pdf/1708.02002.pdf). The weight for class 0 is + `1.0 - alpha`. + gamma: A focusing parameter used to compute the focal factor, default is + `2.0` as mentioned in the reference + [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf). + from_logits: Whether to interpret `y_pred` as a tensor of + [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we + assume that `y_pred` are probabilities (i.e., values in `[0, 1]`). + label_smoothing: Float in `[0, 1]`. When `0`, no smoothing occurs. + When > `0`, we compute the loss between the predicted labels + and a smoothed version of the true labels, where the smoothing + squeezes the labels towards `0.5`. + Larger values of `label_smoothing` correspond to heavier smoothing. + axis: The axis along which to compute crossentropy (the features axis). + Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. + + Examples: + + With the `compile()` API: + + ```python + model.compile( + loss=keras_core.losses.BinaryFocalCrossentropy( + gamma=2.0, from_logits=True), + ... + ) + ``` + + As a standalone function: + + >>> # Example 1: (batch_size = 1, number of samples = 4) + >>> y_true = [0, 1, 0, 0] + >>> y_pred = [-18.6, 0.51, 2.94, -12.8] + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... gamma=2, from_logits=True) + >>> loss(y_true, y_pred) + 0.691 + + >>> # Apply class weight + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=2, from_logits=True) + >>> loss(y_true, y_pred) + 0.51 + + >>> # Example 2: (batch_size = 2, number of samples = 4) + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]] + >>> # Using default 'auto'/'sum_over_batch_size' reduction type. + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... gamma=3, from_logits=True) + >>> loss(y_true, y_pred) + 0.647 + + >>> # Apply class weight + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=3, from_logits=True) + >>> loss(y_true, y_pred) + 0.482 + + >>> # Using 'sample_weight' attribute with focal effect + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... gamma=3, from_logits=True) + >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.133 + + >>> # Apply class weight + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=3, from_logits=True) + >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.097 + + >>> # Using 'sum' reduction` type. + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... gamma=4, from_logits=True, + ... reduction="sum") + >>> loss(y_true, y_pred) + 1.222 + + >>> # Apply class weight + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=4, from_logits=True, + ... reduction="sum") + >>> loss(y_true, y_pred) + 0.914 + + >>> # Using 'none' reduction type. + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... gamma=5, from_logits=True, + ... reduction=None) + >>> loss(y_true, y_pred) + array([0.0017 1.1561], dtype=float32) + + >>> # Apply class weight + >>> loss = keras_core.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=5, from_logits=True, + ... reduction=None) + >>> loss(y_true, y_pred) + array([0.0004 0.8670], dtype=float32) + """ + + def __init__( + self, + apply_class_balancing=False, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="binary_focal_crossentropy", + ): + """Initializes `BinaryFocalCrossentropy` instance.""" + super().__init__( + binary_focal_crossentropy, + apply_class_balancing=apply_class_balancing, + alpha=alpha, + gamma=gamma, + name=name, + reduction=reduction, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.apply_class_balancing = apply_class_balancing + self.alpha = alpha + self.gamma = gamma + + def get_config(self): + config = { + "apply_class_balancing": self.apply_class_balancing, + "alpha": self.alpha, + "gamma": self.gamma, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_core_export("keras_core.losses.CategoricalCrossentropy") +class CategoricalCrossentropy(LossFunctionWrapper): + """Computes the crossentropy loss between the labels and predictions. + + Use this crossentropy loss function when there are two or more label + classes. We expect labels to be provided in a `one_hot` representation. If + you want to provide labels as integers, please use + `SparseCategoricalCrossentropy` loss. There should be `# classes` floating + point values per feature. + + In the snippet below, there is `# classes` floating pointing values per + example. The shape of both `y_pred` and `y_true` are + `[batch_size, num_classes]`. + + Args: + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to -1. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. + + Examples: + + Standalone usage: + + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> cce = keras_core.losses.CategoricalCrossentropy() + >>> cce(y_true, y_pred) + 1.177 + + >>> # Calling with 'sample_weight'. + >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.814 + + >>> # Using 'sum' reduction type. + >>> cce = keras_core.losses.CategoricalCrossentropy( + ... reduction="sum") + >>> cce(y_true, y_pred) + 2.354 + + >>> # Using 'none' reduction type. + >>> cce = keras_core.losses.CategoricalCrossentropy( + ... reduction=None) + >>> cce(y_true, y_pred) + array([0.0513, 2.303], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss=keras_core.losses.CategoricalCrossentropy()) + ``` + """ + + def __init__( + self, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="categorical_crossentropy", + ): + super().__init__( + categorical_crossentropy, + name=name, + reduction=reduction, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + + +@keras_core_export("keras_core.losses.CategoricalFocalCrossentropy") +class CategoricalFocalCrossentropy(LossFunctionWrapper): + """Computes the alpha balanced focal crossentropy loss. + + Use this crossentropy loss function when there are two or more label + classes and if you want to handle class imbalance without using + `class_weights`. We expect labels to be provided in a `one_hot` + representation. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. The general formula for the focal loss (FL) + is as follows: + + `FL(p_t) = (1 - p_t) ** gamma * log(p_t)` + + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` + + `(1 - p_t) ** gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. + + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = -alpha * (1 - p_t) ** gamma * log(p_t)` + + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. + + The formula above can be generalized to: + `FL(p_t) = alpha * (1 - p_t) ** gamma * CrossEntropy(y_true, y_pred)` + + where minus comes from `CrossEntropy(y_true, y_pred)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)` + + In the snippet below, there is `# classes` floating pointing values per + example. The shape of both `y_pred` and `y_true` are + `(batch_size, num_classes)`. + + Args: + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple (easy) examples in a smooth manner. + from_logits: Whether `output` is expected to be a logits tensor. By + default, we consider that `output` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to -1. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. + + Examples: + + Standalone usage: + + >>> y_true = [[0., 1., 0.], [0., 0., 1.]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> cce = keras_core.losses.CategoricalFocalCrossentropy() + >>> cce(y_true, y_pred) + 0.23315276 + + >>> # Calling with 'sample_weight'. + >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.1632 + + >>> # Using 'sum' reduction type. + >>> cce = keras_core.losses.CategoricalFocalCrossentropy( + ... reduction="sum") + >>> cce(y_true, y_pred) + 0.46631 + + >>> # Using 'none' reduction type. + >>> cce = keras_core.losses.CategoricalFocalCrossentropy( + ... reduction=None) + >>> cce(y_true, y_pred) + array([3.2058331e-05, 4.6627346e-01], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='adam', + loss=keras_core.losses.CategoricalFocalCrossentropy()) + ``` + """ + + def __init__( + self, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="categorical_focal_crossentropy", + ): + """Initializes `CategoricalFocalCrossentropy` instance.""" + super().__init__( + categorical_focal_crossentropy, + alpha=alpha, + gamma=gamma, + name=name, + reduction=reduction, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.alpha = alpha + self.gamma = gamma + + def get_config(self): + config = { + "alpha": self.alpha, + "gamma": self.gamma, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_core_export("keras_core.losses.SparseCategoricalCrossentropy") +class SparseCategoricalCrossentropy(LossFunctionWrapper): + """Computes the crossentropy loss between the labels and predictions. + + Use this crossentropy loss function when there are two or more label + classes. We expect labels to be provided as integers. If you want to + provide labels using `one-hot` representation, please use + `CategoricalCrossentropy` loss. There should be `# classes` floating point + values per feature for `y_pred` and a single floating point value per + feature for `y_true`. + + In the snippet below, there is a single floating point value per example for + `y_true` and `# classes` floating pointing values per example for `y_pred`. + The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is + `[batch_size, num_classes]`. + + Args: + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + ignore_class: Optional integer. The ID of a class to be ignored during + loss computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. + By default (`ignore_class=None`), all classes are considered. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. + Suuported options are `"sum"`, `"sum_over_batch_size"` or `None`. + name: Optional name for the loss instance. + + Examples: + + >>> y_true = [1, 2] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> scce = keras_core.losses.SparseCategoricalCrossentropy() + >>> scce(y_true, y_pred) + 1.177 + + >>> # Calling with 'sample_weight'. + >>> scce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.814 + + >>> # Using 'sum' reduction type. + >>> scce = keras_core.losses.SparseCategoricalCrossentropy( + ... reduction="sum") + >>> scce(y_true, y_pred) + 2.354 + + >>> # Using 'none' reduction type. + >>> scce = keras_core.losses.SparseCategoricalCrossentropy( + ... reduction=None) + >>> scce(y_true, y_pred) + array([0.0513, 2.303], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss=keras_core.losses.SparseCategoricalCrossentropy()) + ``` + """ + + def __init__( + self, + from_logits=False, + ignore_class=None, + reduction="sum_over_batch_size", + name="sparse_categorical_crossentropy", + ): + super().__init__( + sparse_categorical_crossentropy, + name=name, + reduction=reduction, + from_logits=from_logits, + ignore_class=ignore_class, + ) + + def convert_binary_labels_to_hinge(y_true): """Converts binary labels into -1/1 for hinge loss/metric calculation.""" are_zeros = ops.equal(y_true, 0) @@ -347,12 +914,6 @@ def hinge(y_true, y_pred): loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1) ``` - Standalone usage: - - >>> y_true = np.random.choice([-1, 1], size=(2, 3)) - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.hinge(y_true, y_pred) - Args: y_true: The ground truth values. `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided they will be converted @@ -361,6 +922,12 @@ def hinge(y_true, y_pred): Returns: Hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.choice([-1, 1], size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.hinge(y_true, y_pred) """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.cast(y_true, dtype=y_pred.dtype) @@ -384,12 +951,6 @@ def squared_hinge(y_true, y_pred): loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1) ``` - Standalone usage: - - >>> y_true = np.random.choice([-1, 1], size=(2, 3)) - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.squared_hinge(y_true, y_pred) - Args: y_true: The ground truth values. `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are provided we will convert them @@ -398,6 +959,12 @@ def squared_hinge(y_true, y_pred): Returns: Squared hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.choice([-1, 1], size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.squared_hinge(y_true, y_pred) """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.cast(y_true, y_pred.dtype) @@ -424,13 +991,6 @@ def categorical_hinge(y_true, y_pred): where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)` - Standalone usage: - - >>> y_true = np.random.randint(0, 3, size=(2,)) - >>> y_true = np.eye(np.max(y_true) + 1)[y_true] - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.categorical_hinge(y_true, y_pred) - Args: y_true: The ground truth values. `y_true` values are expected to be either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor) with @@ -439,6 +999,13 @@ def categorical_hinge(y_true, y_pred): Returns: Categorical hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 3, size=(2,)) + >>> y_true = np.eye(np.max(y_true) + 1)[y_true] + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.categorical_hinge(y_true, y_pred) """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.cast(y_true, y_pred.dtype) @@ -495,18 +1062,18 @@ def mean_absolute_error(y_true, y_pred): loss = mean(abs(y_true - y_pred), axis=-1) ``` - Standalone usage: - - >>> y_true = np.random.randint(0, 2, size=(2, 3)) - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.mean_absolute_error(y_true, y_pred) - Args: y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. Returns: Mean absolute error values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.mean_absolute_error(y_true, y_pred) """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) @@ -523,18 +1090,16 @@ def mean_absolute_error(y_true, y_pred): def mean_absolute_percentage_error(y_true, y_pred): """Computes the mean absolute percentage error between `y_true` & `y_pred`. - `loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)` + Formula: + + ```python + loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1) + ``` Division by zero is prevented by dividing by `maximum(y_true, epsilon)` where `epsilon = keras_core.backend.epsilon()` (default to `1e-7`). - Standalone usage: - - >>> y_true = np.random.random(size=(2, 3)) - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.mean_absolute_percentage_error(y_true, y_pred) - Args: y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. @@ -542,6 +1107,12 @@ def mean_absolute_percentage_error(y_true, y_pred): Returns: Mean absolute percentage error values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.random(size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.mean_absolute_percentage_error(y_true, y_pred) """ epsilon = ops.convert_to_tensor(backend.epsilon()) y_pred = ops.convert_to_tensor(y_pred) @@ -570,19 +1141,19 @@ def mean_squared_logarithmic_error(y_true, y_pred): values and 0 values will be replaced with `keras_core.backend.epsilon()` (default to `1e-7`). - Standalone usage: - - >>> y_true = np.random.randint(0, 2, size=(2, 3)) - >>> y_pred = np.random.random(size=(2, 3)) - >>> loss = keras_core.losses.mean_squared_logarithmic_error(y_true, y_pred) - Args: y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. Returns: - Mean squared logarithmic error values. shape = `[batch_size, d0, .. + Mean squared logarithmic error values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras_core.losses.mean_squared_logarithmic_error(y_true, y_pred) """ epsilon = ops.convert_to_tensor(backend.epsilon()) y_pred = ops.convert_to_tensor(y_pred) @@ -610,12 +1181,6 @@ def cosine_similarity(y_true, y_pred, axis=-1): similarity will be 0 regardless of the proximity between predictions and targets. - Standalone usage: - >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] - >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] - >>> loss = keras_core.losses.cosine_similarity(y_true, y_pred, axis=-1) - [-0., -0.99999994, 0.99999994] - Args: y_true: Tensor of true targets. y_pred: Tensor of predicted targets. @@ -623,6 +1188,13 @@ def cosine_similarity(y_true, y_pred, axis=-1): Returns: Cosine similarity tensor. + + Example: + + >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] + >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] + >>> loss = keras_core.losses.cosine_similarity(y_true, y_pred, axis=-1) + [-0., -0.99999994, 0.99999994] """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) @@ -647,7 +1219,14 @@ def kl_divergence(y_true, y_pred): loss = y_true * log(y_true / y_pred) ``` - Standalone usage: + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + + Returns: + KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float32) >>> y_pred = np.random.random(size=(2, 3)) @@ -657,13 +1236,6 @@ def kl_divergence(y_true, y_pred): >>> y_pred = ops.clip(y_pred, 1e-7, 1) >>> assert np.array_equal( ... loss, np.sum(y_true * np.log(y_true / y_pred), axis=-1)) - - Args: - y_true: Tensor of true targets. - y_pred: Tensor of predicted targets. - - Returns: - KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`. """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, y_pred.dtype) @@ -687,7 +1259,14 @@ def poisson(y_true, y_pred): loss = y_pred - y_true * log(y_pred) ``` - Standalone usage: + Args: + y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. + + Returns: + Poisson loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: >>> y_true = np.random.randint(0, 2, size=(2, 3)) >>> y_pred = np.random.random(size=(2, 3)) @@ -697,15 +1276,344 @@ def poisson(y_true, y_pred): >>> assert np.allclose( ... loss, np.mean(y_pred - y_true * np.log(y_pred), axis=-1), ... atol=1e-5) - - Args: - y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. - y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. - - Returns: - Poisson loss values with shape = `[batch_size, d0, .. dN-1]`. """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) epsilon = ops.convert_to_tensor(backend.epsilon()) return ops.mean(y_pred - y_true * ops.log(y_pred + epsilon), axis=-1) + + +@keras_core_export( + [ + "keras_core.metrics.categorical_crossentropy", + "keras_core.losses.categorical_crossentropy", + ] +) +def categorical_crossentropy( + y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1 +): + """Computes the categorical crossentropy loss. + + Args: + y_true: Tensor of one-hot true targets. + y_pred: Tensor of predicted targets. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to -1. The dimension along which the entropy is + computed. + + Returns: + Categorical crossentropy loss value. + + Example: + + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> loss = keras_core.losses.categorical_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss.numpy() + array([0.0513, 2.303], dtype=float32) + """ + if isinstance(axis, bool): + raise ValueError( + "`axis` must be of type `int`. " + f"Received: axis={axis} of type {type(axis)}" + ) + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if y_pred.shape[-1] == 1: + warnings.warn( + "In loss categorical_crossentropy, expected " + "y_pred.shape to be (batch_size, num_classes) " + f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. " + "Consider using 'binary_crossentropy' if you only have 2 classes.", + SyntaxWarning, + stacklevel=2, + ) + + if label_smoothing: + num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype) + y_true = y_true * (1.0 - label_smoothing) + ( + label_smoothing / num_classes + ) + + return ops.categorical_crossentropy( + y_true, y_pred, from_logits=from_logits, axis=axis + ) + + +@keras_core_export( + [ + "keras_core.metrics.categorical_focal_crossentropy", + "keras_core.losses.categorical_focal_crossentropy", + ] +) +def categorical_focal_crossentropy( + y_true, + y_pred, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Computes the categorical focal crossentropy loss. + + Args: + y_true: Tensor of one-hot true targets. + y_pred: Tensor of predicted targets. + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple examples in a smooth manner. When `gamma` = 0, there is + no focal effect on the categorical crossentropy. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to -1. The dimension along which the entropy is + computed. + + Returns: + Categorical focal crossentropy loss value. + + Example: + + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]] + >>> loss = keras_core.losses.categorical_focal_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([2.63401289e-04, 6.75912094e-01], dtype=float32) + """ + if isinstance(axis, bool): + raise ValueError( + "`axis` must be of type `int`. " + f"Received: axis={axis} of type {type(axis)}" + ) + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if y_pred.shape[-1] == 1: + warnings.warn( + "In loss categorical_focal_crossentropy, expected " + "y_pred.shape to be (batch_size, num_classes) " + f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. " + "Consider using 'binary_crossentropy' if you only have 2 classes.", + SyntaxWarning, + stacklevel=2, + ) + + if label_smoothing: + num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype) + y_true = y_true * (1.0 - label_smoothing) + ( + label_smoothing / num_classes + ) + + return ops.categorical_focal_crossentropy( + target=y_true, + output=y_pred, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + axis=axis, + ) + + +@keras_core_export( + [ + "keras_core.metrics.sparse_categorical_crossentropy", + "keras_core.losses.sparse_categorical_crossentropy", + ] +) +def sparse_categorical_crossentropy( + y_true, y_pred, from_logits=False, axis=-1, ignore_class=None +): + """Computes the sparse categorical crossentropy loss. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + axis: Defaults to -1. The dimension along which the entropy is + computed. + ignore_class: Optional integer. The ID of a class to be ignored during + loss computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) + in segmentation maps. By default (`ignore_class=None`), + all classes are considered. + + Returns: + Sparse categorical crossentropy loss value. + + Examples: + + >>> y_true = [1, 2] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> loss = keras_core.losses.sparse_categorical_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.0513, 2.303], dtype=float32) + + >>> y_true = [[[ 0, 2], + ... [-1, -1]], + ... [[ 0, 2], + ... [-1, -1]]] + >>> y_pred = [[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + ... [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]], + ... [[[1.0, 0.0, 0.0], [0.0, 0.5, 0.5]], + ... [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]]] + >>> loss = keras_core.losses.sparse_categorical_crossentropy( + ... y_true, y_pred, ignore_class=-1) + array([[[2.3841855e-07, 2.3841855e-07], + [0.0000000e+00, 0.0000000e+00]], + [[2.3841855e-07, 6.9314730e-01], + [0.0000000e+00, 0.0000000e+00]]], dtype=float32) + """ + return ops.sparse_categorical_crossentropy( + y_true, + y_pred, + from_logits=from_logits, + ignore_class=ignore_class, + axis=axis, + ) + + +@keras_core_export( + [ + "keras_core.metrics.binary_crossentropy", + "keras_core.losses.binary_crossentropy", + ] +) +def binary_crossentropy( + y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1 +): + """Computes the binary crossentropy loss. + + Args: + y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by + squeezing them towards 0.5, that is, + using `1. - 0.5 * label_smoothing` for the target class + and `0.5 * label_smoothing` for the non-target class. + axis: The axis along which the mean is computed. Defaults to -1. + + Returns: + Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] + >>> loss = keras_core.losses.binary_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.916 , 0.714], dtype=float32) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if label_smoothing: + y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + + return ops.mean( + ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits), + axis=axis, + ) + + +@keras_core_export( + [ + "keras_core.metrics.binary_focal_crossentropy", + "keras_core.losses.binary_focal_crossentropy", + ] +) +def binary_focal_crossentropy( + y_true, + y_pred, + apply_class_balancing=False, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Computes the binary focal crossentropy loss. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. By default, the focal tensor is computed as follows: + + `focal_factor = (1 - output)**gamma` for class 1 + `focal_factor = output**gamma` for class 0 + where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal + effect on the binary crossentropy loss. + + If `apply_class_balancing == True`, this function also takes into account a + weight balancing factor for the binary classes 0 and 1 as follows: + + `weight = alpha` for class 1 (`target == 1`) + `weight = 1 - alpha` for class 0 + where `alpha` is a float in the range of `[0, 1]`. + + Args: + y_true: Ground truth values, of shape `(batch_size, d0, .. dN)`. + y_pred: The predicted values, of shape `(batch_size, d0, .. dN)`. + apply_class_balancing: A bool, whether to apply weight balancing on the + binary classes 0 and 1. + alpha: A weight balancing factor for class 1, default is `0.25` as + mentioned in the reference. The weight for class 0 is `1.0 - alpha`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by + squeezing them towards 0.5, that is, + using `1. - 0.5 * label_smoothing` for the target class + and `0.5 * label_smoothing` for the non-target class. + axis: The axis along which the mean is computed. Defaults to `-1`. + + Returns: + Binary focal crossentropy loss value + with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] + >>> loss = keras_core.losses.binary_focal_crossentropy( + ... y_true, y_pred, gamma=2) + >>> assert loss.shape == (2,) + >>> loss + array([0.330, 0.206], dtype=float32) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if label_smoothing: + y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + + return ops.mean( + ops.binary_focal_crossentropy( + target=y_true, + output=y_pred, + apply_class_balancing=apply_class_balancing, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + ), + axis=axis, + )