First confusion metric (FalsePositives) (#30)

* Added confusion metrics -- still using TF ops

* Fixed structure + tests pass for TF (still need to port to multi-backend)

* Got rid of most tf deps, still a few more to go

* Full removal of TF. Tests pass for both Jax and TF

* Full removal of TF. Tests pass for both Jax and TF

* Formatting

* Formatting

* Review comments

* More review comments + formatting
This commit is contained in:
Ian Stenbit 2023-04-22 10:10:44 -06:00 committed by Francois Chollet
parent 53a858cb7f
commit d6bcc56001
15 changed files with 836 additions and 11 deletions

@ -15,7 +15,3 @@ def get(identifier):
if identifier == "relu":
return relu
return identifier
def serialize(activation):
return activation.__name__

@ -5,6 +5,7 @@ from tensorflow import nest
from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype
from keras_core.backend.jax import math
from keras_core.backend.jax import nn
from keras_core.backend.jax import numpy
from keras_core.backend.jax import random
@ -50,6 +51,10 @@ def name_scope(name):
return jax.named_scope(name)
def vectorized_map(function, elements):
return jax.vmap(function)(elements)
class Variable(KerasVariable):
def __init__(self, value, dtype=None, trainable=True, name=None):
self.name = name or auto_name(self.__class__.__name__)

@ -0,0 +1,15 @@
import jax
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
return jax.ops.segment_sum(
data, segment_ids, num_segments, indices_are_sorted=sorted
)
def top_k(x, k, sorted=False):
if sorted:
return ValueError(
"Jax backend does not support `sorted=True` for `ops.top_k`"
)
return jax.lax.top_k(x, k)

@ -103,3 +103,7 @@ def conv_transpose(
):
# TODO: Implement `conv_transpose`.
raise NotImplementedError
def one_hot(x, num_classes, axis=-1):
return jnn.one_hot(x, num_classes, axis=axis)

@ -6,6 +6,7 @@ from keras_core.backend.common import standardize_dtype
from keras_core.backend.keras_tensor import KerasTensor
from keras_core.backend.stateless_scope import get_stateless_scope
from keras_core.backend.stateless_scope import in_stateless_scope
from keras_core.backend.tensorflow import math
from keras_core.backend.tensorflow import nn
from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random
@ -240,6 +241,10 @@ def name_scope(name):
return tf.name_scope(name)
def vectorized_map(function, elements):
return tf.vectorized_map(function, elements)
def compute_output_spec(fn, *args, **kwargs):
graph_name = auto_name("scratch_graph")
with tf.__internal__.FuncGraph(graph_name).as_default():

@ -0,0 +1,12 @@
import tensorflow as tf
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if sorted:
return tf.math.segment_sum(data, segment_ids)
else:
return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)
def top_k(x, k, sorted=False):
return tf.math.top_k(x, k, sorted=sorted)

@ -390,3 +390,7 @@ def conv_transpose(
data_format=tf_data_format,
dilations=dilation_rate,
)
def one_hot(x, num_classes, axis=-1):
return tf.one_hot(x, num_classes, axis=axis)

@ -61,15 +61,10 @@ class Dense(Layer):
# TODO
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
return {**base_config, **config}

@ -1,4 +1,5 @@
from keras_core.api_export import keras_core_export
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.metric import Metric
from keras_core.metrics.reduction_metrics import Mean
from keras_core.metrics.reduction_metrics import MeanMetricWrapper

@ -0,0 +1,115 @@
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.metrics import metrics_utils
from keras_core.metrics.metric import Metric
class _ConfusionMatrixConditionCount(Metric):
"""Calculates the number of the given confusion matrix condition.
Args:
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
conditions.
thresholds: (Optional) Defaults to 0.5. A float value or a python list /
tuple of float threshold values in [0, 1]. A threshold is compared
with prediction values to determine the truth value of predictions
(i.e., above the threshold is `true`, below is `false`). One metric
value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
def __init__(
self, confusion_matrix_cond, thresholds=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
self._confusion_matrix_cond = confusion_matrix_cond
self.init_thresholds = thresholds
self.thresholds = metrics_utils.parse_init_thresholds(
thresholds, default_threshold=0.5
)
self._thresholds_distributed_evenly = (
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
)
self.accumulator = self.add_variable(
shape=(len(self.thresholds),),
initializer=initializers.Zeros(),
name="accumulator",
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the metric statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a tensor whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
return metrics_utils.update_confusion_matrix_variables(
{self._confusion_matrix_cond: self.accumulator},
y_true,
y_pred,
thresholds=self.thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
sample_weight=sample_weight,
)
def result(self):
if len(self.thresholds) == 1:
result = self.accumulator[0]
else:
result = self.accumulator
return backend.convert_to_tensor(result)
def get_config(self):
config = {"thresholds": self.init_thresholds}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_core_export("keras_core.metrics.FalsePositives")
class FalsePositives(_ConfusionMatrixConditionCount):
"""Calculates the number of false positives.
If `sample_weight` is given, calculates the sum of the weights of
false positives. This metric creates one local variable, `accumulator`
that is used to keep track of the number of false positives.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Args:
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
list/tuple of float threshold values in [0, 1]. A threshold is
compared with prediction values to determine the truth value of
predictions (i.e., above the threshold is `true`, below is `false`).
If used with a loss function that sets `from_logits=True` (i.e. no
sigmoid applied to predictions), `thresholds` should be set to 0.
One metric value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
>>> m = keras_core.metrics.FalsePositives()
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
>>> m.result()
2.0
>>> m.reset_state()
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
>>> m.result()
1.0
"""
def __init__(self, thresholds=None, name=None, dtype=None):
super().__init__(
confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
thresholds=thresholds,
name=name,
dtype=dtype,
)

@ -0,0 +1,103 @@
import numpy as np
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import metrics
from keras_core import testing
np_config.enable_numpy_behavior()
class FalsePositivesTest(testing.TestCase):
def test_config(self):
fp_obj = metrics.FalsePositives(name="my_fp", thresholds=[0.4, 0.9])
self.assertEqual(fp_obj.name, "my_fp")
self.assertLen(fp_obj.variables, 1)
self.assertEqual(fp_obj.thresholds, [0.4, 0.9])
# Check save and restore config
fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config())
self.assertEqual(fp_obj2.name, "my_fp")
self.assertLen(fp_obj2.variables, 1)
self.assertEqual(fp_obj2.thresholds, [0.4, 0.9])
def test_unweighted(self):
fp_obj = metrics.FalsePositives()
y_true = np.array(
((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))
)
y_pred = np.array(
((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))
)
fp_obj.update_state(y_true, y_pred)
result = fp_obj.result()
self.assertAllClose(7.0, result)
def test_weighted(self):
fp_obj = metrics.FalsePositives()
y_true = np.array(
((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))
)
y_pred = np.array(
((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))
)
sample_weight = np.array((1.0, 1.5, 2.0, 2.5))
result = fp_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(14.0, result)
def test_unweighted_with_thresholds(self):
fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])
y_pred = np.array(
(
(0.9, 0.2, 0.8, 0.1),
(0.2, 0.9, 0.7, 0.6),
(0.1, 0.2, 0.4, 0.3),
(0, 1, 0.7, 0.3),
)
)
y_true = np.array(
((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))
)
fp_obj.update_state(y_true, y_pred)
result = fp_obj.result()
self.assertAllClose([7.0, 4.0, 2.0], result)
def test_weighted_with_thresholds(self):
fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])
y_pred = np.array(
(
(0.9, 0.2, 0.8, 0.1),
(0.2, 0.9, 0.7, 0.6),
(0.1, 0.2, 0.4, 0.3),
(0, 1, 0.7, 0.3),
)
)
y_true = np.array(
((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))
)
sample_weight = (
(1.0, 2.0, 3.0, 5.0),
(7.0, 11.0, 13.0, 17.0),
(19.0, 23.0, 29.0, 31.0),
(5.0, 15.0, 10.0, 0),
)
result = fp_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose([125.0, 42.0, 12.0], result)
def test_threshold_limit(self):
with self.assertRaisesRegex(
ValueError,
r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]",
):
metrics.FalsePositives(thresholds=[-1, 0.5, 2])
with self.assertRaisesRegex(
ValueError,
r"Threshold values must be in \[0, 1\]. Received: \[None\]",
):
metrics.FalsePositives(thresholds=[None])

@ -0,0 +1,540 @@
from enum import Enum
import numpy as np
from keras_core import backend
from keras_core import operations as ops
from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.utils.python_utils import to_list
NEG_INF = -1e10
def assert_thresholds_range(thresholds):
if thresholds is not None:
invalid_thresholds = [
t for t in thresholds if t is None or t < 0 or t > 1
]
if invalid_thresholds:
raise ValueError(
"Threshold values must be in [0, 1]. "
f"Received: {invalid_thresholds}"
)
def parse_init_thresholds(thresholds, default_threshold=0.5):
if thresholds is not None:
assert_thresholds_range(to_list(thresholds))
thresholds = to_list(
default_threshold if thresholds is None else thresholds
)
return thresholds
class ConfusionMatrix(Enum):
TRUE_POSITIVES = "tp"
FALSE_POSITIVES = "fp"
TRUE_NEGATIVES = "tn"
FALSE_NEGATIVES = "fn"
def _update_confusion_matrix_variables_optimized(
variables_to_update,
y_true,
y_pred,
thresholds,
multi_label=False,
sample_weights=None,
label_weights=None,
thresholds_with_epsilon=False,
):
"""Update confusion matrix variables with memory efficient alternative.
Note that the thresholds need to be evenly distributed within the list, eg,
the diff between consecutive elements are the same.
To compute TP/FP/TN/FN, we are measuring a binary classifier
C(t) = (predictions >= t)
at each threshold 't'. So we have
TP(t) = sum( C(t) * true_labels )
FP(t) = sum( C(t) * false_labels )
But, computing C(t) requires computation for each t. To make it fast,
observe that C(t) is a cumulative integral, and so if we have
thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
where n = num_thresholds, and if we can compute the bucket function
B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
then we get
C(t_i) = sum( B(j), j >= i )
which is the reversed cumulative sum in ops.cumsum().
We can compute B(i) efficiently by taking advantage of the fact that
our thresholds are evenly distributed, in that
width = 1.0 / (num_thresholds - 1)
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
Given a prediction value p, we can map it to its bucket by
bucket_index(p) = floor( p * (num_thresholds - 1) )
so we can use ops.unsorted_segmenet_sum() to update the buckets in one
pass.
Consider following example:
y_true = [0, 0, 1, 1]
y_pred = [0.1, 0.5, 0.3, 0.9]
thresholds = [0.0, 0.5, 1.0]
num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
bucket_index(y_pred) = ops.floor(y_pred * num_buckets)
= ops.floor([0.2, 1.0, 0.6, 1.8])
= [0, 0, 0, 1]
# The meaning of this bucket is that if any of the label is true,
# then 1 will be added to the corresponding bucket with the index.
# Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
# label for 1.8 is true, then 1 will be added to bucket 1.
#
# Note the second item "1.0" is floored to 0, since the value need to be
# strictly larger than the bucket lower bound.
# In the implementation, we use ops.ceil() - 1 to achieve this.
tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices,
num_segments=num_thresholds)
= [1, 1, 0]
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
# 0, and 1 value contributed by bucket 1. When we aggregate them to
# together, the result become [a + b + c, b + c, c], since large thresholds
# will always contribute to the value for smaller thresholds.
true_positive = ops.cumsum(tp_bucket_value, reverse=True)
= [2, 1, 0]
This implementation exhibits a run time and space complexity of O(T + N),
where T is the number of thresholds and N is the size of predictions.
Metrics that rely on standard implementation instead exhibit a complexity of
O(T * N).
Args:
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
and corresponding variables to update as values.
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
in the range `[0, 1]`.
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
It need to be evenly distributed (the diff between each element need to
be the same).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or
flattened into a single label. When True, the valus of
`variables_to_update` must have a second dimension equal to the number
of labels in y_true and y_pred, and those tensors must not be
RaggedTensors.
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
must be either `1`, or the same as the corresponding `y_true`
dimension).
label_weights: Optional tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN
without explicit multilabel handling (i.e. when the data is to be
flattened).
thresholds_with_epsilon: Optional boolean indicating whether the leading
and tailing thresholds has any epsilon added for floating point
imprecisions. It will change how we handle the leading and tailing
bucket.
"""
num_thresholds = thresholds.shape.as_list()[0]
if sample_weights is None:
sample_weights = 1.0
else:
sample_weights = ops.broadcast_to(
ops.cast(sample_weights, dtype=y_pred.dtype), y_pred.shape
)
if not multi_label:
sample_weights = ops.reshape(sample_weights, [-1])
if label_weights is None:
label_weights = 1.0
else:
label_weights = ops.expand_dims(label_weights, 0)
label_weights = ops.broadcast_to(label_weights, y_pred.shape)
if not multi_label:
label_weights = ops.reshape(label_weights, [-1])
weights = ops.cast(
ops.multiply(sample_weights, label_weights), y_true.dtype
)
# We shouldn't need this, but in case there are predict value that is out of
# the range of [0.0, 1.0]
y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0)
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
if not multi_label:
y_true = ops.reshape(y_true, [-1])
y_pred = ops.reshape(y_pred, [-1])
true_labels = ops.multiply(y_true, weights)
false_labels = ops.multiply((1.0 - y_true), weights)
# Compute the bucket indices for each prediction value.
# Since the predict value has to be strictly greater than the thresholds,
# eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
# We have to use math.ceil(val) - 1 for the bucket.
bucket_indices = ops.ceil(y_pred * (num_thresholds - 1)) - 1
if thresholds_with_epsilon:
# In this case, the first bucket should actually take into account since
# the any prediction between [0.0, 1.0] should be larger than the first
# threshold. We change the bucket value from -1 to 0.
bucket_indices = ops.relu(bucket_indices)
bucket_indices = ops.cast(bucket_indices, "int32")
if multi_label:
# We need to run bucket segment sum for each of the label class. In the
# multi_label case, the rank of the label is 2. We first transpose it so
# that the label dim becomes the first and we can parallel run though
# them.
true_labels = ops.transpose(true_labels)
false_labels = ops.transpose(false_labels)
bucket_indices = ops.transpose(bucket_indices)
def gather_bucket(label_and_bucket_index):
label, bucket_index = (
label_and_bucket_index[0],
label_and_bucket_index[1],
)
return ops.unsorted_segmenet_sum(
data=label,
segment_ids=bucket_index,
num_segments=num_thresholds,
)
tp_bucket_v = ops.vectorized_map(
gather_bucket, (true_labels, bucket_indices), warn=False
)
fp_bucket_v = ops.vectorized_map(
gather_bucket, (false_labels, bucket_indices), warn=False
)
tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1))
fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1))
else:
tp_bucket_v = ops.unsorted_segmenet_sum(
data=true_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
fp_bucket_v = ops.unsorted_segmenet_sum(
data=false_labels,
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
tp = ops.cumsum(tp_bucket_v, reverse=True)
fp = ops.cumsum(fp_bucket_v, reverse=True)
# fn = sum(true_labels) - tp
# tn = sum(false_labels) - fp
if (
ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
):
if multi_label:
total_true_labels = ops.sum(true_labels, axis=1)
total_false_labels = ops.sum(false_labels, axis=1)
else:
total_true_labels = ops.sum(true_labels)
total_false_labels = ops.sum(false_labels)
if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
variable.assign(variable + tp)
if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
variable.assign(variable + fp)
if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
tn = total_false_labels - fp
variable.assign(variable + tn)
if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
fn = total_true_labels - tp
variable.assign(variable + fn)
def is_evenly_distributed_thresholds(thresholds):
"""Check if the thresholds list is evenly distributed.
We could leverage evenly distributed thresholds to use less memory when
calculate metrcis like AUC where each individual threshold need to be
evaluated.
Args:
thresholds: A python list or tuple, or 1D numpy array whose value is
ranged in [0, 1].
Returns:
boolean, whether the values in the inputs are evenly distributed.
"""
# Check the list value and see if it is evenly distributed.
num_thresholds = len(thresholds)
if num_thresholds < 3:
return False
even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
num_thresholds - 1
)
return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
def update_confusion_matrix_variables(
variables_to_update,
y_true,
y_pred,
thresholds,
top_k=None,
class_id=None,
sample_weight=None,
multi_label=False,
label_weights=None,
thresholds_distributed_evenly=False,
):
"""Updates the given confusion matrix variables.
For every pair of values in y_true and y_pred:
true_positive: y_true == True and y_pred > thresholds
false_negatives: y_true == True and y_pred <= thresholds
true_negatives: y_true == False and y_pred <= thresholds
false_positive: y_true == False and y_pred > thresholds
The results will be weighted and added together. When multiple thresholds
are provided, we will repeat the same for every threshold.
For estimation of these metrics over a stream of data, the function creates
an `update_op` operation that updates the given variables.
If `sample_weight` is `None`, weights default to 1.
Use weights of 0 to mask values.
Args:
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
and corresponding variables to update as values.
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
in the range `[0, 1]`.
thresholds: A float value, float tensor, python list, or tuple of float
thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
top_k: Optional int, indicates that the positive labels should be limited
to the top k predictions.
class_id: Optional int, limits the prediction and labels to the class
specified by this argument.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
must be either `1`, or the same as the corresponding `y_true`
dimension).
multi_label: Optional boolean indicating whether multidimensional
prediction/labels should be treated as multilabel responses, or
flattened into a single label. When True, the valus of
`variables_to_update` must have a second dimension equal to the number
of labels in y_true and y_pred, and those tensors must not be
RaggedTensors.
label_weights: (optional) tensor of non-negative weights for multilabel
data. The weights are applied when calculating TP, FP, FN, and TN
without explicit multilabel handling (i.e. when the data is to be
flattened).
thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
distributed within the list. An optimized method will be used if this is
the case. See _update_confusion_matrix_variables_optimized() for more
details.
Raises:
ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
`sample_weight` is not `None` and its shape doesn't match `y_pred`, or
if `variables_to_update` contains invalid keys.
"""
if multi_label and label_weights is not None:
raise ValueError(
"`label_weights` for multilabel data should be handled "
"outside of `update_confusion_matrix_variables` when "
"`multi_label` is True."
)
if variables_to_update is None:
return
if not any(
key for key in variables_to_update if key in list(ConfusionMatrix)
):
raise ValueError(
"Please provide at least one valid confusion matrix "
"variable to update. Valid variable key options are: "
f'"{list(ConfusionMatrix)}". '
f'Received: "{variables_to_update.keys()}"'
)
variable_dtype = list(variables_to_update.values())[0].dtype
y_true = ops.cast(y_true, dtype=variable_dtype)
y_pred = ops.cast(y_pred, dtype=variable_dtype)
if thresholds_distributed_evenly:
# Check whether the thresholds has any leading or tailing epsilon added
# for floating point imprecision. The leading and tailing threshold will
# be handled bit differently as the corner case. At this point,
# thresholds should be a list/array with more than 2 items, and ranged
# between [0, 1]. See is_evenly_distributed_thresholds() for more
# details.
thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype)
num_thresholds = ops.shape(thresholds)[0]
if multi_label:
one_thresh = ops.equal(
ops.cast(1, dtype="int32"),
thresholds.ndim,
name="one_set_of_thresholds_cond",
)
else:
one_thresh = ops.cast(True, dtype="bool")
invalid_keys = [
key for key in variables_to_update if key not in list(ConfusionMatrix)
]
if invalid_keys:
raise ValueError(
f'Invalid keys: "{invalid_keys}". '
f'Valid variable key options are: "{list(ConfusionMatrix)}"'
)
y_pred, y_true = squeeze_to_same_rank(y_pred, y_true)
if sample_weight is not None:
sample_weight = ops.expand_dims(
ops.cast(sample_weight, dtype=variable_dtype), axis=-1
)
_, sample_weight = squeeze_to_same_rank(y_true, sample_weight)
if top_k is not None:
y_pred = _filter_top_k(y_pred, top_k)
if class_id is not None:
# Preserve dimension to match with sample_weight
y_true = y_true[..., class_id, None]
y_pred = y_pred[..., class_id, None]
if thresholds_distributed_evenly:
return _update_confusion_matrix_variables_optimized(
variables_to_update,
y_true,
y_pred,
thresholds,
multi_label=multi_label,
sample_weights=sample_weight,
label_weights=label_weights,
thresholds_with_epsilon=thresholds_with_epsilon,
)
pred_shape = ops.shape(y_pred)
num_predictions = pred_shape[0]
if y_pred.ndim == 1:
num_labels = 1
else:
num_labels = ops.cast(
ops.prod(ops.array(pred_shape[1:]), axis=0), "int32"
)
thresh_label_tile = ops.where(one_thresh, num_labels, 1)
# Reshape predictions and labels, adding a dim for thresholding.
if multi_label:
predictions_extra_dim = ops.expand_dims(y_pred, 0)
labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0)
else:
# Flatten predictions and labels when not multilabel.
predictions_extra_dim = ops.reshape(y_pred, [1, -1])
labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1])
# Tile the thresholds for every prediction.
if multi_label:
thresh_pretile_shape = [num_thresholds, 1, -1]
thresh_tiles = [1, num_predictions, thresh_label_tile]
data_tiles = [num_thresholds, 1, 1]
else:
thresh_pretile_shape = [num_thresholds, -1]
thresh_tiles = [1, num_predictions * num_labels]
data_tiles = [num_thresholds, 1]
thresh_tiled = ops.tile(
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
)
# Tile the predictions for every threshold.
preds_tiled = ops.tile(predictions_extra_dim, data_tiles)
# Compare predictions and threshold.
pred_is_pos = ops.greater(preds_tiled, thresh_tiled)
# Tile labels by number of thresholds
label_is_pos = ops.tile(labels_extra_dim, data_tiles)
if sample_weight is not None:
sample_weight = ops.broadcast_to(
ops.cast(sample_weight, dtype=y_pred.dtype), y_pred.shape
)
weights_tiled = ops.tile(
ops.reshape(sample_weight, thresh_tiles), data_tiles
)
else:
weights_tiled = None
if label_weights is not None and not multi_label:
label_weights = ops.expand_dims(label_weights, 0)
label_weights = ops.broadcast_to(label_weights, y_pred.shape)
label_weights_tiled = ops.tile(
ops.reshape(label_weights, thresh_tiles), data_tiles
)
if weights_tiled is None:
weights_tiled = label_weights_tiled
else:
weights_tiled = ops.multiply(weights_tiled, label_weights_tiled)
def weighted_assign_add(label, pred, weights, var):
label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)
if weights is not None:
label_and_pred *= ops.cast(weights, dtype=var.dtype)
var.assign(var + ops.sum(label_and_pred, 1))
loop_vars = {
ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
}
update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
if update_fn or update_tn:
pred_is_neg = ops.logical_not(pred_is_pos)
loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
if update_fp or update_tn:
label_is_neg = ops.logical_not(label_is_pos)
loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
if update_tn:
loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
label_is_neg,
pred_is_neg,
)
for matrix_cond, (label, pred) in loop_vars.items():
if matrix_cond in variables_to_update:
weighted_assign_add(
label, pred, weights_tiled, variables_to_update[matrix_cond]
)
def _filter_top_k(x, k):
"""Filters top-k values in the last dim of x and set the rest to NEG_INF.
Used for computing top-k prediction values in dense labels (which has the
same shape as predictions) for recall and precision top-k metrics.
Args:
x: tensor with any dimensions.
k: the number of values to keep.
Returns:
tensor with same shape and dtype as x.
"""
_, top_k_idx = ops.top_k(x, k, sorted=False)
top_k_mask = ops.sum(
ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
)
return x * top_k_mask + NEG_INF * (1 - top_k_mask)

@ -146,7 +146,7 @@ def serialize_keras_object(obj):
"class_name": "__tensor__",
"config": {
"value": np.array(obj).tolist(),
"dtype": backend.standardize_dtype(obj.dtype),
"dtype": str(obj.dtype),
},
}
if type(obj).__module__ == np.__name__:
@ -155,7 +155,7 @@ def serialize_keras_object(obj):
"class_name": "__numpy__",
"config": {
"value": obj.tolist(),
"dtype": backend.standardize_dtype(obj.dtype),
"dtype": backend.standardize_dytpe(obj.dtype),
},
}
else:

@ -9,3 +9,16 @@ class TestCase(unittest.TestCase):
def assertAlmostEqual(self, x1, x2, decimal=3):
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
def assertEqual(self, x1, x2):
np.testing.assert_equal(x1, x2)
def assertLen(self, iterable, expected_len):
np.testing.assert_equal(len(iterable), expected_len)
def assertRaisesRegex(
self, exception_class, expected_regexp, *args, **kwargs
):
return np.testing.assert_raises_regex(
exception_class, expected_regexp, *args, **kwargs
)

@ -86,3 +86,20 @@ def func_load(code, defaults=None, closure=None, globs=None):
return python_types.FunctionType(
code, globs, name=code.co_name, argdefs=defaults, closure=closure
)
def to_list(x):
"""Normalizes a list/tensor into a list.
If a tensor is passed, we return
a list of size 1 containing the tensor.
Args:
x: target object to be normalized.
Returns:
A list.
"""
if isinstance(x, list):
return x
return [x]