Add 4 more Confusion Metrics (#73)

* Add recall

* Add SensitivitySpecificity metrics

* Review comments

* Add missing method from math

* Fix pytest

* Update config dict creation

* Remove sentence fragment

* Review comments

* Add special casing for min/max
This commit is contained in:
Ian Stenbit 2023-05-05 10:53:36 -06:00 committed by Francois Chollet
parent 1c994b749a
commit 55640456ff
8 changed files with 1033 additions and 65 deletions

@ -25,8 +25,8 @@ def mean(x, axis=None, keepdims=False):
return jnp.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False):
return jnp.max(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False, initial=None):
return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial)
def ones(shape, dtype="float32"):
@ -327,8 +327,8 @@ def meshgrid(*x, indexing="xy"):
return jnp.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False):
return jnp.min(x, axis=axis, keepdims=keepdims)
def min(x, axis=None, keepdims=False, initial=None):
return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial)
def minimum(x1, x2):

@ -26,7 +26,22 @@ def mean(x, axis=None, keepdims=False):
return tfnp.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False, initial=None):
# The TensorFlow numpy API implementation doesn't support `initial` so we
# handle it manually here.
if initial is not None:
return tf.math.maximum(
tfnp.max(x, axis=axis, keepdims=keepdims), initial
)
# TensorFlow returns -inf by default for an empty list, but for consistency
# with other backends and the numpy API we want to throw in this case.
tf.assert_greater(
size(x),
tf.constant(0, dtype=tf.int64),
message="Cannot compute the max of an empty tensor.",
)
return tfnp.max(x, axis=axis, keepdims=keepdims)
@ -328,7 +343,22 @@ def meshgrid(*x, indexing="xy"):
return tfnp.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False):
def min(x, axis=None, keepdims=False, initial=None):
# The TensorFlow numpy API implementation doesn't support `initial` so we
# handle it manually here.
if initial is not None:
return tf.math.minimum(
tfnp.min(x, axis=axis, keepdims=keepdims), initial
)
# TensorFlow returns inf by default for an empty list, but for consistency
# with other backends and the numpy API we want to throw in this case.
tf.assert_greater(
size(x),
tf.constant(0, dtype=tf.int64),
message="Cannot compute the min of an empty tensor.",
)
return tfnp.min(x, axis=axis, keepdims=keepdims)

@ -8,7 +8,11 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
from keras_core.metrics.confusion_metrics import FalseNegatives
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.confusion_metrics import Precision
from keras_core.metrics.confusion_metrics import PrecisionAtRecall
from keras_core.metrics.confusion_metrics import Recall
from keras_core.metrics.confusion_metrics import RecallAtPrecision
from keras_core.metrics.confusion_metrics import SensitivityAtSpecificity
from keras_core.metrics.confusion_metrics import SpecificityAtSensitivity
from keras_core.metrics.confusion_metrics import TrueNegatives
from keras_core.metrics.confusion_metrics import TruePositives
from keras_core.metrics.hinge_metrics import CategoricalHinge
@ -40,7 +44,11 @@ ALL_OBJECTS = {
FalseNegatives,
FalsePositives,
Precision,
PrecisionAtRecall,
Recall,
RecallAtPrecision,
SensitivityAtSpecificity,
SpecificityAtSensitivity,
TrueNegatives,
TruePositives,
# Hinge

@ -69,7 +69,7 @@ class _ConfusionMatrixConditionCount(Metric):
def get_config(self):
config = {"thresholds": self.init_thresholds}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
@keras_core_export("keras_core.metrics.FalsePositives")
@ -402,7 +402,7 @@ class Precision(Metric):
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
@keras_core_export("keras_core.metrics.Recall")
@ -543,4 +543,512 @@ class Recall(Metric):
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
class SensitivitySpecificityBase(Metric):
"""Abstract base class for computing sensitivity and specificity.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
"""
def __init__(
self, value, num_thresholds=200, class_id=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
if num_thresholds <= 0:
raise ValueError(
"Argument `num_thresholds` must be an integer > 0. "
f"Received: num_thresholds={num_thresholds}"
)
self.value = value
self.class_id = class_id
# Compute `num_thresholds` thresholds in [0, 1]
if num_thresholds == 1:
self.thresholds = [0.5]
self._thresholds_distributed_evenly = False
else:
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)
]
self.thresholds = [0.0] + thresholds + [1.0]
self._thresholds_distributed_evenly = True
self.true_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_positives",
)
self.false_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_positives",
)
self.true_negatives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_negatives",
)
self.false_negatives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_negatives",
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
},
y_true,
y_pred,
thresholds=self.thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
class_id=self.class_id,
sample_weight=sample_weight,
)
def reset_state(self):
num_thresholds = len(self.thresholds)
self.true_positives.assign(ops.zeros((num_thresholds,)))
self.false_positives.assign(ops.zeros((num_thresholds,)))
self.true_negatives.assign(ops.zeros((num_thresholds,)))
self.false_negatives.assign(ops.zeros((num_thresholds,)))
def get_config(self):
config = {"class_id": self.class_id}
base_config = super().get_config()
return {**base_config, **config}
def _find_max_under_constraint(self, constrained, dependent, predicate):
"""Returns the maximum of dependent_statistic that satisfies the
constraint.
Args:
constrained: Over these values the constraint is specified. A rank-1
tensor.
dependent: From these values the maximum that satiesfies the
constraint is selected. Values in this tensor and in
`constrained` are linked by having the same threshold at each
position, hence this tensor must have the same shape.
predicate: A binary boolean functor to be applied to arguments
`constrained` and `self.value`, e.g. `ops.greater`.
Returns:
maximal dependent value, if no value satisfies the constraint 0.0.
"""
feasible = ops.array(ops.nonzero(predicate(constrained, self.value)))
print(feasible)
feasible_exists = ops.greater(ops.size(feasible), 0)
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
return ops.where(feasible_exists, max_dependent, 0.0)
@keras_core_export("keras_core.metrics.SensitivityAtSpecificity")
class SensitivityAtSpecificity(SensitivitySpecificityBase):
"""Computes best sensitivity where specificity is >= specified value.
`Sensitivity` measures the proportion of actual positives that are correctly
identified as such `(tp / (tp + fn))`.
`Specificity` measures the proportion of actual negatives that are correctly
identified as such `(tn / (tn + fp))`.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the sensitivity at the given specificity. The threshold for the
given specificity value is computed and used to evaluate the corresponding
sensitivity.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
specificity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given specificity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.SensitivityAtSpecificity(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[1, 1, 2, 2, 1])
>>> m.result()
0.333333
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.SensitivityAtSpecificity()])
```
"""
def __init__(
self,
specificity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if specificity < 0 or specificity > 1:
raise ValueError(
"Argument `specificity` must be in the range [0, 1]. "
f"Received: specificity={specificity}"
)
self.specificity = specificity
self.num_thresholds = num_thresholds
super().__init__(
specificity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
sensitivities = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
specificities = ops.divide(
self.true_negatives,
self.true_negatives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
specificities, sensitivities, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"specificity": self.specificity,
}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.SpecificityAtSensitivity")
class SpecificityAtSensitivity(SensitivitySpecificityBase):
"""Computes best specificity where sensitivity is >= specified value.
`Sensitivity` measures the proportion of actual positives that are correctly
identified as such `(tp / (tp + fn))`.
`Specificity` measures the proportion of actual negatives that are correctly
identified as such `(tn / (tn + fp))`.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the specificity at the given sensitivity. The threshold for the
given sensitivity value is computed and used to evaluate the corresponding
specificity.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
sensitivity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given sensitivity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.SpecificityAtSensitivity(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.66666667
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[1, 1, 2, 2, 2])
>>> m.result()
0.5
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.SpecificityAtSensitivity()])
```
"""
def __init__(
self,
sensitivity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if sensitivity < 0 or sensitivity > 1:
raise ValueError(
"Argument `sensitivity` must be in the range [0, 1]. "
f"Received: sensitivity={sensitivity}"
)
self.sensitivity = sensitivity
self.num_thresholds = num_thresholds
super().__init__(
sensitivity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
sensitivities = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
specificities = ops.divide(
self.true_negatives,
self.true_negatives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
sensitivities, specificities, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"sensitivity": self.sensitivity,
}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.PrecisionAtRecall")
class PrecisionAtRecall(SensitivitySpecificityBase):
"""Computes best precision where recall is >= specified value.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the precision at the given recall. The threshold for the given
recall value is computed and used to evaluate the corresponding precision.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
Args:
recall: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given recall.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.PrecisionAtRecall(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[2, 2, 2, 1, 1])
>>> m.result()
0.33333333
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.PrecisionAtRecall(recall=0.8)])
```
"""
def __init__(
self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
):
if recall < 0 or recall > 1:
raise ValueError(
"Argument `recall` must be in the range [0, 1]. "
f"Received: recall={recall}"
)
self.recall = recall
self.num_thresholds = num_thresholds
super().__init__(
value=recall,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
recalls = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
precisions = ops.divide(
self.true_positives,
self.true_positives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
recalls, precisions, ops.greater_equal
)
def get_config(self):
config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.RecallAtPrecision")
class RecallAtPrecision(SensitivitySpecificityBase):
"""Computes best recall where precision is >= specified value.
For a given score-label-distribution the required precision might not
be achievable, in this case 0.0 is returned as recall.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the recall at the given precision. The threshold for the given
precision value is computed and used to evaluate the corresponding recall.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
Args:
precision: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds
to use for matching the given precision.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.RecallAtPrecision(0.8)
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[1, 0, 0, 1])
>>> m.result()
1.0
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.RecallAtPrecision(precision=0.8)])
```
"""
def __init__(
self,
precision,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if precision < 0 or precision > 1:
raise ValueError(
"Argument `precision` must be in the range [0, 1]. "
f"Received: precision={precision}"
)
self.precision = precision
self.num_thresholds = num_thresholds
super().__init__(
value=precision,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
recalls = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
precisions = ops.divide(
self.true_positives,
self.true_positives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
precisions, recalls, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"precision": self.precision,
}
base_config = super().get_config()
return {**base_config, **config}

@ -1,7 +1,9 @@
import numpy as np
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import metrics
from keras_core import operations as ops
from keras_core import testing
# TODO: remove reliance on this (or alternatively, turn it on by default).
@ -689,3 +691,413 @@ class RecallTest(testing.TestCase):
self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(3, r_obj.false_negatives)
class SensitivityAtSpecificityTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.SensitivityAtSpecificity(
0.4,
num_thresholds=100,
class_id=12,
name="sensitivity_at_specificity_1",
)
self.assertEqual(s_obj.name, "sensitivity_at_specificity_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.specificity, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.SensitivityAtSpecificity.from_config(
s_obj.get_config()
)
self.assertEqual(s_obj2.name, "sensitivity_at_specificity_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.specificity, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.SensitivityAtSpecificity(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.8)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.8, s_obj(y_true, y_pred))
def test_unweighted_low_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.SensitivityAtSpecificity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.675, result)
def test_invalid_specificity(self):
with self.assertRaisesRegex(
ValueError, r"`specificity` must be in the range \[0, 1\]."
):
metrics.SensitivityAtSpecificity(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
class SpecificityAtSensitivityTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.SpecificityAtSensitivity(
0.4,
num_thresholds=100,
class_id=12,
name="specificity_at_sensitivity_1",
)
self.assertEqual(s_obj.name, "specificity_at_sensitivity_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.sensitivity, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.SpecificityAtSensitivity.from_config(
s_obj.get_config()
)
self.assertEqual(s_obj2.name, "specificity_at_sensitivity_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.sensitivity, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.SpecificityAtSensitivity(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(1.0)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.2, s_obj(y_true, y_pred))
def test_unweighted_low_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.SpecificityAtSensitivity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.4, result)
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`sensitivity` must be in the range \[0, 1\]."
):
metrics.SpecificityAtSensitivity(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
class PrecisionAtRecallTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.PrecisionAtRecall(
0.4, num_thresholds=100, class_id=12, name="precision_at_recall_1"
)
self.assertEqual(s_obj.name, "precision_at_recall_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.recall, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config())
self.assertEqual(s_obj2.name, "precision_at_recall_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.recall, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.PrecisionAtRecall(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_recall(self):
s_obj = metrics.PrecisionAtRecall(0.8)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# For 0.5 < decision threshold < 0.6.
self.assertAlmostEqual(2.0 / 3, s_obj(y_true, y_pred))
def test_unweighted_low_recall(self):
s_obj = metrics.PrecisionAtRecall(0.6)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# For 0.2 < decision threshold < 0.5.
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.PrecisionAtRecall(0.6, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
# For 0.2 < decision threshold < 0.5.
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.PrecisionAtRecall(7.0 / 8)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
# For 0.0 < decision threshold < 0.2.
self.assertAlmostEqual(0.7, result)
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`recall` must be in the range \[0, 1\]."
):
metrics.PrecisionAtRecall(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.RecallAtPrecision(
0.4, num_thresholds=100, class_id=12, name="recall_at_precision_1"
)
self.assertEqual(s_obj.name, "recall_at_precision_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.precision, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.RecallAtPrecision.from_config(s_obj.get_config())
self.assertEqual(s_obj2.name, "recall_at_precision_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.precision, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.RecallAtPrecision(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_precision(self):
s_obj = metrics.RecallAtPrecision(0.75)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The precision 0.75 can be reached at thresholds 0.4<=t<0.45.
self.assertAlmostEqual(0.5, s_obj(y_true, y_pred))
def test_unweighted_low_precision(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3, class_id=2)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 2, 0, 0, 0, 2, 2, 0, 2, 2, 0, 2]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.RecallAtPrecision(0.75)
pred_values = [0.1, 0.2, 0.3, 0.5, 0.6, 0.9, 0.9]
label_values = [0, 1, 0, 0, 0, 1, 1]
weight_values = [1, 2, 1, 2, 1, 2, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.6, result)
def test_unachievable_precision(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3)
pred_values = [0.1, 0.2, 0.3, 0.9]
label_values = [1, 1, 0, 0]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The highest possible precision is 1/2 which is below the required
# value, expect 0 recall.
self.assertAlmostEqual(0, s_obj(y_true, y_pred))
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`precision` must be in the range \[0, 1\]."
):
metrics.RecallAtPrecision(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.RecallAtPrecision(0.4, num_thresholds=-1)

@ -74,8 +74,7 @@ def _update_confusion_matrix_variables_optimized(
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
Given a prediction value p, we can map it to its bucket by
bucket_index(p) = floor( p * (num_thresholds - 1) )
so we can use ops.unsorted_segmenet_sum() to update the buckets in one
pass.
so we can use ops.segment_sum() to update the buckets in one pass.
Consider following example:
y_true = [0, 0, 1, 1]
@ -93,7 +92,7 @@ def _update_confusion_matrix_variables_optimized(
# Note the second item "1.0" is floored to 0, since the value need to be
# strictly larger than the bucket lower bound.
# In the implementation, we use ops.ceil() - 1 to achieve this.
tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices,
tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,
num_segments=num_thresholds)
= [1, 1, 0]
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
@ -109,35 +108,35 @@ def _update_confusion_matrix_variables_optimized(
O(T * N).
Args:
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
and corresponding variables to update as values.
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
in the range `[0, 1]`.
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
It need to be evenly distributed (the diff between each element need to
be the same).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or
flattened into a single label. When True, the valus of
`variables_to_update` must have a second dimension equal to the number
of labels in y_true and y_pred, and those tensors must not be
RaggedTensors.
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
must be either `1`, or the same as the corresponding `y_true`
dimension).
label_weights: Optional tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN
without explicit multilabel handling (i.e. when the data is to be
flattened).
thresholds_with_epsilon: Optional boolean indicating whether the leading
and tailing thresholds has any epsilon added for floating point
imprecisions. It will change how we handle the leading and tailing
bucket.
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid
keys and corresponding variables to update as values.
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
It need to be evenly distributed (the diff between each element need
to be the same).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or
flattened into a single label. When True, the valus of
`variables_to_update` must have a second dimension equal to the
number of labels in y_true and y_pred, and those tensors must not be
RaggedTensors.
sample_weights: Optional `Tensor` whose rank is either 0, or the same
rank as `y_true`, and must be broadcastable to `y_true` (i.e., all
dimensions must be either `1`, or the same as the corresponding
`y_true` dimension).
label_weights: Optional tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN
without explicit multilabel handling (i.e. when the data is to be
flattened).
thresholds_with_epsilon: Optional boolean indicating whether the leading
and tailing thresholds has any epsilon added for floating point
imprecisions. It will change how we handle the leading and tailing
bucket.
"""
num_thresholds = thresholds.shape.as_list()[0]
num_thresholds = ops.shape(thresholds)[0]
if sample_weights is None:
sample_weights = 1.0
@ -160,7 +159,7 @@ def _update_confusion_matrix_variables_optimized(
# We shouldn't need this, but in case there are predict value that is out of
# the range of [0.0, 1.0]
y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0)
y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
if not multi_label:
@ -198,7 +197,7 @@ def _update_confusion_matrix_variables_optimized(
label_and_bucket_index[0],
label_and_bucket_index[1],
)
return ops.unsorted_segmenet_sum(
return ops.segment_sum(
data=label,
segment_ids=bucket_index,
num_segments=num_thresholds,
@ -210,21 +209,21 @@ def _update_confusion_matrix_variables_optimized(
fp_bucket_v = ops.vectorized_map(
gather_bucket, (false_labels, bucket_indices), warn=False
)
tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1))
fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1))
tp = ops.transpose(ops.cumsum(ops.flip(tp_bucket_v), axis=1))
fp = ops.transpose(ops.cumsum(ops.flip(fp_bucket_v), axis=1))
else:
tp_bucket_v = ops.unsorted_segmenet_sum(
tp_bucket_v = ops.segment_sum(
data=true_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
fp_bucket_v = ops.unsorted_segmenet_sum(
fp_bucket_v = ops.segment_sum(
data=false_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
tp = ops.cumsum(tp_bucket_v, reverse=True)
fp = ops.cumsum(fp_bucket_v, reverse=True)
tp = ops.cumsum(ops.flip(tp_bucket_v))
fp = ops.cumsum(ops.flip(fp_bucket_v))
# fn = sum(true_labels) - tp
# tn = sum(false_labels) - fp

@ -1,7 +1,6 @@
"""
segment_sum
top_k
in_top_k
"""
from keras_core import backend
@ -10,14 +9,16 @@ from keras_core.operations.operation import Operation
class SegmentSum(Operation):
def call(self, x, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
def call(self, data, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
def segment_sum(x, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((x,)):
return SegmentSum().symbolic_call(x, segment_ids, num_segments, sorted)
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((data,)):
return SegmentSum().symbolic_call(
data, segment_ids, num_segments, sorted
)
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
class TopK(Operation):

@ -1881,16 +1881,19 @@ def matmul(x1, x2):
class Max(Operation):
def __init__(self, axis=None, keepdims=False):
def __init__(self, axis=None, keepdims=False, initial=None):
super().__init__()
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
self.initial = initial
def call(self, x):
return backend.numpy.max(x, axis=self.axis, keepdims=self.keepdims)
return backend.numpy.max(
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
)
def compute_output_spec(self, x):
return KerasTensor(
@ -1899,10 +1902,12 @@ class Max(Operation):
)
def max(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False, initial=None):
if any_symbolic_tensors((x,)):
return Max(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.numpy.max(x, axis=axis, keepdims=keepdims)
return Max(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
x
)
return backend.numpy.max(x, axis=axis, keepdims=keepdims, initial=initial)
class Maximum(Operation):
@ -1961,15 +1966,18 @@ def meshgrid(*x, indexing="xy"):
class Min(Operation):
def __init__(self, axis=None, keepdims=False):
def __init__(self, axis=None, keepdims=False, initial=None):
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
self.initial = initial
def call(self, x):
return backend.numpy.min(x, axis=self.axis, keepdims=self.keepdims)
return backend.numpy.min(
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
)
def compute_output_spec(self, x):
return KerasTensor(
@ -1978,10 +1986,12 @@ class Min(Operation):
)
def min(x, axis=None, keepdims=False):
def min(x, axis=None, keepdims=False, initial=None):
if any_symbolic_tensors((x,)):
return Min(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.numpy.min(x, axis=axis, keepdims=keepdims)
return Min(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
x
)
return backend.numpy.min(x, axis=axis, keepdims=keepdims, initial=initial)
class Minimum(Operation):