parent
841b8d702d
commit
07504d53a8
@ -8,6 +8,7 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
|
|||||||
from keras_core.metrics.confusion_metrics import FalseNegatives
|
from keras_core.metrics.confusion_metrics import FalseNegatives
|
||||||
from keras_core.metrics.confusion_metrics import FalsePositives
|
from keras_core.metrics.confusion_metrics import FalsePositives
|
||||||
from keras_core.metrics.confusion_metrics import Precision
|
from keras_core.metrics.confusion_metrics import Precision
|
||||||
|
from keras_core.metrics.confusion_metrics import Recall
|
||||||
from keras_core.metrics.confusion_metrics import TrueNegatives
|
from keras_core.metrics.confusion_metrics import TrueNegatives
|
||||||
from keras_core.metrics.confusion_metrics import TruePositives
|
from keras_core.metrics.confusion_metrics import TruePositives
|
||||||
from keras_core.metrics.hinge_metrics import CategoricalHinge
|
from keras_core.metrics.hinge_metrics import CategoricalHinge
|
||||||
@ -39,6 +40,7 @@ ALL_OBJECTS = {
|
|||||||
FalseNegatives,
|
FalseNegatives,
|
||||||
FalsePositives,
|
FalsePositives,
|
||||||
Precision,
|
Precision,
|
||||||
|
Recall,
|
||||||
TrueNegatives,
|
TrueNegatives,
|
||||||
TruePositives,
|
TruePositives,
|
||||||
# Hinge
|
# Hinge
|
||||||
|
@ -14,9 +14,9 @@ class _ConfusionMatrixConditionCount(Metric):
|
|||||||
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
|
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
|
||||||
conditions.
|
conditions.
|
||||||
thresholds: (Optional) Defaults to 0.5. A float value or a python list /
|
thresholds: (Optional) Defaults to 0.5. A float value or a python list /
|
||||||
tuple of float threshold values in [0, 1]. A threshold is compared
|
tuple of float threshold values in `[0, 1]`. A threshold is compared
|
||||||
with prediction values to determine the truth value of predictions
|
with prediction values to determine the truth value of predictions
|
||||||
(i.e., above the threshold is `true`, below is `false`). One metric
|
(i.e., above the threshold is `True`, below is `False`). One metric
|
||||||
value is generated for each threshold value.
|
value is generated for each threshold value.
|
||||||
name: (Optional) string name of the metric instance.
|
name: (Optional) string name of the metric instance.
|
||||||
dtype: (Optional) data type of the metric result.
|
dtype: (Optional) data type of the metric result.
|
||||||
@ -85,9 +85,9 @@ class FalsePositives(_ConfusionMatrixConditionCount):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
||||||
list/tuple of float threshold values in [0, 1]. A threshold is
|
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
||||||
compared with prediction values to determine the truth value of
|
compared with prediction values to determine the truth value of
|
||||||
predictions (i.e., above the threshold is `true`, below is `false`).
|
predictions (i.e., above the threshold is `True`, below is `False`).
|
||||||
If used with a loss function that sets `from_logits=True` (i.e. no
|
If used with a loss function that sets `from_logits=True` (i.e. no
|
||||||
sigmoid applied to predictions), `thresholds` should be set to 0.
|
sigmoid applied to predictions), `thresholds` should be set to 0.
|
||||||
One metric value is generated for each threshold value.
|
One metric value is generated for each threshold value.
|
||||||
@ -129,9 +129,9 @@ class FalseNegatives(_ConfusionMatrixConditionCount):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
||||||
list/tuple of float threshold values in [0, 1]. A threshold is
|
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
||||||
compared with prediction values to determine the truth value of
|
compared with prediction values to determine the truth value of
|
||||||
predictions (i.e., above the threshold is `true`, below is `false`).
|
predictions (i.e., above the threshold is `True`, below is `False`).
|
||||||
If used with a loss function that sets `from_logits=True` (i.e. no
|
If used with a loss function that sets `from_logits=True` (i.e. no
|
||||||
sigmoid applied to predictions), `thresholds` should be set to 0.
|
sigmoid applied to predictions), `thresholds` should be set to 0.
|
||||||
One metric value is generated for each threshold value.
|
One metric value is generated for each threshold value.
|
||||||
@ -173,9 +173,9 @@ class TrueNegatives(_ConfusionMatrixConditionCount):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
||||||
list/tuple of float threshold values in [0, 1]. A threshold is
|
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
||||||
compared with prediction values to determine the truth value of
|
compared with prediction values to determine the truth value of
|
||||||
predictions (i.e., above the threshold is `true`, below is `false`).
|
predictions (i.e., above the threshold is `True`, below is `False`).
|
||||||
If used with a loss function that sets `from_logits=True` (i.e. no
|
If used with a loss function that sets `from_logits=True` (i.e. no
|
||||||
sigmoid applied to predictions), `thresholds` should be set to 0.
|
sigmoid applied to predictions), `thresholds` should be set to 0.
|
||||||
One metric value is generated for each threshold value.
|
One metric value is generated for each threshold value.
|
||||||
@ -217,9 +217,9 @@ class TruePositives(_ConfusionMatrixConditionCount):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
||||||
list/tuple of float threshold values in [0, 1]. A threshold is
|
list/tuple of float threshold values in `[0, 1]`. A threshold is
|
||||||
compared with prediction values to determine the truth value of
|
compared with prediction values to determine the truth value of
|
||||||
predictions (i.e., above the threshold is `true`, below is `false`).
|
predictions (i.e., above the threshold is `True`, below is `False`).
|
||||||
If used with a loss function that sets `from_logits=True` (i.e. no
|
If used with a loss function that sets `from_logits=True` (i.e. no
|
||||||
sigmoid applied to predictions), `thresholds` should be set to 0.
|
sigmoid applied to predictions), `thresholds` should be set to 0.
|
||||||
One metric value is generated for each threshold value.
|
One metric value is generated for each threshold value.
|
||||||
@ -272,13 +272,13 @@ class Precision(Metric):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds: (Optional) A float value, or a Python list/tuple of float
|
thresholds: (Optional) A float value, or a Python list/tuple of float
|
||||||
threshold values in [0, 1]. A threshold is compared with prediction
|
threshold values in `[0, 1]`. A threshold is compared with
|
||||||
values to determine the truth value of predictions (i.e., above the
|
prediction values to determine the truth value of predictions (i.e.,
|
||||||
threshold is `true`, below is `false`). If used with a loss function
|
above the threshold is `True`, below is `False`). If used with a
|
||||||
that sets `from_logits=True` (i.e. no sigmoid applied to
|
loss function that sets `from_logits=True` (i.e. no sigmoid applied
|
||||||
predictions), `thresholds` should be set to 0. One metric value is
|
to predictions), `thresholds` should be set to 0. One metric value
|
||||||
generated for each threshold value. If neither thresholds nor top_k
|
is generated for each threshold value. If neither thresholds nor
|
||||||
are set, the default is to calculate precision with
|
top_k are set, the default is to calculate precision with
|
||||||
`thresholds=0.5`.
|
`thresholds=0.5`.
|
||||||
top_k: (Optional) Unset by default. An int value specifying the top-k
|
top_k: (Optional) Unset by default. An int value specifying the top-k
|
||||||
predictions to consider when calculating precision.
|
predictions to consider when calculating precision.
|
||||||
@ -326,7 +326,7 @@ class Precision(Metric):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
|
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
|
||||||
metrics=[keras_core.metrics.Precision(thresholds=0)])
|
metrics=[keras_core.metrics.Precision(thresholds=0)])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -369,7 +369,7 @@ class Precision(Metric):
|
|||||||
Can be a `Tensor` whose rank is either 0, or the same rank as
|
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||||
`y_true`, and must be broadcastable to `y_true`.
|
`y_true`, and must be broadcastable to `y_true`.
|
||||||
"""
|
"""
|
||||||
return metrics_utils.update_confusion_matrix_variables(
|
metrics_utils.update_confusion_matrix_variables(
|
||||||
{
|
{
|
||||||
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
||||||
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
||||||
@ -403,3 +403,143 @@ class Precision(Metric):
|
|||||||
}
|
}
|
||||||
base_config = super().get_config()
|
base_config = super().get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.metrics.Recall")
|
||||||
|
class Recall(Metric):
|
||||||
|
"""Computes the recall of the predictions with respect to the labels.
|
||||||
|
|
||||||
|
This metric creates two local variables, `true_positives` and
|
||||||
|
`false_negatives`, that are used to compute the recall. This value is
|
||||||
|
ultimately returned as `recall`, an idempotent operation that simply divides
|
||||||
|
`true_positives` by the sum of `true_positives` and `false_negatives`.
|
||||||
|
|
||||||
|
If `sample_weight` is `None`, weights default to 1.
|
||||||
|
Use `sample_weight` of 0 to mask values.
|
||||||
|
|
||||||
|
If `top_k` is set, recall will be computed as how often on average a class
|
||||||
|
among the labels of a batch entry is in the top-k predictions.
|
||||||
|
|
||||||
|
If `class_id` is specified, we calculate recall by considering only the
|
||||||
|
entries in the batch for which `class_id` is in the label, and computing the
|
||||||
|
fraction of them for which `class_id` is above the threshold and/or in the
|
||||||
|
top-k predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thresholds: (Optional) A float value, or a Python list/tuple of float
|
||||||
|
threshold values in `[0, 1]`. A threshold is compared with
|
||||||
|
prediction values to determine the truth value of predictions (i.e.,
|
||||||
|
above thethreshold is `True`, below is `False`). If used with a loss
|
||||||
|
function that sets `from_logits=True` (i.e. no sigmoid applied to
|
||||||
|
predictions), `thresholds` should be set to 0. One metric value is
|
||||||
|
generated for each threshold value. If neither thresholds nor top_k
|
||||||
|
are set, the default is to calculate recall with `thresholds=0.5`.
|
||||||
|
top_k: (Optional) Unset by default. An int value specifying the top-k
|
||||||
|
predictions to consider when calculating recall.
|
||||||
|
class_id: (Optional) Integer class ID for which we want binary metrics.
|
||||||
|
This must be in the half-open interval `[0, num_classes)`, where
|
||||||
|
`num_classes` is the last dimension of predictions.
|
||||||
|
name: (Optional) string name of the metric instance.
|
||||||
|
dtype: (Optional) data type of the metric result.
|
||||||
|
|
||||||
|
Standalone usage:
|
||||||
|
|
||||||
|
>>> m = keras_core.metrics.Recall()
|
||||||
|
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
|
||||||
|
>>> m.result()
|
||||||
|
0.6666667
|
||||||
|
|
||||||
|
>>> m.reset_state()
|
||||||
|
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
||||||
|
>>> m.result()
|
||||||
|
1.0
|
||||||
|
|
||||||
|
Usage with `compile()` API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model.compile(optimizer='sgd',
|
||||||
|
loss='mse',
|
||||||
|
metrics=[keras_core.metrics.Recall()])
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage with a loss with `from_logits=True`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
|
||||||
|
metrics=[keras_core.metrics.Recall(thresholds=0)])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
|
||||||
|
):
|
||||||
|
super().__init__(name=name, dtype=dtype)
|
||||||
|
self.init_thresholds = thresholds
|
||||||
|
self.top_k = top_k
|
||||||
|
self.class_id = class_id
|
||||||
|
|
||||||
|
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
|
||||||
|
self.thresholds = metrics_utils.parse_init_thresholds(
|
||||||
|
thresholds, default_threshold=default_threshold
|
||||||
|
)
|
||||||
|
self._thresholds_distributed_evenly = (
|
||||||
|
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
|
||||||
|
)
|
||||||
|
self.true_positives = self.add_variable(
|
||||||
|
shape=(len(self.thresholds),),
|
||||||
|
initializer=initializers.Zeros(),
|
||||||
|
name="true_positives",
|
||||||
|
)
|
||||||
|
self.false_negatives = self.add_variable(
|
||||||
|
shape=(len(self.thresholds),),
|
||||||
|
initializer=initializers.Zeros(),
|
||||||
|
name="false_negatives",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
"""Accumulates true positive and false negative statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: The ground truth values, with the same dimensions as
|
||||||
|
`y_pred`. Will be cast to `bool`.
|
||||||
|
y_pred: The predicted values. Each element must be in the range
|
||||||
|
`[0, 1]`.
|
||||||
|
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||||
|
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`y_true`, and must be broadcastable to `y_true`.
|
||||||
|
"""
|
||||||
|
metrics_utils.update_confusion_matrix_variables(
|
||||||
|
{
|
||||||
|
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
||||||
|
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
||||||
|
},
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
thresholds=self.thresholds,
|
||||||
|
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
||||||
|
top_k=self.top_k,
|
||||||
|
class_id=self.class_id,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
def result(self):
|
||||||
|
result = ops.divide(
|
||||||
|
self.true_positives,
|
||||||
|
self.true_positives + self.false_negatives + backend.epsilon(),
|
||||||
|
)
|
||||||
|
return result[0] if len(self.thresholds) == 1 else result
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
num_thresholds = len(to_list(self.thresholds))
|
||||||
|
self.true_positives.assign(ops.zeros((num_thresholds,)))
|
||||||
|
self.false_negatives.assign(ops.zeros((num_thresholds,)))
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {
|
||||||
|
"thresholds": self.init_thresholds,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"class_id": self.class_id,
|
||||||
|
}
|
||||||
|
base_config = super().get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
@ -534,3 +534,158 @@ class PrecisionTest(testing.TestCase):
|
|||||||
self.assertAlmostEqual(1, result)
|
self.assertAlmostEqual(1, result)
|
||||||
self.assertAlmostEqual(1, p_obj.true_positives)
|
self.assertAlmostEqual(1, p_obj.true_positives)
|
||||||
self.assertAlmostEqual(0, p_obj.false_positives)
|
self.assertAlmostEqual(0, p_obj.false_positives)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallTest(testing.TestCase):
|
||||||
|
def test_config(self):
|
||||||
|
r_obj = metrics.Recall(
|
||||||
|
name="my_recall", thresholds=[0.4, 0.9], top_k=15, class_id=12
|
||||||
|
)
|
||||||
|
self.assertEqual(r_obj.name, "my_recall")
|
||||||
|
self.assertLen(r_obj.variables, 2)
|
||||||
|
self.assertEqual(
|
||||||
|
[v.name for v in r_obj.variables],
|
||||||
|
["true_positives", "false_negatives"],
|
||||||
|
)
|
||||||
|
self.assertEqual(r_obj.thresholds, [0.4, 0.9])
|
||||||
|
self.assertEqual(r_obj.top_k, 15)
|
||||||
|
self.assertEqual(r_obj.class_id, 12)
|
||||||
|
|
||||||
|
# Check save and restore config
|
||||||
|
r_obj2 = metrics.Recall.from_config(r_obj.get_config())
|
||||||
|
self.assertEqual(r_obj2.name, "my_recall")
|
||||||
|
self.assertLen(r_obj2.variables, 2)
|
||||||
|
self.assertEqual(r_obj2.thresholds, [0.4, 0.9])
|
||||||
|
self.assertEqual(r_obj2.top_k, 15)
|
||||||
|
self.assertEqual(r_obj2.class_id, 12)
|
||||||
|
|
||||||
|
def test_unweighted(self):
|
||||||
|
r_obj = metrics.Recall()
|
||||||
|
y_pred = np.array([1, 0, 1, 0])
|
||||||
|
y_true = np.array([0, 1, 1, 0])
|
||||||
|
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
|
||||||
|
|
||||||
|
def test_unweighted_all_incorrect(self):
|
||||||
|
r_obj = metrics.Recall(thresholds=[0.5])
|
||||||
|
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||||
|
y_pred = np.array(inputs)
|
||||||
|
y_true = np.array(1 - inputs)
|
||||||
|
self.assertAlmostEqual(0, r_obj(y_true, y_pred))
|
||||||
|
|
||||||
|
def test_weighted(self):
|
||||||
|
r_obj = metrics.Recall()
|
||||||
|
y_pred = np.array([[1, 0, 1, 0], [0, 1, 0, 1]])
|
||||||
|
y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]])
|
||||||
|
result = r_obj(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]),
|
||||||
|
)
|
||||||
|
weighted_tp = 3.0 + 1.0
|
||||||
|
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
|
||||||
|
expected_recall = weighted_tp / weighted_t
|
||||||
|
self.assertAlmostEqual(expected_recall, result)
|
||||||
|
|
||||||
|
def test_div_by_zero(self):
|
||||||
|
r_obj = metrics.Recall()
|
||||||
|
y_pred = np.array([0, 0, 0, 0])
|
||||||
|
y_true = np.array([0, 0, 0, 0])
|
||||||
|
self.assertEqual(0, r_obj(y_true, y_pred))
|
||||||
|
|
||||||
|
def test_unweighted_with_threshold(self):
|
||||||
|
r_obj = metrics.Recall(thresholds=[0.5, 0.7])
|
||||||
|
y_pred = np.array([1, 0, 0.6, 0])
|
||||||
|
y_true = np.array([0, 1, 1, 0])
|
||||||
|
self.assertAllClose([0.5, 0.0], r_obj(y_true, y_pred), 0)
|
||||||
|
|
||||||
|
def test_weighted_with_threshold(self):
|
||||||
|
r_obj = metrics.Recall(thresholds=[0.5, 1.0])
|
||||||
|
y_true = np.array([[0, 1], [1, 0]])
|
||||||
|
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
|
||||||
|
weights = np.array([[1, 4], [3, 2]], dtype="float32")
|
||||||
|
result = r_obj(y_true, y_pred, sample_weight=weights)
|
||||||
|
weighted_tp = 0 + 3.0
|
||||||
|
weighted_positives = (0 + 3.0) + (4.0 + 0.0)
|
||||||
|
expected_recall = weighted_tp / weighted_positives
|
||||||
|
self.assertAllClose([expected_recall, 0], result, 1e-3)
|
||||||
|
|
||||||
|
def test_multiple_updates(self):
|
||||||
|
r_obj = metrics.Recall(thresholds=[0.5, 1.0])
|
||||||
|
y_true = np.array([[0, 1], [1, 0]])
|
||||||
|
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
|
||||||
|
weights = np.array([[1, 4], [3, 2]], dtype="float32")
|
||||||
|
for _ in range(2):
|
||||||
|
r_obj.update_state(y_true, y_pred, sample_weight=weights)
|
||||||
|
|
||||||
|
weighted_tp = (0 + 3.0) + (0 + 3.0)
|
||||||
|
weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + (
|
||||||
|
(0 + 3.0) + (4.0 + 0.0)
|
||||||
|
)
|
||||||
|
expected_recall = weighted_tp / weighted_positives
|
||||||
|
self.assertAllClose([expected_recall, 0], r_obj.result(), 1e-3)
|
||||||
|
|
||||||
|
def test_unweighted_top_k(self):
|
||||||
|
r_obj = metrics.Recall(top_k=3)
|
||||||
|
y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2])
|
||||||
|
y_true = np.array([0, 1, 1, 0, 0])
|
||||||
|
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
|
||||||
|
|
||||||
|
def test_weighted_top_k(self):
|
||||||
|
r_obj = metrics.Recall(top_k=3)
|
||||||
|
y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]])
|
||||||
|
y_true1 = np.array([[0, 1, 1, 0, 1]])
|
||||||
|
r_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]]))
|
||||||
|
|
||||||
|
y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2])
|
||||||
|
y_true2 = np.array([1, 0, 1, 1, 1])
|
||||||
|
result = r_obj(y_true2, y_pred2, sample_weight=np.array(3))
|
||||||
|
|
||||||
|
tp = (2 + 5) + (3 + 3)
|
||||||
|
positives = (4 + 2 + 5) + (3 + 3 + 3 + 3)
|
||||||
|
expected_recall = tp / positives
|
||||||
|
self.assertAlmostEqual(expected_recall, result)
|
||||||
|
|
||||||
|
def test_unweighted_class_id(self):
|
||||||
|
r_obj = metrics.Recall(class_id=2)
|
||||||
|
|
||||||
|
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
|
||||||
|
y_true = np.array([0, 1, 1, 0, 0])
|
||||||
|
self.assertAlmostEqual(1, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(0, r_obj.false_negatives)
|
||||||
|
|
||||||
|
y_pred = np.array([0.2, 0.1, 0, 0, 0.2])
|
||||||
|
y_true = np.array([0, 1, 1, 0, 0])
|
||||||
|
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(1, r_obj.false_negatives)
|
||||||
|
|
||||||
|
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
|
||||||
|
y_true = np.array([0, 1, 0, 0, 0])
|
||||||
|
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(1, r_obj.false_negatives)
|
||||||
|
|
||||||
|
def test_unweighted_top_k_and_class_id(self):
|
||||||
|
r_obj = metrics.Recall(class_id=2, top_k=2)
|
||||||
|
|
||||||
|
y_pred = np.array([0.2, 0.6, 0.3, 0, 0.2])
|
||||||
|
y_true = np.array([0, 1, 1, 0, 0])
|
||||||
|
self.assertAlmostEqual(1, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(0, r_obj.false_negatives)
|
||||||
|
|
||||||
|
y_pred = np.array([1, 1, 0.9, 1, 1])
|
||||||
|
y_true = np.array([0, 1, 1, 0, 0])
|
||||||
|
self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(1, r_obj.false_negatives)
|
||||||
|
|
||||||
|
def test_unweighted_top_k_and_threshold(self):
|
||||||
|
r_obj = metrics.Recall(thresholds=0.7, top_k=2)
|
||||||
|
|
||||||
|
y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2])
|
||||||
|
y_true = np.array([1, 1, 1, 0, 1])
|
||||||
|
self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))
|
||||||
|
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||||
|
self.assertAlmostEqual(3, r_obj.false_negatives)
|
||||||
|
Loading…
Reference in New Issue
Block a user