Add Recall metric (#70)

* Add recall

* Review comments
This commit is contained in:
Ian Stenbit 2023-05-02 16:20:15 -06:00 committed by Francois Chollet
parent 841b8d702d
commit 07504d53a8
3 changed files with 316 additions and 19 deletions

@ -8,6 +8,7 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
from keras_core.metrics.confusion_metrics import FalseNegatives
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.confusion_metrics import Precision
from keras_core.metrics.confusion_metrics import Recall
from keras_core.metrics.confusion_metrics import TrueNegatives
from keras_core.metrics.confusion_metrics import TruePositives
from keras_core.metrics.hinge_metrics import CategoricalHinge
@ -39,6 +40,7 @@ ALL_OBJECTS = {
FalseNegatives,
FalsePositives,
Precision,
Recall,
TrueNegatives,
TruePositives,
# Hinge

@ -14,9 +14,9 @@ class _ConfusionMatrixConditionCount(Metric):
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
conditions.
thresholds: (Optional) Defaults to 0.5. A float value or a python list /
tuple of float threshold values in [0, 1]. A threshold is compared
tuple of float threshold values in `[0, 1]`. A threshold is compared
with prediction values to determine the truth value of predictions
(i.e., above the threshold is `true`, below is `false`). One metric
(i.e., above the threshold is `True`, below is `False`). One metric
value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
@ -85,9 +85,9 @@ class FalsePositives(_ConfusionMatrixConditionCount):
Args:
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
list/tuple of float threshold values in [0, 1]. A threshold is
list/tuple of float threshold values in `[0, 1]`. A threshold is
compared with prediction values to determine the truth value of
predictions (i.e., above the threshold is `true`, below is `false`).
predictions (i.e., above the threshold is `True`, below is `False`).
If used with a loss function that sets `from_logits=True` (i.e. no
sigmoid applied to predictions), `thresholds` should be set to 0.
One metric value is generated for each threshold value.
@ -129,9 +129,9 @@ class FalseNegatives(_ConfusionMatrixConditionCount):
Args:
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
list/tuple of float threshold values in [0, 1]. A threshold is
list/tuple of float threshold values in `[0, 1]`. A threshold is
compared with prediction values to determine the truth value of
predictions (i.e., above the threshold is `true`, below is `false`).
predictions (i.e., above the threshold is `True`, below is `False`).
If used with a loss function that sets `from_logits=True` (i.e. no
sigmoid applied to predictions), `thresholds` should be set to 0.
One metric value is generated for each threshold value.
@ -173,9 +173,9 @@ class TrueNegatives(_ConfusionMatrixConditionCount):
Args:
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
list/tuple of float threshold values in [0, 1]. A threshold is
list/tuple of float threshold values in `[0, 1]`. A threshold is
compared with prediction values to determine the truth value of
predictions (i.e., above the threshold is `true`, below is `false`).
predictions (i.e., above the threshold is `True`, below is `False`).
If used with a loss function that sets `from_logits=True` (i.e. no
sigmoid applied to predictions), `thresholds` should be set to 0.
One metric value is generated for each threshold value.
@ -217,9 +217,9 @@ class TruePositives(_ConfusionMatrixConditionCount):
Args:
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
list/tuple of float threshold values in [0, 1]. A threshold is
list/tuple of float threshold values in `[0, 1]`. A threshold is
compared with prediction values to determine the truth value of
predictions (i.e., above the threshold is `true`, below is `false`).
predictions (i.e., above the threshold is `True`, below is `False`).
If used with a loss function that sets `from_logits=True` (i.e. no
sigmoid applied to predictions), `thresholds` should be set to 0.
One metric value is generated for each threshold value.
@ -272,13 +272,13 @@ class Precision(Metric):
Args:
thresholds: (Optional) A float value, or a Python list/tuple of float
threshold values in [0, 1]. A threshold is compared with prediction
values to determine the truth value of predictions (i.e., above the
threshold is `true`, below is `false`). If used with a loss function
that sets `from_logits=True` (i.e. no sigmoid applied to
predictions), `thresholds` should be set to 0. One metric value is
generated for each threshold value. If neither thresholds nor top_k
are set, the default is to calculate precision with
threshold values in `[0, 1]`. A threshold is compared with
prediction values to determine the truth value of predictions (i.e.,
above the threshold is `True`, below is `False`). If used with a
loss function that sets `from_logits=True` (i.e. no sigmoid applied
to predictions), `thresholds` should be set to 0. One metric value
is generated for each threshold value. If neither thresholds nor
top_k are set, the default is to calculate precision with
`thresholds=0.5`.
top_k: (Optional) Unset by default. An int value specifying the top-k
predictions to consider when calculating precision.
@ -326,7 +326,7 @@ class Precision(Metric):
```python
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras_core.metrics.Precision(thresholds=0)])
```
"""
@ -369,7 +369,7 @@ class Precision(Metric):
Can be a `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
return metrics_utils.update_confusion_matrix_variables(
metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
@ -403,3 +403,143 @@ class Precision(Metric):
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_core_export("keras_core.metrics.Recall")
class Recall(Metric):
"""Computes the recall of the predictions with respect to the labels.
This metric creates two local variables, `true_positives` and
`false_negatives`, that are used to compute the recall. This value is
ultimately returned as `recall`, an idempotent operation that simply divides
`true_positives` by the sum of `true_positives` and `false_negatives`.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `top_k` is set, recall will be computed as how often on average a class
among the labels of a batch entry is in the top-k predictions.
If `class_id` is specified, we calculate recall by considering only the
entries in the batch for which `class_id` is in the label, and computing the
fraction of them for which `class_id` is above the threshold and/or in the
top-k predictions.
Args:
thresholds: (Optional) A float value, or a Python list/tuple of float
threshold values in `[0, 1]`. A threshold is compared with
prediction values to determine the truth value of predictions (i.e.,
above thethreshold is `True`, below is `False`). If used with a loss
function that sets `from_logits=True` (i.e. no sigmoid applied to
predictions), `thresholds` should be set to 0. One metric value is
generated for each threshold value. If neither thresholds nor top_k
are set, the default is to calculate recall with `thresholds=0.5`.
top_k: (Optional) Unset by default. An int value specifying the top-k
predictions to consider when calculating recall.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.Recall()
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
>>> m.result()
0.6666667
>>> m.reset_state()
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
>>> m.result()
1.0
Usage with `compile()` API:
```python
model.compile(optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.Recall()])
```
Usage with a loss with `from_logits=True`:
```python
model.compile(optimizer='adam',
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras_core.metrics.Recall(thresholds=0)])
```
"""
def __init__(
self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
self.init_thresholds = thresholds
self.top_k = top_k
self.class_id = class_id
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=default_threshold
)
self._thresholds_distributed_evenly = (
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
)
self.true_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_positives",
)
self.false_negatives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_negatives",
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates true positive and false negative statistics.
Args:
y_true: The ground truth values, with the same dimensions as
`y_pred`. Will be cast to `bool`.
y_pred: The predicted values. Each element must be in the range
`[0, 1]`.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
},
y_true,
y_pred,
thresholds=self.thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
top_k=self.top_k,
class_id=self.class_id,
sample_weight=sample_weight,
)
def result(self):
result = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
return result[0] if len(self.thresholds) == 1 else result
def reset_state(self):
num_thresholds = len(to_list(self.thresholds))
self.true_positives.assign(ops.zeros((num_thresholds,)))
self.false_negatives.assign(ops.zeros((num_thresholds,)))
def get_config(self):
config = {
"thresholds": self.init_thresholds,
"top_k": self.top_k,
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

@ -534,3 +534,158 @@ class PrecisionTest(testing.TestCase):
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)
class RecallTest(testing.TestCase):
def test_config(self):
r_obj = metrics.Recall(
name="my_recall", thresholds=[0.4, 0.9], top_k=15, class_id=12
)
self.assertEqual(r_obj.name, "my_recall")
self.assertLen(r_obj.variables, 2)
self.assertEqual(
[v.name for v in r_obj.variables],
["true_positives", "false_negatives"],
)
self.assertEqual(r_obj.thresholds, [0.4, 0.9])
self.assertEqual(r_obj.top_k, 15)
self.assertEqual(r_obj.class_id, 12)
# Check save and restore config
r_obj2 = metrics.Recall.from_config(r_obj.get_config())
self.assertEqual(r_obj2.name, "my_recall")
self.assertLen(r_obj2.variables, 2)
self.assertEqual(r_obj2.thresholds, [0.4, 0.9])
self.assertEqual(r_obj2.top_k, 15)
self.assertEqual(r_obj2.class_id, 12)
def test_unweighted(self):
r_obj = metrics.Recall()
y_pred = np.array([1, 0, 1, 0])
y_true = np.array([0, 1, 1, 0])
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
def test_unweighted_all_incorrect(self):
r_obj = metrics.Recall(thresholds=[0.5])
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs)
y_true = np.array(1 - inputs)
self.assertAlmostEqual(0, r_obj(y_true, y_pred))
def test_weighted(self):
r_obj = metrics.Recall()
y_pred = np.array([[1, 0, 1, 0], [0, 1, 0, 1]])
y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]])
result = r_obj(
y_true,
y_pred,
sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]),
)
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
expected_recall = weighted_tp / weighted_t
self.assertAlmostEqual(expected_recall, result)
def test_div_by_zero(self):
r_obj = metrics.Recall()
y_pred = np.array([0, 0, 0, 0])
y_true = np.array([0, 0, 0, 0])
self.assertEqual(0, r_obj(y_true, y_pred))
def test_unweighted_with_threshold(self):
r_obj = metrics.Recall(thresholds=[0.5, 0.7])
y_pred = np.array([1, 0, 0.6, 0])
y_true = np.array([0, 1, 1, 0])
self.assertAllClose([0.5, 0.0], r_obj(y_true, y_pred), 0)
def test_weighted_with_threshold(self):
r_obj = metrics.Recall(thresholds=[0.5, 1.0])
y_true = np.array([[0, 1], [1, 0]])
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
weights = np.array([[1, 4], [3, 2]], dtype="float32")
result = r_obj(y_true, y_pred, sample_weight=weights)
weighted_tp = 0 + 3.0
weighted_positives = (0 + 3.0) + (4.0 + 0.0)
expected_recall = weighted_tp / weighted_positives
self.assertAllClose([expected_recall, 0], result, 1e-3)
def test_multiple_updates(self):
r_obj = metrics.Recall(thresholds=[0.5, 1.0])
y_true = np.array([[0, 1], [1, 0]])
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
weights = np.array([[1, 4], [3, 2]], dtype="float32")
for _ in range(2):
r_obj.update_state(y_true, y_pred, sample_weight=weights)
weighted_tp = (0 + 3.0) + (0 + 3.0)
weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + (
(0 + 3.0) + (4.0 + 0.0)
)
expected_recall = weighted_tp / weighted_positives
self.assertAllClose([expected_recall, 0], r_obj.result(), 1e-3)
def test_unweighted_top_k(self):
r_obj = metrics.Recall(top_k=3)
y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
def test_weighted_top_k(self):
r_obj = metrics.Recall(top_k=3)
y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]])
y_true1 = np.array([[0, 1, 1, 0, 1]])
r_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]]))
y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2])
y_true2 = np.array([1, 0, 1, 1, 1])
result = r_obj(y_true2, y_pred2, sample_weight=np.array(3))
tp = (2 + 5) + (3 + 3)
positives = (4 + 2 + 5) + (3 + 3 + 3 + 3)
expected_recall = tp / positives
self.assertAlmostEqual(expected_recall, result)
def test_unweighted_class_id(self):
r_obj = metrics.Recall(class_id=2)
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
self.assertAlmostEqual(1, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(0, r_obj.false_negatives)
y_pred = np.array([0.2, 0.1, 0, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(1, r_obj.false_negatives)
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
y_true = np.array([0, 1, 0, 0, 0])
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(1, r_obj.false_negatives)
def test_unweighted_top_k_and_class_id(self):
r_obj = metrics.Recall(class_id=2, top_k=2)
y_pred = np.array([0.2, 0.6, 0.3, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
self.assertAlmostEqual(1, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(0, r_obj.false_negatives)
y_pred = np.array([1, 1, 0.9, 1, 1])
y_true = np.array([0, 1, 1, 0, 0])
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(1, r_obj.false_negatives)
def test_unweighted_top_k_and_threshold(self):
r_obj = metrics.Recall(thresholds=0.7, top_k=2)
y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2])
y_true = np.array([1, 1, 1, 0, 1])
self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(3, r_obj.false_negatives)