From d6600a5797dc9e57d95500b32b74e5f0663e0423 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 1 May 2023 18:54:10 -0700 Subject: [PATCH] Add crossentropy metrics. --- keras_core/losses/losses.py | 2 +- keras_core/metrics/__init__.py | 9 +- keras_core/metrics/probabilistic_metrics.py | 246 ++++++++++++++++++ .../metrics/probabilistic_metrics_test.py | 152 ++++++++++- 4 files changed, 400 insertions(+), 9 deletions(-) diff --git a/keras_core/losses/losses.py b/keras_core/losses/losses.py index 92a18f31e..b369ffcf8 100644 --- a/keras_core/losses/losses.py +++ b/keras_core/losses/losses.py @@ -1347,7 +1347,7 @@ def categorical_crossentropy( >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] >>> loss = keras_core.losses.categorical_crossentropy(y_true, y_pred) >>> assert loss.shape == (2,) - >>> loss.numpy() + >>> loss array([0.0513, 2.303], dtype=float32) """ if isinstance(axis, bool): diff --git a/keras_core/metrics/__init__.py b/keras_core/metrics/__init__.py index 8221bca6c..6403d052a 100644 --- a/keras_core/metrics/__init__.py +++ b/keras_core/metrics/__init__.py @@ -8,15 +8,19 @@ from keras_core.metrics.hinge_metrics import CategoricalHinge from keras_core.metrics.hinge_metrics import Hinge from keras_core.metrics.hinge_metrics import SquaredHinge from keras_core.metrics.metric import Metric +from keras_core.metrics.probabilistic_metrics import BinaryCrossentropy +from keras_core.metrics.probabilistic_metrics import CategoricalCrossentropy from keras_core.metrics.probabilistic_metrics import KLDivergence from keras_core.metrics.probabilistic_metrics import Poisson +from keras_core.metrics.probabilistic_metrics import ( + SparseCategoricalCrossentropy, +) from keras_core.metrics.reduction_metrics import Mean from keras_core.metrics.reduction_metrics import MeanMetricWrapper from keras_core.metrics.reduction_metrics import Sum from keras_core.metrics.regression_metrics import MeanSquaredError from keras_core.metrics.regression_metrics import mean_squared_error from keras_core.saving import serialization_lib -from keras_core.utils import naming ALL_OBJECTS = { Metric, @@ -30,6 +34,9 @@ ALL_OBJECTS = { CategoricalHinge, KLDivergence, Poisson, + BinaryCrossentropy, + CategoricalCrossentropy, + SparseCategoricalCrossentropy, } ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} ALL_OBJECTS_DICT.update( diff --git a/keras_core/metrics/probabilistic_metrics.py b/keras_core/metrics/probabilistic_metrics.py index 0485da207..1b6beb88f 100644 --- a/keras_core/metrics/probabilistic_metrics.py +++ b/keras_core/metrics/probabilistic_metrics.py @@ -1,6 +1,9 @@ from keras_core.api_export import keras_core_export +from keras_core.losses.losses import binary_crossentropy +from keras_core.losses.losses import categorical_crossentropy from keras_core.losses.losses import kl_divergence from keras_core.losses.losses import poisson +from keras_core.losses.losses import sparse_categorical_crossentropy from keras_core.metrics import reduction_metrics @@ -62,6 +65,8 @@ class Poisson(reduction_metrics.MeanMetricWrapper): name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. + Examples: + Standalone usage: >>> m = keras_core.metrics.Poisson() @@ -89,3 +94,244 @@ class Poisson(reduction_metrics.MeanMetricWrapper): def get_config(self): return {"name": self.name, "dtype": self.dtype} + + +@keras_core_export("keras_core.metrics.BinaryCrossentropy") +class BinaryCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + This is the crossentropy metric class to be used when there are only two + label classes (0 and 1). + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected + to be a logits tensor. By default, we consider + that output encodes a probability distribution. + label_smoothing: (Optional) Float in `[0, 1]`. + When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. + e.g. `label_smoothing=0.2` means that we will use + a value of 0.1 for label "0" and 0.9 for label "1". + + Examples: + + Standalone usage: + + >>> m = keras_core.metrics.BinaryCrossentropy() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result() + 0.81492424 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.9162905 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras_core.metrics.BinaryCrossentropy()]) + ``` + """ + + def __init__( + self, + name="binary_crossentropy", + dtype=None, + from_logits=False, + label_smoothing=0, + ): + super().__init__( + binary_crossentropy, + name, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + } + + +@keras_core_export("keras_core.metrics.CategoricalCrossentropy") +class CategoricalCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + This is the crossentropy metric class to be used when there are multiple + label classes (2 or more). It assumes that labels are one-hot encoded, + e.g., when labels values are `[2, 0, 1]`, then + `y_true` is `[[0, 0, 1], [1, 0, 0], [0, 1, 0]]`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected to be + a logits tensor. By default, we consider that output + encodes a probability distribution. + label_smoothing: (Optional) Float in `[0, 1]`. + When > 0, label values are smoothed, meaning the confidence + on label values are relaxed. e.g. `label_smoothing=0.2` means + that we will use a value of 0.1 for label + "0" and 0.9 for label "1". + axis: (Optional) Defaults to -1. + The dimension along which entropy is computed. + + Examples: + + Standalone usage: + + >>> # EPSILON = 1e-7, y = y_true, y` = y_pred + >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) + >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] + >>> # xent = -sum(y * log(y'), axis = -1) + >>> # = -((log 0.95), (log 0.1)) + >>> # = [0.051, 2.302] + >>> # Reduced xent = (0.051 + 2.302) / 2 + >>> m = keras_core.metrics.CategoricalCrossentropy() + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.result() + 1.1769392 + + >>> m.reset_state() + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=np.array([0.3, 0.7])) + >>> m.result() + 1.6271976 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras_core.metrics.CategoricalCrossentropy()]) + ``` + """ + + def __init__( + self, + name="categorical_crossentropy", + dtype=None, + from_logits=False, + label_smoothing=0, + axis=-1, + ): + super().__init__( + categorical_crossentropy, + name, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + } + + +@keras_core_export("keras_core.metrics.SparseCategoricalCrossentropy") +class SparseCategoricalCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + Use this crossentropy metric when there are two or more label classes. + It expects labels to be provided as integers. If you want to provide labels + that are one-hot encoded, please use the `CategoricalCrossentropy` + metric instead. + + There should be `num_classes` floating point values per feature for `y_pred` + and a single floating point value per feature for `y_true`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected + to be a logits tensor. By default, we consider that output + encodes a probability distribution. + axis: (Optional) Defaults to -1. + The dimension along which entropy is computed. + + Examples: + + Standalone usage: + + >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] + >>> # logits = log(y_pred) + >>> # softmax = exp(logits) / sum(exp(logits), axis=-1) + >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] + >>> # xent = -sum(y * log(softmax), 1) + >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181], + >>> # [-2.3026, -0.2231, -2.3026]] + >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] + >>> # xent = [0.0513, 2.3026] + >>> # Reduced xent = (0.0513 + 2.3026) / 2 + >>> m = keras_core.metrics.SparseCategoricalCrossentropy() + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.result() + 1.1769392 + + >>> m.reset_state() + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=np.array([0.3, 0.7])) + >>> m.result() + 1.6271976 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras_core.metrics.SparseCategoricalCrossentropy()]) + ``` + """ + + def __init__( + self, + name="sparse_categorical_crossentropy", + dtype=None, + from_logits=False, + axis=-1, + ): + super().__init__( + sparse_categorical_crossentropy, + name=name, + dtype=dtype, + from_logits=from_logits, + axis=axis, + ) + self.from_logits = from_logits + self.axis = axis + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "axis": self.axis, + } diff --git a/keras_core/metrics/probabilistic_metrics_test.py b/keras_core/metrics/probabilistic_metrics_test.py index 93c5d71b7..8b9ff75ca 100644 --- a/keras_core/metrics/probabilistic_metrics_test.py +++ b/keras_core/metrics/probabilistic_metrics_test.py @@ -65,13 +65,7 @@ class PoissonTest(testing.TestCase): ) def test_config(self): - poisson_obj = metrics.Poisson(name="poisson", dtype="float32") - self.assertEqual(poisson_obj.name, "poisson") - self.assertEqual(poisson_obj._dtype, "float32") - - poisson_obj2 = metrics.Poisson.from_config(poisson_obj.get_config()) - self.assertEqual(poisson_obj2.name, "poisson") - self.assertEqual(poisson_obj2._dtype, "float32") + self.run_class_serialization_test(metrics.Poisson(name="poisson")) def test_unweighted(self): self.setup() @@ -96,3 +90,147 @@ class PoissonTest(testing.TestCase): expected_result = np.multiply(self.expected_results, sample_weight) expected_result = np.sum(expected_result) / np.sum(sample_weight) self.assertAllClose(result, expected_result, atol=1e-3) + + +class BinaryCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.BinaryCrossentropy( + name="bce", dtype="int32", label_smoothing=0.2 + ) + ) + + def test_unweighted(self): + bce_obj = metrics.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) + result = bce_obj(y_true, y_pred) + self.assertAllClose(result, 3.9855, atol=1e-3) + + def test_unweighted_with_logits(self): + bce_obj = metrics.BinaryCrossentropy(from_logits=True) + + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + result = bce_obj(y_true, y_pred) + self.assertAllClose(result, 3.333, atol=1e-3) + + def test_weighted(self): + bce_obj = metrics.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) + sample_weight = np.array([1.5, 2.0]) + result = bce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 3.4162, atol=1e-3) + + def test_weighted_from_logits(self): + bce_obj = metrics.BinaryCrossentropy(from_logits=True) + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + sample_weight = np.array([2.0, 2.5]) + result = bce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 3.7037, atol=1e-3) + + def test_label_smoothing(self): + logits = np.array(((10.0, -10.0, -10.0))) + y_true = np.array(((1, 0, 1))) + label_smoothing = 0.1 + bce_obj = metrics.BinaryCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + result = bce_obj(y_true, logits) + expected_value = (10.0 + 5.0 * label_smoothing) / 3.0 + self.assertAllClose(expected_value, result, atol=1e-3) + + +class CategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.CategoricalCrossentropy( + name="cce", dtype="int32", label_smoothing=0.2 + ) + ) + + def test_unweighted(self): + cce_obj = metrics.CategoricalCrossentropy() + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + result = cce_obj(y_true, y_pred) + self.assertAllClose(result, 1.176, atol=1e-3) + + def test_unweighted_from_logits(self): + cce_obj = metrics.CategoricalCrossentropy(from_logits=True) + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + result = cce_obj(y_true, logits) + self.assertAllClose(result, 3.5011, atol=1e-3) + + def test_weighted(self): + cce_obj = metrics.CategoricalCrossentropy() + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + sample_weight = np.array([1.5, 2.0]) + result = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 1.338, atol=1e-3) + + def test_weighted_from_logits(self): + cce_obj = metrics.CategoricalCrossentropy(from_logits=True) + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + sample_weight = np.array([1.5, 2.0]) + result = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAllClose(result, 4.0012, atol=1e-3) + + def test_label_smoothing(self): + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + label_smoothing = 0.1 + cce_obj = metrics.CategoricalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + self.assertAllClose(loss, 3.667, atol=1e-3) + + +class SparseCategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.SparseCategoricalCrossentropy(name="scce", dtype="int32") + ) + + def test_unweighted(self): + scce_obj = metrics.SparseCategoricalCrossentropy() + + y_true = np.array([1, 2]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + result = scce_obj(y_true, y_pred) + self.assertAllClose(result, 1.176, atol=1e-3) + + def test_unweighted_from_logits(self): + scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True) + + y_true = np.array([1, 2]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + result = scce_obj(y_true, logits) + self.assertAllClose(result, 3.5011, atol=1e-3) + + def test_weighted(self): + scce_obj = metrics.SparseCategoricalCrossentropy() + + y_true = np.array([1, 2]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + sample_weight = np.array([1.5, 2.0]) + result = scce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 1.338, atol=1e-3) + + def test_weighted_from_logits(self): + scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True) + + y_true = np.array([1, 2]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + sample_weight = np.array([1.5, 2.0]) + result = scce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAllClose(result, 4.0012, atol=1e-3)