Add precision metric (#60)

* Add precision metric

* Implement tf nan_to_num

* Add epsilon to divisor

* Formatting
This commit is contained in:
Ian Stenbit 2023-04-30 10:10:07 -06:00 committed by Francois Chollet
parent 018d9fe633
commit a503324861
12 changed files with 423 additions and 5 deletions

@ -7,9 +7,9 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
)
def top_k(x, k, sorted=False):
if sorted:
def top_k(x, k, sorted=True):
if not sorted:
return ValueError(
"Jax backend does not support `sorted=True` for `ops.top_k`"
"Jax backend does not support `sorted=False` for `ops.top_k`"
)
return jax.lax.top_k(x, k)

@ -339,6 +339,10 @@ def moveaxis(x, source, destination):
return jnp.moveaxis(x, source=source, destination=destination)
def nan_to_num(x):
return jnp.nan_to_num(x)
def ndim(x):
return jnp.ndim(x)

@ -8,5 +8,5 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)
def top_k(x, k, sorted=False):
def top_k(x, k, sorted=True):
return tf.math.top_k(x, k, sorted=sorted)

@ -340,6 +340,19 @@ def moveaxis(x, source, destination):
return tfnp.moveaxis(x, source=source, destination=destination)
def nan_to_num(x):
# Replace NaN with 0
x = tf.where(tf.math.is_nan(x), 0, x)
# Replace positive infinitiy with dtype.max
x = tf.where(tf.math.is_inf(x) & (x > 0), x.dtype.max, x)
# Replace negative infinity with dtype.min
x = tf.where(tf.math.is_inf(x) & (x < 0), x.dtype.min, x)
return x
def ndim(x):
return tfnp.ndim(x)

@ -1,6 +1,7 @@
from keras_core.api_export import keras_core_export
from keras_core.metrics.confusion_metrics import FalseNegatives
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.confusion_metrics import Precision
from keras_core.metrics.confusion_metrics import TrueNegatives
from keras_core.metrics.confusion_metrics import TruePositives
from keras_core.metrics.hinge_metrics import CategoricalHinge

@ -1,8 +1,10 @@
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.metrics import metrics_utils
from keras_core.metrics.metric import Metric
from keras_core.utils.python_utils import to_list
class _ConfusionMatrixConditionCount(Metric):
@ -244,3 +246,160 @@ class TruePositives(_ConfusionMatrixConditionCount):
name=name,
dtype=dtype,
)
@keras_core_export("keras_core.metrics.Precision")
class Precision(Metric):
"""Computes the precision of the predictions with respect to the labels.
The metric creates two local variables, `true_positives` and
`false_positives` that are used to compute the precision. This value is
ultimately returned as `precision`, an idempotent operation that simply
divides `true_positives` by the sum of `true_positives` and
`false_positives`.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `top_k` is set, we'll calculate precision as how often on average a class
among the top-k classes with the highest predicted values of a batch entry
is correct and can be found in the label for that entry.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold and/or in
the top-k highest predictions, and computing the fraction of them for which
`class_id` is indeed a correct label.
Args:
thresholds: (Optional) A float value, or a Python list/tuple of float
threshold values in [0, 1]. A threshold is compared with prediction
values to determine the truth value of predictions (i.e., above the
threshold is `true`, below is `false`). If used with a loss function
that sets `from_logits=True` (i.e. no sigmoid applied to
predictions), `thresholds` should be set to 0. One metric value is
generated for each threshold value. If neither thresholds nor top_k
are set, the default is to calculate precision with
`thresholds=0.5`.
top_k: (Optional) Unset by default. An int value specifying the top-k
predictions to consider when calculating precision.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.Precision()
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
>>> m.result()
0.6666667
>>> m.reset_state()
>>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
>>> m.result()
1.0
>>> # With top_k=2, it will calculate precision over y_true[:2]
>>> # and y_pred[:2]
>>> m = keras_core.metrics.Precision(top_k=2)
>>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
>>> m.result()
0.0
>>> # With top_k=4, it will calculate precision over y_true[:4]
>>> # and y_pred[:4]
>>> m = keras_core.metrics.Precision(top_k=4)
>>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
>>> m.result()
0.5
Usage with `compile()` API:
```python
model.compile(optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.Precision()])
```
Usage with a loss with `from_logits=True`:
```python
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras_core.metrics.Precision(thresholds=0)])
```
"""
def __init__(
self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
self.init_thresholds = thresholds
self.top_k = top_k
self.class_id = class_id
default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=default_threshold
)
self._thresholds_distributed_evenly = (
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
)
self.true_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="true_positives",
)
self.false_positives = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="false_positives",
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates true positive and false positive statistics.
Args:
y_true: The ground truth values, with the same dimensions as
`y_pred`. Will be cast to `bool`.
y_pred: The predicted values. Each element must be in the range
`[0, 1]`.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
return metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
},
y_true,
y_pred,
thresholds=self.thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
top_k=self.top_k,
class_id=self.class_id,
sample_weight=sample_weight,
)
def result(self):
result = ops.divide(
self.true_positives,
self.true_positives + self.false_positives + backend.epsilon(),
)
return result[0] if len(self.thresholds) == 1 else result
def reset_state(self):
num_thresholds = len(to_list(self.thresholds))
self.true_positives.assign(ops.zeros((num_thresholds,)))
self.false_positives.assign(ops.zeros((num_thresholds,)))
def get_config(self):
config = {
"thresholds": self.init_thresholds,
"top_k": self.top_k,
"class_id": self.class_id,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

@ -368,3 +368,169 @@ class TruePositiveTest(testing.TestCase):
r"Threshold values must be in \[0, 1\]. Received: \[None\]",
):
metrics.TruePositives(thresholds=[None])
class PrecisionTest(testing.TestCase):
def test_config(self):
p_obj = metrics.Precision(
name="my_precision", thresholds=[0.4, 0.9], top_k=15, class_id=12
)
self.assertEqual(p_obj.name, "my_precision")
self.assertLen(p_obj.variables, 2)
self.assertEqual(
[v.name for v in p_obj.variables],
["true_positives", "false_positives"],
)
self.assertEqual(p_obj.thresholds, [0.4, 0.9])
self.assertEqual(p_obj.top_k, 15)
self.assertEqual(p_obj.class_id, 12)
# Check save and restore config
p_obj2 = metrics.Precision.from_config(p_obj.get_config())
self.assertEqual(p_obj2.name, "my_precision")
self.assertLen(p_obj2.variables, 2)
self.assertEqual(p_obj2.thresholds, [0.4, 0.9])
self.assertEqual(p_obj2.top_k, 15)
self.assertEqual(p_obj2.class_id, 12)
def test_unweighted(self):
p_obj = metrics.Precision()
y_pred = np.array([1, 0, 1, 0])
y_true = np.array([0, 1, 1, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(0.5, result)
def test_unweighted_all_incorrect(self):
p_obj = metrics.Precision(thresholds=[0.5])
inputs = np.random.randint(0, 2, size=(100, 1))
y_pred = np.array(inputs)
y_true = np.array(1 - inputs)
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(0, result)
def test_weighted(self):
p_obj = metrics.Precision()
y_pred = np.array([[1, 0, 1, 0], [1, 0, 1, 0]])
y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]])
result = p_obj(
y_true,
y_pred,
sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]),
)
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
expected_precision = weighted_tp / weighted_positives
self.assertAlmostEqual(expected_precision, result)
def test_div_by_zero(self):
p_obj = metrics.Precision()
y_pred = np.array([0, 0, 0, 0])
y_true = np.array([0, 0, 0, 0])
result = p_obj(y_true, y_pred)
self.assertEqual(0, result)
def test_unweighted_with_threshold(self):
p_obj = metrics.Precision(thresholds=[0.5, 0.7])
y_pred = np.array([1, 0, 0.6, 0])
y_true = np.array([0, 1, 1, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual([0.5, 0.0], result, 0)
def test_weighted_with_threshold(self):
p_obj = metrics.Precision(thresholds=[0.5, 1.0])
y_true = np.array([[0, 1], [1, 0]])
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
weights = np.array([[4, 0], [3, 1]], dtype="float32")
result = p_obj(y_true, y_pred, sample_weight=weights)
weighted_tp = 0 + 3.0
weighted_positives = (0 + 3.0) + (4.0 + 0.0)
expected_precision = weighted_tp / weighted_positives
self.assertAlmostEqual([expected_precision, 0], result, 1e-3)
def test_multiple_updates(self):
p_obj = metrics.Precision(thresholds=[0.5, 1.0])
y_true = np.array([[0, 1], [1, 0]])
y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32")
weights = np.array([[4, 0], [3, 1]], dtype="float32")
for _ in range(2):
p_obj.update_state(y_true, y_pred, sample_weight=weights)
weighted_tp = (0 + 3.0) + (0 + 3.0)
weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + (
(0 + 3.0) + (4.0 + 0.0)
)
expected_precision = weighted_tp / weighted_positives
self.assertAlmostEqual([expected_precision, 0], p_obj.result(), 1e-3)
def test_unweighted_top_k(self):
p_obj = metrics.Precision(top_k=3)
y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1.0 / 3, result)
def test_weighted_top_k(self):
p_obj = metrics.Precision(top_k=3)
y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]])
y_true1 = np.array([[0, 1, 1, 0, 1]])
p_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]]))
y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2])
y_true2 = np.array([1, 0, 1, 1, 1])
result = p_obj(y_true2, y_pred2, sample_weight=np.array(3))
tp = (2 + 5) + (3 + 3)
predicted_positives = (1 + 2 + 5) + (3 + 3 + 3)
expected_precision = tp / predicted_positives
self.assertAlmostEqual(expected_precision, result)
def test_unweighted_class_id(self):
p_obj = metrics.Precision(class_id=2)
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)
y_pred = np.array([0.2, 0.1, 0, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)
y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])
y_true = np.array([0, 1, 0, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(0.5, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(1, p_obj.false_positives)
def test_unweighted_top_k_and_class_id(self):
p_obj = metrics.Precision(class_id=2, top_k=2)
y_pred = np.array([0.2, 0.6, 0.3, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)
y_pred = np.array([1, 1, 0.9, 1, 1])
y_true = np.array([0, 1, 1, 0, 0])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)
def test_unweighted_top_k_and_threshold(self):
p_obj = metrics.Precision(thresholds=0.7, top_k=2)
y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2])
y_true = np.array([0, 1, 1, 0, 1])
result = p_obj(y_true, y_pred)
self.assertAlmostEqual(1, result)
self.assertAlmostEqual(1, p_obj.true_positives)
self.assertAlmostEqual(0, p_obj.false_positives)

@ -533,7 +533,7 @@ def _filter_top_k(x, k):
Returns:
tensor with same shape and dtype as x.
"""
_, top_k_idx = ops.top_k(x, k, sorted=False)
_, top_k_idx = ops.top_k(x, k)
top_k_mask = ops.sum(
ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
)

@ -9,5 +9,6 @@ from keras_core.backend import is_tensor
from keras_core.backend import name_scope
from keras_core.backend import random
from keras_core.backend import shape
from keras_core.operations.math import * # noqa: F403
from keras_core.operations.nn import * # noqa: F403
from keras_core.operations.numpy import * # noqa: F403

@ -0,0 +1,30 @@
"""
segment_sum
top_k
"""
from keras_core import backend
from keras_core.backend import any_symbolic_tensors
from keras_core.operations.operation import Operation
class SegmentSum(Operation):
def call(self, x, segment_ids, num_segments=None, sorted=False):
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
def segment_sum(x, segment_ids, num_segments=None, sorted=False):
if any_symbolic_tensors((x,)):
return SegmentSum().symbolic_call(x, segment_ids, num_segments, sorted)
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
class TopK(Operation):
def call(self, x, k, sorted=True):
return backend.math.top_k(x, k, sorted)
def top_k(x, k, sorted=True):
if any_symbolic_tensors((x,)):
return TopK().symbolic_call(x, k, sorted)
return backend.math.top_k(x, k, sorted)

@ -0,0 +1,35 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import testing
from keras_core.backend.keras_tensor import KerasTensor
from keras_core.operations import math as kmath
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class MathOpsDynamicShapeTest(testing.TestCase):
def test_topk(self):
x = KerasTensor([None, 2, 3])
values, indices = kmath.top_k(x, k=1)
self.assertEqual(values.shape, (None, 2, 1))
self.assertEqual(indices.shape, (None, 2, 1))
class MathOpsStaticShapeTest(testing.TestCase):
def test_topk(self):
x = KerasTensor([1, 2, 3])
values, indices = kmath.top_k(x, k=1)
self.assertEqual(values.shape, (1, 2, 1))
self.assertEqual(indices.shape, (1, 2, 1))
class MathOpsCorrectnessTest(testing.TestCase):
def test_topk(self):
x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32)
values, indices = kmath.top_k(x, k=2)
self.assertAllClose(values, [4, 3])
self.assertAllClose(indices, [1, 4])

@ -2035,6 +2035,15 @@ def moveaxis(x, source, destination):
return backend.numpy.moveaxis(x, source=source, destination=destination)
class NanToNum(Operation):
def call(self, x):
return backend.numpy.nan_to_num(x)
def nan_to_num(x):
return backend.numpy.nan_to_num(x)
class Ndim(Operation):
def call(self, x):
return backend.numpy.ndim(