diff --git a/keras_core/metrics/__init__.py b/keras_core/metrics/__init__.py index 19213822c..6ec09b5b3 100644 --- a/keras_core/metrics/__init__.py +++ b/keras_core/metrics/__init__.py @@ -8,6 +8,7 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy from keras_core.metrics.confusion_metrics import FalseNegatives from keras_core.metrics.confusion_metrics import FalsePositives from keras_core.metrics.confusion_metrics import Precision +from keras_core.metrics.confusion_metrics import Recall from keras_core.metrics.confusion_metrics import TrueNegatives from keras_core.metrics.confusion_metrics import TruePositives from keras_core.metrics.hinge_metrics import CategoricalHinge @@ -39,6 +40,7 @@ ALL_OBJECTS = { FalseNegatives, FalsePositives, Precision, + Recall, TrueNegatives, TruePositives, # Hinge diff --git a/keras_core/metrics/confusion_metrics.py b/keras_core/metrics/confusion_metrics.py index e074d6bb9..be9461bf5 100644 --- a/keras_core/metrics/confusion_metrics.py +++ b/keras_core/metrics/confusion_metrics.py @@ -14,9 +14,9 @@ class _ConfusionMatrixConditionCount(Metric): confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. thresholds: (Optional) Defaults to 0.5. A float value or a python list / - tuple of float threshold values in [0, 1]. A threshold is compared + tuple of float threshold values in `[0, 1]`. A threshold is compared with prediction values to determine the truth value of predictions - (i.e., above the threshold is `true`, below is `false`). One metric + (i.e., above the threshold is `True`, below is `False`). One metric value is generated for each threshold value. name: (Optional) string name of the metric instance. dtype: (Optional) data type of the metric result. @@ -85,9 +85,9 @@ class FalsePositives(_ConfusionMatrixConditionCount): Args: thresholds: (Optional) Defaults to 0.5. A float value, or a Python - list/tuple of float threshold values in [0, 1]. A threshold is + list/tuple of float threshold values in `[0, 1]`. A threshold is compared with prediction values to determine the truth value of - predictions (i.e., above the threshold is `true`, below is `false`). + predictions (i.e., above the threshold is `True`, below is `False`). If used with a loss function that sets `from_logits=True` (i.e. no sigmoid applied to predictions), `thresholds` should be set to 0. One metric value is generated for each threshold value. @@ -129,9 +129,9 @@ class FalseNegatives(_ConfusionMatrixConditionCount): Args: thresholds: (Optional) Defaults to 0.5. A float value, or a Python - list/tuple of float threshold values in [0, 1]. A threshold is + list/tuple of float threshold values in `[0, 1]`. A threshold is compared with prediction values to determine the truth value of - predictions (i.e., above the threshold is `true`, below is `false`). + predictions (i.e., above the threshold is `True`, below is `False`). If used with a loss function that sets `from_logits=True` (i.e. no sigmoid applied to predictions), `thresholds` should be set to 0. One metric value is generated for each threshold value. @@ -173,9 +173,9 @@ class TrueNegatives(_ConfusionMatrixConditionCount): Args: thresholds: (Optional) Defaults to 0.5. A float value, or a Python - list/tuple of float threshold values in [0, 1]. A threshold is + list/tuple of float threshold values in `[0, 1]`. A threshold is compared with prediction values to determine the truth value of - predictions (i.e., above the threshold is `true`, below is `false`). + predictions (i.e., above the threshold is `True`, below is `False`). If used with a loss function that sets `from_logits=True` (i.e. no sigmoid applied to predictions), `thresholds` should be set to 0. One metric value is generated for each threshold value. @@ -217,9 +217,9 @@ class TruePositives(_ConfusionMatrixConditionCount): Args: thresholds: (Optional) Defaults to 0.5. A float value, or a Python - list/tuple of float threshold values in [0, 1]. A threshold is + list/tuple of float threshold values in `[0, 1]`. A threshold is compared with prediction values to determine the truth value of - predictions (i.e., above the threshold is `true`, below is `false`). + predictions (i.e., above the threshold is `True`, below is `False`). If used with a loss function that sets `from_logits=True` (i.e. no sigmoid applied to predictions), `thresholds` should be set to 0. One metric value is generated for each threshold value. @@ -272,13 +272,13 @@ class Precision(Metric): Args: thresholds: (Optional) A float value, or a Python list/tuple of float - threshold values in [0, 1]. A threshold is compared with prediction - values to determine the truth value of predictions (i.e., above the - threshold is `true`, below is `false`). If used with a loss function - that sets `from_logits=True` (i.e. no sigmoid applied to - predictions), `thresholds` should be set to 0. One metric value is - generated for each threshold value. If neither thresholds nor top_k - are set, the default is to calculate precision with + threshold values in `[0, 1]`. A threshold is compared with + prediction values to determine the truth value of predictions (i.e., + above the threshold is `True`, below is `False`). If used with a + loss function that sets `from_logits=True` (i.e. no sigmoid applied + to predictions), `thresholds` should be set to 0. One metric value + is generated for each threshold value. If neither thresholds nor + top_k are set, the default is to calculate precision with `thresholds=0.5`. top_k: (Optional) Unset by default. An int value specifying the top-k predictions to consider when calculating precision. @@ -326,7 +326,7 @@ class Precision(Metric): ```python model.compile(optimizer='adam', - loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + loss=keras_core.losses.BinaryCrossentropy(from_logits=True), metrics=[keras_core.metrics.Precision(thresholds=0)]) ``` """ @@ -369,7 +369,7 @@ class Precision(Metric): Can be a `Tensor` whose rank is either 0, or the same rank as `y_true`, and must be broadcastable to `y_true`. """ - return metrics_utils.update_confusion_matrix_variables( + metrics_utils.update_confusion_matrix_variables( { metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 @@ -403,3 +403,143 @@ class Precision(Metric): } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) + + +@keras_core_export("keras_core.metrics.Recall") +class Recall(Metric): + """Computes the recall of the predictions with respect to the labels. + + This metric creates two local variables, `true_positives` and + `false_negatives`, that are used to compute the recall. This value is + ultimately returned as `recall`, an idempotent operation that simply divides + `true_positives` by the sum of `true_positives` and `false_negatives`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `top_k` is set, recall will be computed as how often on average a class + among the labels of a batch entry is in the top-k predictions. + + If `class_id` is specified, we calculate recall by considering only the + entries in the batch for which `class_id` is in the label, and computing the + fraction of them for which `class_id` is above the threshold and/or in the + top-k predictions. + + Args: + thresholds: (Optional) A float value, or a Python list/tuple of float + threshold values in `[0, 1]`. A threshold is compared with + prediction values to determine the truth value of predictions (i.e., + above thethreshold is `True`, below is `False`). If used with a loss + function that sets `from_logits=True` (i.e. no sigmoid applied to + predictions), `thresholds` should be set to 0. One metric value is + generated for each threshold value. If neither thresholds nor top_k + are set, the default is to calculate recall with `thresholds=0.5`. + top_k: (Optional) Unset by default. An int value specifying the top-k + predictions to consider when calculating recall. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Standalone usage: + + >>> m = keras_core.metrics.Recall() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.result() + 0.6666667 + + >>> m.reset_state() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mse', + metrics=[keras_core.metrics.Recall()]) + ``` + + Usage with a loss with `from_logits=True`: + + ```python + model.compile(optimizer='adam', + loss=keras_core.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras_core.metrics.Recall(thresholds=0)]) + ``` + """ + + def __init__( + self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + self.init_thresholds = thresholds + self.top_k = top_k + self.class_id = class_id + + default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF + self.thresholds = metrics_utils.parse_init_thresholds( + thresholds, default_threshold=default_threshold + ) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds(self.thresholds) + ) + self.true_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="true_positives", + ) + self.false_negatives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="false_negatives", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates true positive and false negative statistics. + + Args: + y_true: The ground truth values, with the same dimensions as + `y_pred`. Will be cast to `bool`. + y_pred: The predicted values. Each element must be in the range + `[0, 1]`. + sample_weight: Optional weighting of each example. Defaults to 1. + Can be a `Tensor` whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 + }, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + top_k=self.top_k, + class_id=self.class_id, + sample_weight=sample_weight, + ) + + def result(self): + result = ops.divide( + self.true_positives, + self.true_positives + self.false_negatives + backend.epsilon(), + ) + return result[0] if len(self.thresholds) == 1 else result + + def reset_state(self): + num_thresholds = len(to_list(self.thresholds)) + self.true_positives.assign(ops.zeros((num_thresholds,))) + self.false_negatives.assign(ops.zeros((num_thresholds,))) + + def get_config(self): + config = { + "thresholds": self.init_thresholds, + "top_k": self.top_k, + "class_id": self.class_id, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_core/metrics/confusion_metrics_test.py b/keras_core/metrics/confusion_metrics_test.py index fb3b9fb1a..6a544bfe4 100644 --- a/keras_core/metrics/confusion_metrics_test.py +++ b/keras_core/metrics/confusion_metrics_test.py @@ -534,3 +534,158 @@ class PrecisionTest(testing.TestCase): self.assertAlmostEqual(1, result) self.assertAlmostEqual(1, p_obj.true_positives) self.assertAlmostEqual(0, p_obj.false_positives) + + +class RecallTest(testing.TestCase): + def test_config(self): + r_obj = metrics.Recall( + name="my_recall", thresholds=[0.4, 0.9], top_k=15, class_id=12 + ) + self.assertEqual(r_obj.name, "my_recall") + self.assertLen(r_obj.variables, 2) + self.assertEqual( + [v.name for v in r_obj.variables], + ["true_positives", "false_negatives"], + ) + self.assertEqual(r_obj.thresholds, [0.4, 0.9]) + self.assertEqual(r_obj.top_k, 15) + self.assertEqual(r_obj.class_id, 12) + + # Check save and restore config + r_obj2 = metrics.Recall.from_config(r_obj.get_config()) + self.assertEqual(r_obj2.name, "my_recall") + self.assertLen(r_obj2.variables, 2) + self.assertEqual(r_obj2.thresholds, [0.4, 0.9]) + self.assertEqual(r_obj2.top_k, 15) + self.assertEqual(r_obj2.class_id, 12) + + def test_unweighted(self): + r_obj = metrics.Recall() + y_pred = np.array([1, 0, 1, 0]) + y_true = np.array([0, 1, 1, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + + def test_unweighted_all_incorrect(self): + r_obj = metrics.Recall(thresholds=[0.5]) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs) + y_true = np.array(1 - inputs) + self.assertAlmostEqual(0, r_obj(y_true, y_pred)) + + def test_weighted(self): + r_obj = metrics.Recall() + y_pred = np.array([[1, 0, 1, 0], [0, 1, 0, 1]]) + y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]]) + result = r_obj( + y_true, + y_pred, + sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]), + ) + weighted_tp = 3.0 + 1.0 + weighted_t = (2.0 + 3.0) + (4.0 + 1.0) + expected_recall = weighted_tp / weighted_t + self.assertAlmostEqual(expected_recall, result) + + def test_div_by_zero(self): + r_obj = metrics.Recall() + y_pred = np.array([0, 0, 0, 0]) + y_true = np.array([0, 0, 0, 0]) + self.assertEqual(0, r_obj(y_true, y_pred)) + + def test_unweighted_with_threshold(self): + r_obj = metrics.Recall(thresholds=[0.5, 0.7]) + y_pred = np.array([1, 0, 0.6, 0]) + y_true = np.array([0, 1, 1, 0]) + self.assertAllClose([0.5, 0.0], r_obj(y_true, y_pred), 0) + + def test_weighted_with_threshold(self): + r_obj = metrics.Recall(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[1, 4], [3, 2]], dtype="float32") + result = r_obj(y_true, y_pred, sample_weight=weights) + weighted_tp = 0 + 3.0 + weighted_positives = (0 + 3.0) + (4.0 + 0.0) + expected_recall = weighted_tp / weighted_positives + self.assertAllClose([expected_recall, 0], result, 1e-3) + + def test_multiple_updates(self): + r_obj = metrics.Recall(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[1, 4], [3, 2]], dtype="float32") + for _ in range(2): + r_obj.update_state(y_true, y_pred, sample_weight=weights) + + weighted_tp = (0 + 3.0) + (0 + 3.0) + weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + ( + (0 + 3.0) + (4.0 + 0.0) + ) + expected_recall = weighted_tp / weighted_positives + self.assertAllClose([expected_recall, 0], r_obj.result(), 1e-3) + + def test_unweighted_top_k(self): + r_obj = metrics.Recall(top_k=3) + y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + + def test_weighted_top_k(self): + r_obj = metrics.Recall(top_k=3) + y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]]) + y_true1 = np.array([[0, 1, 1, 0, 1]]) + r_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]])) + + y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2]) + y_true2 = np.array([1, 0, 1, 1, 1]) + result = r_obj(y_true2, y_pred2, sample_weight=np.array(3)) + + tp = (2 + 5) + (3 + 3) + positives = (4 + 2 + 5) + (3 + 3 + 3 + 3) + expected_recall = tp / positives + self.assertAlmostEqual(expected_recall, result) + + def test_unweighted_class_id(self): + r_obj = metrics.Recall(class_id=2) + + y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(1, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(0, r_obj.false_negatives) + + y_pred = np.array([0.2, 0.1, 0, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(1, r_obj.false_negatives) + + y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2]) + y_true = np.array([0, 1, 0, 0, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(1, r_obj.false_negatives) + + def test_unweighted_top_k_and_class_id(self): + r_obj = metrics.Recall(class_id=2, top_k=2) + + y_pred = np.array([0.2, 0.6, 0.3, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(1, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(0, r_obj.false_negatives) + + y_pred = np.array([1, 1, 0.9, 1, 1]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(1, r_obj.false_negatives) + + def test_unweighted_top_k_and_threshold(self): + r_obj = metrics.Recall(thresholds=0.7, top_k=2) + + y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2]) + y_true = np.array([1, 1, 1, 0, 1]) + self.assertAlmostEqual(0.25, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(3, r_obj.false_negatives)