Add 4 more Confusion Metrics (#73)

* Add recall

* Add SensitivitySpecificity metrics

* Review comments

* Add missing method from math

* Fix pytest

* Update config dict creation

* Remove sentence fragment

* Review comments

* Add special casing for min/max
This commit is contained in:
Ian Stenbit 2023-05-05 10:53:36 -06:00 committed by Francois Chollet
parent 1c994b749a
commit 55640456ff
8 changed files with 1033 additions and 65 deletions

@ -25,8 +25,8 @@ def mean(x, axis=None, keepdims=False):
return jnp.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False):
return jnp.max(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False, initial=None):
return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial)
def ones(shape, dtype="float32"):
@ -327,8 +327,8 @@ def meshgrid(*x, indexing="xy"):
return jnp.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False):
return jnp.min(x, axis=axis, keepdims=keepdims)
def min(x, axis=None, keepdims=False, initial=None):
return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial)
def minimum(x1, x2):

@ -26,7 +26,22 @@ def mean(x, axis=None, keepdims=False):
return tfnp.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False, initial=None):
# The TensorFlow numpy API implementation doesn't support `initial` so we
# handle it manually here.
if initial is not None:
return tf.math.maximum(
tfnp.max(x, axis=axis, keepdims=keepdims), initial
)
# TensorFlow returns -inf by default for an empty list, but for consistency
# with other backends and the numpy API we want to throw in this case.
tf.assert_greater(
size(x),
tf.constant(0, dtype=tf.int64),
message="Cannot compute the max of an empty tensor.",
)
return tfnp.max(x, axis=axis, keepdims=keepdims)
@ -328,7 +343,22 @@ def meshgrid(*x, indexing="xy"):
return tfnp.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False):
def min(x, axis=None, keepdims=False, initial=None):
# The TensorFlow numpy API implementation doesn't support `initial` so we
# handle it manually here.
if initial is not None:
return tf.math.minimum(
tfnp.min(x, axis=axis, keepdims=keepdims), initial
)
# TensorFlow returns inf by default for an empty list, but for consistency
# with other backends and the numpy API we want to throw in this case.
tf.assert_greater(
size(x),
tf.constant(0, dtype=tf.int64),
message="Cannot compute the min of an empty tensor.",
)
return tfnp.min(x, axis=axis, keepdims=keepdims)

@ -8,7 +8,11 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
from keras_core.metrics.confusion_metrics import FalseNegatives
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.confusion_metrics import Precision
from keras_core.metrics.confusion_metrics import PrecisionAtRecall
from keras_core.metrics.confusion_metrics import Recall
from keras_core.metrics.confusion_metrics import RecallAtPrecision
from keras_core.metrics.confusion_metrics import SensitivityAtSpecificity
from keras_core.metrics.confusion_metrics import SpecificityAtSensitivity
from keras_core.metrics.confusion_metrics import TrueNegatives
from keras_core.metrics.confusion_metrics import TruePositives
from keras_core.metrics.hinge_metrics import CategoricalHinge
@ -40,7 +44,11 @@ ALL_OBJECTS = {
FalseNegatives,
FalsePositives,
Precision,
PrecisionAtRecall,
Recall,
RecallAtPrecision,
SensitivityAtSpecificity,
SpecificityAtSensitivity,
TrueNegatives,
TruePositives,
# Hinge

@ -69,7 +69,7 @@ class _ConfusionMatrixConditionCount(Metric):
def get_config(self):
config = {"thresholds": self.init_thresholds}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
@keras_core_export("keras_core.metrics.FalsePositives")
@ -402,7 +402,7 @@ class Precision(Metric):
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
@keras_core_export("keras_core.metrics.Recall")
@ -543,4 +543,512 @@ class Recall(Metric):
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
return {**base_config, **config}
class SensitivitySpecificityBase(Metric):
"""Abstract base class for computing sensitivity and specificity.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
"""
def __init__(
self, value, num_thresholds=200, class_id=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
if num_thresholds <= 0:
raise ValueError(
"Argument `num_thresholds` must be an integer > 0. "
f"Received: num_thresholds={num_thresholds}"
)
self.value = value
self.class_id = class_id
# Compute `num_thresholds` thresholds in [0, 1]
if num_thresholds == 1:
self.thresholds = [0.5]
self._thresholds_distributed_evenly = False
else:
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)
]
self.thresholds = [0.0] + thresholds + [1.0]
self._thresholds_distributed_evenly = True
self.true_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_positives",
)
self.false_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_positives",
)
self.true_negatives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_negatives",
)
self.false_negatives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_negatives",
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
},
y_true,
y_pred,
thresholds=self.thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
class_id=self.class_id,
sample_weight=sample_weight,
)
def reset_state(self):
num_thresholds = len(self.thresholds)
self.true_positives.assign(ops.zeros((num_thresholds,)))
self.false_positives.assign(ops.zeros((num_thresholds,)))
self.true_negatives.assign(ops.zeros((num_thresholds,)))
self.false_negatives.assign(ops.zeros((num_thresholds,)))
def get_config(self):
config = {"class_id": self.class_id}
base_config = super().get_config()
return {**base_config, **config}
def _find_max_under_constraint(self, constrained, dependent, predicate):
"""Returns the maximum of dependent_statistic that satisfies the
constraint.
Args:
constrained: Over these values the constraint is specified. A rank-1
tensor.
dependent: From these values the maximum that satiesfies the
constraint is selected. Values in this tensor and in
`constrained` are linked by having the same threshold at each
position, hence this tensor must have the same shape.
predicate: A binary boolean functor to be applied to arguments
`constrained` and `self.value`, e.g. `ops.greater`.
Returns:
maximal dependent value, if no value satisfies the constraint 0.0.
"""
feasible = ops.array(ops.nonzero(predicate(constrained, self.value)))
print(feasible)
feasible_exists = ops.greater(ops.size(feasible), 0)
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
return ops.where(feasible_exists, max_dependent, 0.0)
@keras_core_export("keras_core.metrics.SensitivityAtSpecificity")
class SensitivityAtSpecificity(SensitivitySpecificityBase):
"""Computes best sensitivity where specificity is >= specified value.
`Sensitivity` measures the proportion of actual positives that are correctly
identified as such `(tp / (tp + fn))`.
`Specificity` measures the proportion of actual negatives that are correctly
identified as such `(tn / (tn + fp))`.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the sensitivity at the given specificity. The threshold for the
given specificity value is computed and used to evaluate the corresponding
sensitivity.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
specificity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given specificity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.SensitivityAtSpecificity(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[1, 1, 2, 2, 1])
>>> m.result()
0.333333
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.SensitivityAtSpecificity()])
```
"""
def __init__(
self,
specificity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if specificity < 0 or specificity > 1:
raise ValueError(
"Argument `specificity` must be in the range [0, 1]. "
f"Received: specificity={specificity}"
)
self.specificity = specificity
self.num_thresholds = num_thresholds
super().__init__(
specificity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
sensitivities = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
specificities = ops.divide(
self.true_negatives,
self.true_negatives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
specificities, sensitivities, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"specificity": self.specificity,
}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.SpecificityAtSensitivity")
class SpecificityAtSensitivity(SensitivitySpecificityBase):
"""Computes best specificity where sensitivity is >= specified value.
`Sensitivity` measures the proportion of actual positives that are correctly
identified as such `(tp / (tp + fn))`.
`Specificity` measures the proportion of actual negatives that are correctly
identified as such `(tn / (tn + fp))`.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the specificity at the given sensitivity. The threshold for the
given sensitivity value is computed and used to evaluate the corresponding
specificity.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
sensitivity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given sensitivity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.SpecificityAtSensitivity(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.66666667
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[1, 1, 2, 2, 2])
>>> m.result()
0.5
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.SpecificityAtSensitivity()])
```
"""
def __init__(
self,
sensitivity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if sensitivity < 0 or sensitivity > 1:
raise ValueError(
"Argument `sensitivity` must be in the range [0, 1]. "
f"Received: sensitivity={sensitivity}"
)
self.sensitivity = sensitivity
self.num_thresholds = num_thresholds
super().__init__(
sensitivity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
sensitivities = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
specificities = ops.divide(
self.true_negatives,
self.true_negatives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
sensitivities, specificities, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"sensitivity": self.sensitivity,
}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.PrecisionAtRecall")
class PrecisionAtRecall(SensitivitySpecificityBase):
"""Computes best precision where recall is >= specified value.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the precision at the given recall. The threshold for the given
recall value is computed and used to evaluate the corresponding precision.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
Args:
recall: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given recall.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.PrecisionAtRecall(0.5)
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
... sample_weight=[2, 2, 2, 1, 1])
>>> m.result()
0.33333333
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.PrecisionAtRecall(recall=0.8)])
```
"""
def __init__(
self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
):
if recall < 0 or recall > 1:
raise ValueError(
"Argument `recall` must be in the range [0, 1]. "
f"Received: recall={recall}"
)
self.recall = recall
self.num_thresholds = num_thresholds
super().__init__(
value=recall,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
recalls = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
precisions = ops.divide(
self.true_positives,
self.true_positives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
recalls, precisions, ops.greater_equal
)
def get_config(self):
config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.RecallAtPrecision")
class RecallAtPrecision(SensitivitySpecificityBase):
"""Computes best recall where precision is >= specified value.
For a given score-label-distribution the required precision might not
be achievable, in this case 0.0 is returned as recall.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the recall at the given precision. The threshold for the given
precision value is computed and used to evaluate the corresponding recall.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold
predictions, and computing the fraction of them for which `class_id` is
indeed a correct label.
Args:
precision: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds
to use for matching the given precision.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.RecallAtPrecision(0.8)
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> m.result()
0.5
>>> m.reset_state()
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[1, 0, 0, 1])
>>> m.result()
1.0
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.RecallAtPrecision(precision=0.8)])
```
"""
def __init__(
self,
precision,
num_thresholds=200,
class_id=None,
name=None,
dtype=None,
):
if precision < 0 or precision > 1:
raise ValueError(
"Argument `precision` must be in the range [0, 1]. "
f"Received: precision={precision}"
)
self.precision = precision
self.num_thresholds = num_thresholds
super().__init__(
value=precision,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype,
)
def result(self):
recalls = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + backend.epsilon(),
)
precisions = ops.divide(
self.true_positives,
self.true_positives + self.false_positives + backend.epsilon(),
)
return self._find_max_under_constraint(
precisions, recalls, ops.greater_equal
)
def get_config(self):
config = {
"num_thresholds": self.num_thresholds,
"precision": self.precision,
}
base_config = super().get_config()
return {**base_config, **config}

@ -1,7 +1,9 @@
import numpy as np
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import metrics
from keras_core import operations as ops
from keras_core import testing
# TODO: remove reliance on this (or alternatively, turn it on by default).
@ -689,3 +691,413 @@ class RecallTest(testing.TestCase):
self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))
self.assertAlmostEqual(1, r_obj.true_positives)
self.assertAlmostEqual(3, r_obj.false_negatives)
class SensitivityAtSpecificityTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.SensitivityAtSpecificity(
0.4,
num_thresholds=100,
class_id=12,
name="sensitivity_at_specificity_1",
)
self.assertEqual(s_obj.name, "sensitivity_at_specificity_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.specificity, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.SensitivityAtSpecificity.from_config(
s_obj.get_config()
)
self.assertEqual(s_obj2.name, "sensitivity_at_specificity_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.specificity, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.SensitivityAtSpecificity(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.8)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.8, s_obj(y_true, y_pred))
def test_unweighted_low_specificity(self):
s_obj = metrics.SensitivityAtSpecificity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.SensitivityAtSpecificity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.675, result)
def test_invalid_specificity(self):
with self.assertRaisesRegex(
ValueError, r"`specificity` must be in the range \[0, 1\]."
):
metrics.SensitivityAtSpecificity(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
class SpecificityAtSensitivityTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.SpecificityAtSensitivity(
0.4,
num_thresholds=100,
class_id=12,
name="specificity_at_sensitivity_1",
)
self.assertEqual(s_obj.name, "specificity_at_sensitivity_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.sensitivity, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.SpecificityAtSensitivity.from_config(
s_obj.get_config()
)
self.assertEqual(s_obj2.name, "specificity_at_sensitivity_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.sensitivity, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.SpecificityAtSensitivity(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(1.0)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.2, s_obj(y_true, y_pred))
def test_unweighted_low_sensitivity(self):
s_obj = metrics.SpecificityAtSensitivity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.SpecificityAtSensitivity(0.4)
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.4, result)
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`sensitivity` must be in the range \[0, 1\]."
):
metrics.SpecificityAtSensitivity(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
class PrecisionAtRecallTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.PrecisionAtRecall(
0.4, num_thresholds=100, class_id=12, name="precision_at_recall_1"
)
self.assertEqual(s_obj.name, "precision_at_recall_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.recall, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config())
self.assertEqual(s_obj2.name, "precision_at_recall_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.recall, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.PrecisionAtRecall(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_recall(self):
s_obj = metrics.PrecisionAtRecall(0.8)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# For 0.5 < decision threshold < 0.6.
self.assertAlmostEqual(2.0 / 3, s_obj(y_true, y_pred))
def test_unweighted_low_recall(self):
s_obj = metrics.PrecisionAtRecall(0.6)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# For 0.2 < decision threshold < 0.5.
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.PrecisionAtRecall(0.6, class_id=2)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
# For 0.2 < decision threshold < 0.5.
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.PrecisionAtRecall(7.0 / 8)
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
# For 0.0 < decision threshold < 0.2.
self.assertAlmostEqual(0.7, result)
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`recall` must be in the range \[0, 1\]."
):
metrics.PrecisionAtRecall(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
def test_config(self):
s_obj = metrics.RecallAtPrecision(
0.4, num_thresholds=100, class_id=12, name="recall_at_precision_1"
)
self.assertEqual(s_obj.name, "recall_at_precision_1")
self.assertLen(s_obj.variables, 4)
self.assertEqual(s_obj.precision, 0.4)
self.assertEqual(s_obj.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
# Check save and restore config
s_obj2 = metrics.RecallAtPrecision.from_config(s_obj.get_config())
self.assertEqual(s_obj2.name, "recall_at_precision_1")
self.assertLen(s_obj2.variables, 4)
self.assertEqual(s_obj2.precision, 0.4)
self.assertEqual(s_obj2.num_thresholds, 100)
self.assertEqual(s_obj.class_id, 12)
def test_unweighted_all_correct(self):
s_obj = metrics.RecallAtPrecision(0.7)
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs, dtype="float32")
y_true = np.array(inputs)
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
def test_unweighted_high_precision(self):
s_obj = metrics.RecallAtPrecision(0.75)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The precision 0.75 can be reached at thresholds 0.4<=t<0.45.
self.assertAlmostEqual(0.5, s_obj(y_true, y_pred))
def test_unweighted_low_precision(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
def test_unweighted_class_id(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3, class_id=2)
pred_values = [
0.05,
0.1,
0.2,
0.3,
0.3,
0.35,
0.4,
0.45,
0.5,
0.6,
0.9,
0.95,
]
label_values = [0, 2, 0, 0, 0, 2, 2, 0, 2, 2, 0, 2]
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
# 1].
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
@parameterized.parameters(["bool", "int32", "float32"])
def test_weighted(self, label_dtype):
s_obj = metrics.RecallAtPrecision(0.75)
pred_values = [0.1, 0.2, 0.3, 0.5, 0.6, 0.9, 0.9]
label_values = [0, 1, 0, 0, 0, 1, 1]
weight_values = [1, 2, 1, 2, 1, 2, 1]
y_pred = np.array(pred_values, dtype="float32")
y_true = ops.cast(label_values, dtype=label_dtype)
weights = np.array(weight_values)
result = s_obj(y_true, y_pred, sample_weight=weights)
self.assertAlmostEqual(0.6, result)
def test_unachievable_precision(self):
s_obj = metrics.RecallAtPrecision(2.0 / 3)
pred_values = [0.1, 0.2, 0.3, 0.9]
label_values = [1, 1, 0, 0]
y_pred = np.array(pred_values, dtype="float32")
y_true = np.array(label_values)
# The highest possible precision is 1/2 which is below the required
# value, expect 0 recall.
self.assertAlmostEqual(0, s_obj(y_true, y_pred))
def test_invalid_sensitivity(self):
with self.assertRaisesRegex(
ValueError, r"`precision` must be in the range \[0, 1\]."
):
metrics.RecallAtPrecision(-1)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.RecallAtPrecision(0.4, num_thresholds=-1)

@ -74,8 +74,7 @@ def _update_confusion_matrix_variables_optimized(
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
Given a prediction value p, we can map it to its bucket by
bucket_index(p) = floor( p * (num_thresholds - 1) )
so we can use ops.unsorted_segmenet_sum() to update the buckets in one
pass.
so we can use ops.segment_sum() to update the buckets in one pass.
Consider following example:
y_true = [0, 0, 1, 1]
@ -93,7 +92,7 @@ def _update_confusion_matrix_variables_optimized(
# Note the second item "1.0" is floored to 0, since the value need to be
# strictly larger than the bucket lower bound.
# In the implementation, we use ops.ceil() - 1 to achieve this.
tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices,
tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,
num_segments=num_thresholds)
= [1, 1, 0]
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
@ -109,25 +108,25 @@ def _update_confusion_matrix_variables_optimized(
O(T * N).
Args:
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
and corresponding variables to update as values.
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid
keys and corresponding variables to update as values.
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
in the range `[0, 1]`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
It need to be evenly distributed (the diff between each element need to
be the same).
It need to be evenly distributed (the diff between each element need
to be the same).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or
flattened into a single label. When True, the valus of
`variables_to_update` must have a second dimension equal to the number
of labels in y_true and y_pred, and those tensors must not be
`variables_to_update` must have a second dimension equal to the
number of labels in y_true and y_pred, and those tensors must not be
RaggedTensors.
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
must be either `1`, or the same as the corresponding `y_true`
dimension).
sample_weights: Optional `Tensor` whose rank is either 0, or the same
rank as `y_true`, and must be broadcastable to `y_true` (i.e., all
dimensions must be either `1`, or the same as the corresponding
`y_true` dimension).
label_weights: Optional tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN
without explicit multilabel handling (i.e. when the data is to be
@ -137,7 +136,7 @@ def _update_confusion_matrix_variables_optimized(
imprecisions. It will change how we handle the leading and tailing
bucket.
"""
num_thresholds = thresholds.shape.as_list()[0]
num_thresholds = ops.shape(thresholds)[0]
if sample_weights is None:
sample_weights = 1.0
@ -160,7 +159,7 @@ def _update_confusion_matrix_variables_optimized(
# We shouldn't need this, but in case there are predict value that is out of
# the range of [0.0, 1.0]
y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0)
y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
if not multi_label:
@ -198,7 +197,7 @@ def _update_confusion_matrix_variables_optimized(
label_and_bucket_index[0],
label_and_bucket_index[1],
)
return ops.unsorted_segmenet_sum(
return ops.segment_sum(
data=label,
segment_ids=bucket_index,
num_segments=num_thresholds,
@ -210,21 +209,21 @@ def _update_confusion_matrix_variables_optimized(
fp_bucket_v = ops.vectorized_map(
gather_bucket, (false_labels, bucket_indices), warn=False
)
tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1))
fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1))
tp = ops.transpose(ops.cumsum(ops.flip(tp_bucket_v), axis=1))
fp = ops.transpose(ops.cumsum(ops.flip(fp_bucket_v), axis=1))
else:
tp_bucket_v = ops.unsorted_segmenet_sum(
tp_bucket_v = ops.segment_sum(
data=true_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
fp_bucket_v = ops.unsorted_segmenet_sum(
fp_bucket_v = ops.segment_sum(
data=false_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
tp = ops.cumsum(tp_bucket_v, reverse=True)
fp = ops.cumsum(fp_bucket_v, reverse=True)
tp = ops.cumsum(ops.flip(tp_bucket_v))
fp = ops.cumsum(ops.flip(fp_bucket_v))
# fn = sum(true_labels) - tp
# tn = sum(false_labels) - fp

@ -1,7 +1,6 @@
"""
segment_sum
top_k
in_top_k
"""
from keras_core import backend
@ -10,14 +9,16 @@ from keras_core.operations.operation import Operation
class SegmentSum(Operation):
def call(self, x, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
def call(self, data, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
def segment_sum(x, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((x,)):
return SegmentSum().symbolic_call(x, segment_ids, num_segments, sorted)
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((data,)):
return SegmentSum().symbolic_call(
data, segment_ids, num_segments, sorted
)
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
class TopK(Operation):

@ -1881,16 +1881,19 @@ def matmul(x1, x2):
class Max(Operation):
def __init__(self, axis=None, keepdims=False):
def __init__(self, axis=None, keepdims=False, initial=None):
super().__init__()
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
self.initial = initial
def call(self, x):
return backend.numpy.max(x, axis=self.axis, keepdims=self.keepdims)
return backend.numpy.max(
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
)
def compute_output_spec(self, x):
return KerasTensor(
@ -1899,10 +1902,12 @@ class Max(Operation):
)
def max(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False, initial=None):
if any_symbolic_tensors((x,)):
return Max(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.numpy.max(x, axis=axis, keepdims=keepdims)
return Max(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
x
)
return backend.numpy.max(x, axis=axis, keepdims=keepdims, initial=initial)
class Maximum(Operation):
@ -1961,15 +1966,18 @@ def meshgrid(*x, indexing="xy"):
class Min(Operation):
def __init__(self, axis=None, keepdims=False):
def __init__(self, axis=None, keepdims=False, initial=None):
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
self.initial = initial
def call(self, x):
return backend.numpy.min(x, axis=self.axis, keepdims=self.keepdims)
return backend.numpy.min(
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
)
def compute_output_spec(self, x):
return KerasTensor(
@ -1978,10 +1986,12 @@ class Min(Operation):
)
def min(x, axis=None, keepdims=False):
def min(x, axis=None, keepdims=False, initial=None):
if any_symbolic_tensors((x,)):
return Min(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.numpy.min(x, axis=axis, keepdims=keepdims)
return Min(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
x
)
return backend.numpy.min(x, axis=axis, keepdims=keepdims, initial=initial)
class Minimum(Operation):