First confusion metric (FalsePositives) (#30)
* Added confusion metrics -- still using TF ops * Fixed structure + tests pass for TF (still need to port to multi-backend) * Got rid of most tf deps, still a few more to go * Full removal of TF. Tests pass for both Jax and TF * Full removal of TF. Tests pass for both Jax and TF * Formatting * Formatting * Review comments * More review comments + formatting
This commit is contained in:
parent
53a858cb7f
commit
d6bcc56001
@ -15,7 +15,3 @@ def get(identifier):
|
||||
if identifier == "relu":
|
||||
return relu
|
||||
return identifier
|
||||
|
||||
|
||||
def serialize(activation):
|
||||
return activation.__name__
|
||||
|
@ -5,6 +5,7 @@ from tensorflow import nest
|
||||
|
||||
from keras_core.backend.common import KerasVariable
|
||||
from keras_core.backend.common import standardize_dtype
|
||||
from keras_core.backend.jax import math
|
||||
from keras_core.backend.jax import nn
|
||||
from keras_core.backend.jax import numpy
|
||||
from keras_core.backend.jax import random
|
||||
@ -50,6 +51,10 @@ def name_scope(name):
|
||||
return jax.named_scope(name)
|
||||
|
||||
|
||||
def vectorized_map(function, elements):
|
||||
return jax.vmap(function)(elements)
|
||||
|
||||
|
||||
class Variable(KerasVariable):
|
||||
def __init__(self, value, dtype=None, trainable=True, name=None):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
|
15
keras_core/backend/jax/math.py
Normal file
15
keras_core/backend/jax/math.py
Normal file
@ -0,0 +1,15 @@
|
||||
import jax
|
||||
|
||||
|
||||
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
||||
return jax.ops.segment_sum(
|
||||
data, segment_ids, num_segments, indices_are_sorted=sorted
|
||||
)
|
||||
|
||||
|
||||
def top_k(x, k, sorted=False):
|
||||
if sorted:
|
||||
return ValueError(
|
||||
"Jax backend does not support `sorted=True` for `ops.top_k`"
|
||||
)
|
||||
return jax.lax.top_k(x, k)
|
@ -103,3 +103,7 @@ def conv_transpose(
|
||||
):
|
||||
# TODO: Implement `conv_transpose`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def one_hot(x, num_classes, axis=-1):
|
||||
return jnn.one_hot(x, num_classes, axis=axis)
|
||||
|
@ -6,6 +6,7 @@ from keras_core.backend.common import standardize_dtype
|
||||
from keras_core.backend.keras_tensor import KerasTensor
|
||||
from keras_core.backend.stateless_scope import get_stateless_scope
|
||||
from keras_core.backend.stateless_scope import in_stateless_scope
|
||||
from keras_core.backend.tensorflow import math
|
||||
from keras_core.backend.tensorflow import nn
|
||||
from keras_core.backend.tensorflow import numpy
|
||||
from keras_core.backend.tensorflow import random
|
||||
@ -240,6 +241,10 @@ def name_scope(name):
|
||||
return tf.name_scope(name)
|
||||
|
||||
|
||||
def vectorized_map(function, elements):
|
||||
return tf.vectorized_map(function, elements)
|
||||
|
||||
|
||||
def compute_output_spec(fn, *args, **kwargs):
|
||||
graph_name = auto_name("scratch_graph")
|
||||
with tf.__internal__.FuncGraph(graph_name).as_default():
|
||||
|
12
keras_core/backend/tensorflow/math.py
Normal file
12
keras_core/backend/tensorflow/math.py
Normal file
@ -0,0 +1,12 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
||||
if sorted:
|
||||
return tf.math.segment_sum(data, segment_ids)
|
||||
else:
|
||||
return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)
|
||||
|
||||
|
||||
def top_k(x, k, sorted=False):
|
||||
return tf.math.top_k(x, k, sorted=sorted)
|
@ -390,3 +390,7 @@ def conv_transpose(
|
||||
data_format=tf_data_format,
|
||||
dilations=dilation_rate,
|
||||
)
|
||||
|
||||
|
||||
def one_hot(x, num_classes, axis=-1):
|
||||
return tf.one_hot(x, num_classes, axis=axis)
|
||||
|
@ -61,15 +61,10 @@ class Dense(Layer):
|
||||
# TODO
|
||||
config = {
|
||||
"units": self.units,
|
||||
"activation": activations.serialize(self.activation),
|
||||
"use_bias": self.use_bias,
|
||||
"kernel_initializer": initializers.serialize(
|
||||
self.kernel_initializer
|
||||
),
|
||||
"bias_initializer": initializers.serialize(self.bias_initializer),
|
||||
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
|
||||
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
||||
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
||||
"bias_constraint": constraints.serialize(self.bias_constraint),
|
||||
}
|
||||
return {**base_config, **config}
|
||||
|
@ -1,4 +1,5 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.metrics.confusion_metrics import FalsePositives
|
||||
from keras_core.metrics.metric import Metric
|
||||
from keras_core.metrics.reduction_metrics import Mean
|
||||
from keras_core.metrics.reduction_metrics import MeanMetricWrapper
|
||||
|
115
keras_core/metrics/confusion_metrics.py
Normal file
115
keras_core/metrics/confusion_metrics.py
Normal file
@ -0,0 +1,115 @@
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.metrics import metrics_utils
|
||||
from keras_core.metrics.metric import Metric
|
||||
|
||||
|
||||
class _ConfusionMatrixConditionCount(Metric):
|
||||
"""Calculates the number of the given confusion matrix condition.
|
||||
|
||||
Args:
|
||||
confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`
|
||||
conditions.
|
||||
thresholds: (Optional) Defaults to 0.5. A float value or a python list /
|
||||
tuple of float threshold values in [0, 1]. A threshold is compared
|
||||
with prediction values to determine the truth value of predictions
|
||||
(i.e., above the threshold is `true`, below is `false`). One metric
|
||||
value is generated for each threshold value.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, confusion_matrix_cond, thresholds=None, name=None, dtype=None
|
||||
):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self._confusion_matrix_cond = confusion_matrix_cond
|
||||
self.init_thresholds = thresholds
|
||||
self.thresholds = metrics_utils.parse_init_thresholds(
|
||||
thresholds, default_threshold=0.5
|
||||
)
|
||||
self._thresholds_distributed_evenly = (
|
||||
metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
|
||||
)
|
||||
self.accumulator = self.add_variable(
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=initializers.Zeros(),
|
||||
name="accumulator",
|
||||
)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
"""Accumulates the metric statistics.
|
||||
|
||||
Args:
|
||||
y_true: The ground truth values.
|
||||
y_pred: The predicted values.
|
||||
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||
Can be a tensor whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`.
|
||||
"""
|
||||
return metrics_utils.update_confusion_matrix_variables(
|
||||
{self._confusion_matrix_cond: self.accumulator},
|
||||
y_true,
|
||||
y_pred,
|
||||
thresholds=self.thresholds,
|
||||
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
def result(self):
|
||||
if len(self.thresholds) == 1:
|
||||
result = self.accumulator[0]
|
||||
else:
|
||||
result = self.accumulator
|
||||
return backend.convert_to_tensor(result)
|
||||
|
||||
def get_config(self):
|
||||
config = {"thresholds": self.init_thresholds}
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.FalsePositives")
|
||||
class FalsePositives(_ConfusionMatrixConditionCount):
|
||||
"""Calculates the number of false positives.
|
||||
|
||||
If `sample_weight` is given, calculates the sum of the weights of
|
||||
false positives. This metric creates one local variable, `accumulator`
|
||||
that is used to keep track of the number of false positives.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
Args:
|
||||
thresholds: (Optional) Defaults to 0.5. A float value, or a Python
|
||||
list/tuple of float threshold values in [0, 1]. A threshold is
|
||||
compared with prediction values to determine the truth value of
|
||||
predictions (i.e., above the threshold is `true`, below is `false`).
|
||||
If used with a loss function that sets `from_logits=True` (i.e. no
|
||||
sigmoid applied to predictions), `thresholds` should be set to 0.
|
||||
One metric value is generated for each threshold value.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.FalsePositives()
|
||||
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
|
||||
>>> m.result()
|
||||
2.0
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
|
||||
>>> m.result()
|
||||
1.0
|
||||
"""
|
||||
|
||||
def __init__(self, thresholds=None, name=None, dtype=None):
|
||||
super().__init__(
|
||||
confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
|
||||
thresholds=thresholds,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
)
|
103
keras_core/metrics/confusion_metrics_test.py
Normal file
103
keras_core/metrics/confusion_metrics_test.py
Normal file
@ -0,0 +1,103 @@
|
||||
import numpy as np
|
||||
from tensorflow.python.ops.numpy_ops import np_config
|
||||
|
||||
from keras_core import metrics
|
||||
from keras_core import testing
|
||||
|
||||
np_config.enable_numpy_behavior()
|
||||
|
||||
|
||||
class FalsePositivesTest(testing.TestCase):
|
||||
def test_config(self):
|
||||
fp_obj = metrics.FalsePositives(name="my_fp", thresholds=[0.4, 0.9])
|
||||
self.assertEqual(fp_obj.name, "my_fp")
|
||||
self.assertLen(fp_obj.variables, 1)
|
||||
self.assertEqual(fp_obj.thresholds, [0.4, 0.9])
|
||||
|
||||
# Check save and restore config
|
||||
fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config())
|
||||
self.assertEqual(fp_obj2.name, "my_fp")
|
||||
self.assertLen(fp_obj2.variables, 1)
|
||||
self.assertEqual(fp_obj2.thresholds, [0.4, 0.9])
|
||||
|
||||
def test_unweighted(self):
|
||||
fp_obj = metrics.FalsePositives()
|
||||
|
||||
y_true = np.array(
|
||||
((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))
|
||||
)
|
||||
y_pred = np.array(
|
||||
((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))
|
||||
)
|
||||
|
||||
fp_obj.update_state(y_true, y_pred)
|
||||
result = fp_obj.result()
|
||||
self.assertAllClose(7.0, result)
|
||||
|
||||
def test_weighted(self):
|
||||
fp_obj = metrics.FalsePositives()
|
||||
y_true = np.array(
|
||||
((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))
|
||||
)
|
||||
y_pred = np.array(
|
||||
((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))
|
||||
)
|
||||
sample_weight = np.array((1.0, 1.5, 2.0, 2.5))
|
||||
result = fp_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAllClose(14.0, result)
|
||||
|
||||
def test_unweighted_with_thresholds(self):
|
||||
fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])
|
||||
|
||||
y_pred = np.array(
|
||||
(
|
||||
(0.9, 0.2, 0.8, 0.1),
|
||||
(0.2, 0.9, 0.7, 0.6),
|
||||
(0.1, 0.2, 0.4, 0.3),
|
||||
(0, 1, 0.7, 0.3),
|
||||
)
|
||||
)
|
||||
y_true = np.array(
|
||||
((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))
|
||||
)
|
||||
|
||||
fp_obj.update_state(y_true, y_pred)
|
||||
result = fp_obj.result()
|
||||
self.assertAllClose([7.0, 4.0, 2.0], result)
|
||||
|
||||
def test_weighted_with_thresholds(self):
|
||||
fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])
|
||||
|
||||
y_pred = np.array(
|
||||
(
|
||||
(0.9, 0.2, 0.8, 0.1),
|
||||
(0.2, 0.9, 0.7, 0.6),
|
||||
(0.1, 0.2, 0.4, 0.3),
|
||||
(0, 1, 0.7, 0.3),
|
||||
)
|
||||
)
|
||||
y_true = np.array(
|
||||
((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))
|
||||
)
|
||||
sample_weight = (
|
||||
(1.0, 2.0, 3.0, 5.0),
|
||||
(7.0, 11.0, 13.0, 17.0),
|
||||
(19.0, 23.0, 29.0, 31.0),
|
||||
(5.0, 15.0, 10.0, 0),
|
||||
)
|
||||
|
||||
result = fp_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAllClose([125.0, 42.0, 12.0], result)
|
||||
|
||||
def test_threshold_limit(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]",
|
||||
):
|
||||
metrics.FalsePositives(thresholds=[-1, 0.5, 2])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Threshold values must be in \[0, 1\]. Received: \[None\]",
|
||||
):
|
||||
metrics.FalsePositives(thresholds=[None])
|
540
keras_core/metrics/metrics_utils.py
Normal file
540
keras_core/metrics/metrics_utils.py
Normal file
@ -0,0 +1,540 @@
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.losses.loss import squeeze_to_same_rank
|
||||
from keras_core.utils.python_utils import to_list
|
||||
|
||||
NEG_INF = -1e10
|
||||
|
||||
|
||||
def assert_thresholds_range(thresholds):
|
||||
if thresholds is not None:
|
||||
invalid_thresholds = [
|
||||
t for t in thresholds if t is None or t < 0 or t > 1
|
||||
]
|
||||
if invalid_thresholds:
|
||||
raise ValueError(
|
||||
"Threshold values must be in [0, 1]. "
|
||||
f"Received: {invalid_thresholds}"
|
||||
)
|
||||
|
||||
|
||||
def parse_init_thresholds(thresholds, default_threshold=0.5):
|
||||
if thresholds is not None:
|
||||
assert_thresholds_range(to_list(thresholds))
|
||||
thresholds = to_list(
|
||||
default_threshold if thresholds is None else thresholds
|
||||
)
|
||||
return thresholds
|
||||
|
||||
|
||||
class ConfusionMatrix(Enum):
|
||||
TRUE_POSITIVES = "tp"
|
||||
FALSE_POSITIVES = "fp"
|
||||
TRUE_NEGATIVES = "tn"
|
||||
FALSE_NEGATIVES = "fn"
|
||||
|
||||
|
||||
def _update_confusion_matrix_variables_optimized(
|
||||
variables_to_update,
|
||||
y_true,
|
||||
y_pred,
|
||||
thresholds,
|
||||
multi_label=False,
|
||||
sample_weights=None,
|
||||
label_weights=None,
|
||||
thresholds_with_epsilon=False,
|
||||
):
|
||||
"""Update confusion matrix variables with memory efficient alternative.
|
||||
|
||||
Note that the thresholds need to be evenly distributed within the list, eg,
|
||||
the diff between consecutive elements are the same.
|
||||
|
||||
To compute TP/FP/TN/FN, we are measuring a binary classifier
|
||||
C(t) = (predictions >= t)
|
||||
at each threshold 't'. So we have
|
||||
TP(t) = sum( C(t) * true_labels )
|
||||
FP(t) = sum( C(t) * false_labels )
|
||||
|
||||
But, computing C(t) requires computation for each t. To make it fast,
|
||||
observe that C(t) is a cumulative integral, and so if we have
|
||||
thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
|
||||
where n = num_thresholds, and if we can compute the bucket function
|
||||
B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
|
||||
then we get
|
||||
C(t_i) = sum( B(j), j >= i )
|
||||
which is the reversed cumulative sum in ops.cumsum().
|
||||
|
||||
We can compute B(i) efficiently by taking advantage of the fact that
|
||||
our thresholds are evenly distributed, in that
|
||||
width = 1.0 / (num_thresholds - 1)
|
||||
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||||
Given a prediction value p, we can map it to its bucket by
|
||||
bucket_index(p) = floor( p * (num_thresholds - 1) )
|
||||
so we can use ops.unsorted_segmenet_sum() to update the buckets in one
|
||||
pass.
|
||||
|
||||
Consider following example:
|
||||
y_true = [0, 0, 1, 1]
|
||||
y_pred = [0.1, 0.5, 0.3, 0.9]
|
||||
thresholds = [0.0, 0.5, 1.0]
|
||||
num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
|
||||
bucket_index(y_pred) = ops.floor(y_pred * num_buckets)
|
||||
= ops.floor([0.2, 1.0, 0.6, 1.8])
|
||||
= [0, 0, 0, 1]
|
||||
# The meaning of this bucket is that if any of the label is true,
|
||||
# then 1 will be added to the corresponding bucket with the index.
|
||||
# Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
|
||||
# label for 1.8 is true, then 1 will be added to bucket 1.
|
||||
#
|
||||
# Note the second item "1.0" is floored to 0, since the value need to be
|
||||
# strictly larger than the bucket lower bound.
|
||||
# In the implementation, we use ops.ceil() - 1 to achieve this.
|
||||
tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices,
|
||||
num_segments=num_thresholds)
|
||||
= [1, 1, 0]
|
||||
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
|
||||
# 0, and 1 value contributed by bucket 1. When we aggregate them to
|
||||
# together, the result become [a + b + c, b + c, c], since large thresholds
|
||||
# will always contribute to the value for smaller thresholds.
|
||||
true_positive = ops.cumsum(tp_bucket_value, reverse=True)
|
||||
= [2, 1, 0]
|
||||
|
||||
This implementation exhibits a run time and space complexity of O(T + N),
|
||||
where T is the number of thresholds and N is the size of predictions.
|
||||
Metrics that rely on standard implementation instead exhibit a complexity of
|
||||
O(T * N).
|
||||
|
||||
Args:
|
||||
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
||||
and corresponding variables to update as values.
|
||||
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
|
||||
cast to `bool`.
|
||||
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
||||
in the range `[0, 1]`.
|
||||
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
|
||||
It need to be evenly distributed (the diff between each element need to
|
||||
be the same).
|
||||
multi_label: Optional boolean indicating whether multidimensional
|
||||
prediction/labels should be treated as multilabel responses, or
|
||||
flattened into a single label. When True, the valus of
|
||||
`variables_to_update` must have a second dimension equal to the number
|
||||
of labels in y_true and y_pred, and those tensors must not be
|
||||
RaggedTensors.
|
||||
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
|
||||
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
||||
must be either `1`, or the same as the corresponding `y_true`
|
||||
dimension).
|
||||
label_weights: Optional tensor of non-negative weights for multilabel
|
||||
data. The weights are applied when calculating TP, FP, FN, and TN
|
||||
without explicit multilabel handling (i.e. when the data is to be
|
||||
flattened).
|
||||
thresholds_with_epsilon: Optional boolean indicating whether the leading
|
||||
and tailing thresholds has any epsilon added for floating point
|
||||
imprecisions. It will change how we handle the leading and tailing
|
||||
bucket.
|
||||
"""
|
||||
num_thresholds = thresholds.shape.as_list()[0]
|
||||
|
||||
if sample_weights is None:
|
||||
sample_weights = 1.0
|
||||
else:
|
||||
sample_weights = ops.broadcast_to(
|
||||
ops.cast(sample_weights, dtype=y_pred.dtype), y_pred.shape
|
||||
)
|
||||
if not multi_label:
|
||||
sample_weights = ops.reshape(sample_weights, [-1])
|
||||
if label_weights is None:
|
||||
label_weights = 1.0
|
||||
else:
|
||||
label_weights = ops.expand_dims(label_weights, 0)
|
||||
label_weights = ops.broadcast_to(label_weights, y_pred.shape)
|
||||
if not multi_label:
|
||||
label_weights = ops.reshape(label_weights, [-1])
|
||||
weights = ops.cast(
|
||||
ops.multiply(sample_weights, label_weights), y_true.dtype
|
||||
)
|
||||
|
||||
# We shouldn't need this, but in case there are predict value that is out of
|
||||
# the range of [0.0, 1.0]
|
||||
y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0)
|
||||
|
||||
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
|
||||
if not multi_label:
|
||||
y_true = ops.reshape(y_true, [-1])
|
||||
y_pred = ops.reshape(y_pred, [-1])
|
||||
|
||||
true_labels = ops.multiply(y_true, weights)
|
||||
false_labels = ops.multiply((1.0 - y_true), weights)
|
||||
|
||||
# Compute the bucket indices for each prediction value.
|
||||
# Since the predict value has to be strictly greater than the thresholds,
|
||||
# eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
|
||||
# We have to use math.ceil(val) - 1 for the bucket.
|
||||
bucket_indices = ops.ceil(y_pred * (num_thresholds - 1)) - 1
|
||||
|
||||
if thresholds_with_epsilon:
|
||||
# In this case, the first bucket should actually take into account since
|
||||
# the any prediction between [0.0, 1.0] should be larger than the first
|
||||
# threshold. We change the bucket value from -1 to 0.
|
||||
bucket_indices = ops.relu(bucket_indices)
|
||||
|
||||
bucket_indices = ops.cast(bucket_indices, "int32")
|
||||
|
||||
if multi_label:
|
||||
# We need to run bucket segment sum for each of the label class. In the
|
||||
# multi_label case, the rank of the label is 2. We first transpose it so
|
||||
# that the label dim becomes the first and we can parallel run though
|
||||
# them.
|
||||
true_labels = ops.transpose(true_labels)
|
||||
false_labels = ops.transpose(false_labels)
|
||||
bucket_indices = ops.transpose(bucket_indices)
|
||||
|
||||
def gather_bucket(label_and_bucket_index):
|
||||
label, bucket_index = (
|
||||
label_and_bucket_index[0],
|
||||
label_and_bucket_index[1],
|
||||
)
|
||||
return ops.unsorted_segmenet_sum(
|
||||
data=label,
|
||||
segment_ids=bucket_index,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
|
||||
tp_bucket_v = ops.vectorized_map(
|
||||
gather_bucket, (true_labels, bucket_indices), warn=False
|
||||
)
|
||||
fp_bucket_v = ops.vectorized_map(
|
||||
gather_bucket, (false_labels, bucket_indices), warn=False
|
||||
)
|
||||
tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1))
|
||||
fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1))
|
||||
else:
|
||||
tp_bucket_v = ops.unsorted_segmenet_sum(
|
||||
data=true_labels,
|
||||
segment_ids=bucket_indices,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
fp_bucket_v = ops.unsorted_segmenet_sum(
|
||||
data=false_labels,
|
||||
segment_ids=bucket_indices,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
tp = ops.cumsum(tp_bucket_v, reverse=True)
|
||||
fp = ops.cumsum(fp_bucket_v, reverse=True)
|
||||
|
||||
# fn = sum(true_labels) - tp
|
||||
# tn = sum(false_labels) - fp
|
||||
if (
|
||||
ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
||||
or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
||||
):
|
||||
if multi_label:
|
||||
total_true_labels = ops.sum(true_labels, axis=1)
|
||||
total_false_labels = ops.sum(false_labels, axis=1)
|
||||
else:
|
||||
total_true_labels = ops.sum(true_labels)
|
||||
total_false_labels = ops.sum(false_labels)
|
||||
|
||||
if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
|
||||
variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
|
||||
variable.assign(variable + tp)
|
||||
if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
|
||||
variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
|
||||
variable.assign(variable + fp)
|
||||
if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
|
||||
variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
|
||||
tn = total_false_labels - fp
|
||||
variable.assign(variable + tn)
|
||||
if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
|
||||
variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
|
||||
fn = total_true_labels - tp
|
||||
variable.assign(variable + fn)
|
||||
|
||||
|
||||
def is_evenly_distributed_thresholds(thresholds):
|
||||
"""Check if the thresholds list is evenly distributed.
|
||||
|
||||
We could leverage evenly distributed thresholds to use less memory when
|
||||
calculate metrcis like AUC where each individual threshold need to be
|
||||
evaluated.
|
||||
|
||||
Args:
|
||||
thresholds: A python list or tuple, or 1D numpy array whose value is
|
||||
ranged in [0, 1].
|
||||
|
||||
Returns:
|
||||
boolean, whether the values in the inputs are evenly distributed.
|
||||
"""
|
||||
# Check the list value and see if it is evenly distributed.
|
||||
num_thresholds = len(thresholds)
|
||||
if num_thresholds < 3:
|
||||
return False
|
||||
even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
|
||||
num_thresholds - 1
|
||||
)
|
||||
return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
|
||||
|
||||
|
||||
def update_confusion_matrix_variables(
|
||||
variables_to_update,
|
||||
y_true,
|
||||
y_pred,
|
||||
thresholds,
|
||||
top_k=None,
|
||||
class_id=None,
|
||||
sample_weight=None,
|
||||
multi_label=False,
|
||||
label_weights=None,
|
||||
thresholds_distributed_evenly=False,
|
||||
):
|
||||
"""Updates the given confusion matrix variables.
|
||||
|
||||
For every pair of values in y_true and y_pred:
|
||||
|
||||
true_positive: y_true == True and y_pred > thresholds
|
||||
false_negatives: y_true == True and y_pred <= thresholds
|
||||
true_negatives: y_true == False and y_pred <= thresholds
|
||||
false_positive: y_true == False and y_pred > thresholds
|
||||
|
||||
The results will be weighted and added together. When multiple thresholds
|
||||
are provided, we will repeat the same for every threshold.
|
||||
|
||||
For estimation of these metrics over a stream of data, the function creates
|
||||
an `update_op` operation that updates the given variables.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use weights of 0 to mask values.
|
||||
|
||||
Args:
|
||||
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
||||
and corresponding variables to update as values.
|
||||
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
|
||||
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
||||
in the range `[0, 1]`.
|
||||
thresholds: A float value, float tensor, python list, or tuple of float
|
||||
thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
|
||||
top_k: Optional int, indicates that the positive labels should be limited
|
||||
to the top k predictions.
|
||||
class_id: Optional int, limits the prediction and labels to the class
|
||||
specified by this argument.
|
||||
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
|
||||
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
||||
must be either `1`, or the same as the corresponding `y_true`
|
||||
dimension).
|
||||
multi_label: Optional boolean indicating whether multidimensional
|
||||
prediction/labels should be treated as multilabel responses, or
|
||||
flattened into a single label. When True, the valus of
|
||||
`variables_to_update` must have a second dimension equal to the number
|
||||
of labels in y_true and y_pred, and those tensors must not be
|
||||
RaggedTensors.
|
||||
label_weights: (optional) tensor of non-negative weights for multilabel
|
||||
data. The weights are applied when calculating TP, FP, FN, and TN
|
||||
without explicit multilabel handling (i.e. when the data is to be
|
||||
flattened).
|
||||
thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
|
||||
distributed within the list. An optimized method will be used if this is
|
||||
the case. See _update_confusion_matrix_variables_optimized() for more
|
||||
details.
|
||||
|
||||
Raises:
|
||||
ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
|
||||
`sample_weight` is not `None` and its shape doesn't match `y_pred`, or
|
||||
if `variables_to_update` contains invalid keys.
|
||||
"""
|
||||
if multi_label and label_weights is not None:
|
||||
raise ValueError(
|
||||
"`label_weights` for multilabel data should be handled "
|
||||
"outside of `update_confusion_matrix_variables` when "
|
||||
"`multi_label` is True."
|
||||
)
|
||||
if variables_to_update is None:
|
||||
return
|
||||
if not any(
|
||||
key for key in variables_to_update if key in list(ConfusionMatrix)
|
||||
):
|
||||
raise ValueError(
|
||||
"Please provide at least one valid confusion matrix "
|
||||
"variable to update. Valid variable key options are: "
|
||||
f'"{list(ConfusionMatrix)}". '
|
||||
f'Received: "{variables_to_update.keys()}"'
|
||||
)
|
||||
|
||||
variable_dtype = list(variables_to_update.values())[0].dtype
|
||||
|
||||
y_true = ops.cast(y_true, dtype=variable_dtype)
|
||||
y_pred = ops.cast(y_pred, dtype=variable_dtype)
|
||||
|
||||
if thresholds_distributed_evenly:
|
||||
# Check whether the thresholds has any leading or tailing epsilon added
|
||||
# for floating point imprecision. The leading and tailing threshold will
|
||||
# be handled bit differently as the corner case. At this point,
|
||||
# thresholds should be a list/array with more than 2 items, and ranged
|
||||
# between [0, 1]. See is_evenly_distributed_thresholds() for more
|
||||
# details.
|
||||
thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
|
||||
|
||||
thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype)
|
||||
num_thresholds = ops.shape(thresholds)[0]
|
||||
|
||||
if multi_label:
|
||||
one_thresh = ops.equal(
|
||||
ops.cast(1, dtype="int32"),
|
||||
thresholds.ndim,
|
||||
name="one_set_of_thresholds_cond",
|
||||
)
|
||||
else:
|
||||
one_thresh = ops.cast(True, dtype="bool")
|
||||
|
||||
invalid_keys = [
|
||||
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
||||
]
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f'Invalid keys: "{invalid_keys}". '
|
||||
f'Valid variable key options are: "{list(ConfusionMatrix)}"'
|
||||
)
|
||||
|
||||
y_pred, y_true = squeeze_to_same_rank(y_pred, y_true)
|
||||
if sample_weight is not None:
|
||||
sample_weight = ops.expand_dims(
|
||||
ops.cast(sample_weight, dtype=variable_dtype), axis=-1
|
||||
)
|
||||
_, sample_weight = squeeze_to_same_rank(y_true, sample_weight)
|
||||
|
||||
if top_k is not None:
|
||||
y_pred = _filter_top_k(y_pred, top_k)
|
||||
if class_id is not None:
|
||||
# Preserve dimension to match with sample_weight
|
||||
y_true = y_true[..., class_id, None]
|
||||
y_pred = y_pred[..., class_id, None]
|
||||
|
||||
if thresholds_distributed_evenly:
|
||||
return _update_confusion_matrix_variables_optimized(
|
||||
variables_to_update,
|
||||
y_true,
|
||||
y_pred,
|
||||
thresholds,
|
||||
multi_label=multi_label,
|
||||
sample_weights=sample_weight,
|
||||
label_weights=label_weights,
|
||||
thresholds_with_epsilon=thresholds_with_epsilon,
|
||||
)
|
||||
|
||||
pred_shape = ops.shape(y_pred)
|
||||
num_predictions = pred_shape[0]
|
||||
if y_pred.ndim == 1:
|
||||
num_labels = 1
|
||||
else:
|
||||
num_labels = ops.cast(
|
||||
ops.prod(ops.array(pred_shape[1:]), axis=0), "int32"
|
||||
)
|
||||
thresh_label_tile = ops.where(one_thresh, num_labels, 1)
|
||||
|
||||
# Reshape predictions and labels, adding a dim for thresholding.
|
||||
if multi_label:
|
||||
predictions_extra_dim = ops.expand_dims(y_pred, 0)
|
||||
labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0)
|
||||
else:
|
||||
# Flatten predictions and labels when not multilabel.
|
||||
predictions_extra_dim = ops.reshape(y_pred, [1, -1])
|
||||
labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1])
|
||||
|
||||
# Tile the thresholds for every prediction.
|
||||
if multi_label:
|
||||
thresh_pretile_shape = [num_thresholds, 1, -1]
|
||||
thresh_tiles = [1, num_predictions, thresh_label_tile]
|
||||
data_tiles = [num_thresholds, 1, 1]
|
||||
else:
|
||||
thresh_pretile_shape = [num_thresholds, -1]
|
||||
thresh_tiles = [1, num_predictions * num_labels]
|
||||
data_tiles = [num_thresholds, 1]
|
||||
|
||||
thresh_tiled = ops.tile(
|
||||
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
|
||||
)
|
||||
|
||||
# Tile the predictions for every threshold.
|
||||
preds_tiled = ops.tile(predictions_extra_dim, data_tiles)
|
||||
|
||||
# Compare predictions and threshold.
|
||||
pred_is_pos = ops.greater(preds_tiled, thresh_tiled)
|
||||
|
||||
# Tile labels by number of thresholds
|
||||
label_is_pos = ops.tile(labels_extra_dim, data_tiles)
|
||||
|
||||
if sample_weight is not None:
|
||||
sample_weight = ops.broadcast_to(
|
||||
ops.cast(sample_weight, dtype=y_pred.dtype), y_pred.shape
|
||||
)
|
||||
weights_tiled = ops.tile(
|
||||
ops.reshape(sample_weight, thresh_tiles), data_tiles
|
||||
)
|
||||
else:
|
||||
weights_tiled = None
|
||||
|
||||
if label_weights is not None and not multi_label:
|
||||
label_weights = ops.expand_dims(label_weights, 0)
|
||||
label_weights = ops.broadcast_to(label_weights, y_pred.shape)
|
||||
label_weights_tiled = ops.tile(
|
||||
ops.reshape(label_weights, thresh_tiles), data_tiles
|
||||
)
|
||||
if weights_tiled is None:
|
||||
weights_tiled = label_weights_tiled
|
||||
else:
|
||||
weights_tiled = ops.multiply(weights_tiled, label_weights_tiled)
|
||||
|
||||
def weighted_assign_add(label, pred, weights, var):
|
||||
label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)
|
||||
if weights is not None:
|
||||
label_and_pred *= ops.cast(weights, dtype=var.dtype)
|
||||
var.assign(var + ops.sum(label_and_pred, 1))
|
||||
|
||||
loop_vars = {
|
||||
ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
|
||||
}
|
||||
update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
||||
update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
|
||||
update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
||||
|
||||
if update_fn or update_tn:
|
||||
pred_is_neg = ops.logical_not(pred_is_pos)
|
||||
loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
|
||||
|
||||
if update_fp or update_tn:
|
||||
label_is_neg = ops.logical_not(label_is_pos)
|
||||
loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
|
||||
if update_tn:
|
||||
loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
|
||||
label_is_neg,
|
||||
pred_is_neg,
|
||||
)
|
||||
|
||||
for matrix_cond, (label, pred) in loop_vars.items():
|
||||
if matrix_cond in variables_to_update:
|
||||
weighted_assign_add(
|
||||
label, pred, weights_tiled, variables_to_update[matrix_cond]
|
||||
)
|
||||
|
||||
|
||||
def _filter_top_k(x, k):
|
||||
"""Filters top-k values in the last dim of x and set the rest to NEG_INF.
|
||||
|
||||
Used for computing top-k prediction values in dense labels (which has the
|
||||
same shape as predictions) for recall and precision top-k metrics.
|
||||
|
||||
Args:
|
||||
x: tensor with any dimensions.
|
||||
k: the number of values to keep.
|
||||
|
||||
Returns:
|
||||
tensor with same shape and dtype as x.
|
||||
"""
|
||||
_, top_k_idx = ops.top_k(x, k, sorted=False)
|
||||
top_k_mask = ops.sum(
|
||||
ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
|
||||
)
|
||||
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
|
@ -146,7 +146,7 @@ def serialize_keras_object(obj):
|
||||
"class_name": "__tensor__",
|
||||
"config": {
|
||||
"value": np.array(obj).tolist(),
|
||||
"dtype": backend.standardize_dtype(obj.dtype),
|
||||
"dtype": str(obj.dtype),
|
||||
},
|
||||
}
|
||||
if type(obj).__module__ == np.__name__:
|
||||
@ -155,7 +155,7 @@ def serialize_keras_object(obj):
|
||||
"class_name": "__numpy__",
|
||||
"config": {
|
||||
"value": obj.tolist(),
|
||||
"dtype": backend.standardize_dtype(obj.dtype),
|
||||
"dtype": backend.standardize_dytpe(obj.dtype),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
@ -9,3 +9,16 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
def assertAlmostEqual(self, x1, x2, decimal=3):
|
||||
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
|
||||
|
||||
def assertEqual(self, x1, x2):
|
||||
np.testing.assert_equal(x1, x2)
|
||||
|
||||
def assertLen(self, iterable, expected_len):
|
||||
np.testing.assert_equal(len(iterable), expected_len)
|
||||
|
||||
def assertRaisesRegex(
|
||||
self, exception_class, expected_regexp, *args, **kwargs
|
||||
):
|
||||
return np.testing.assert_raises_regex(
|
||||
exception_class, expected_regexp, *args, **kwargs
|
||||
)
|
||||
|
@ -86,3 +86,20 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
||||
return python_types.FunctionType(
|
||||
code, globs, name=code.co_name, argdefs=defaults, closure=closure
|
||||
)
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""Normalizes a list/tensor into a list.
|
||||
|
||||
If a tensor is passed, we return
|
||||
a list of size 1 containing the tensor.
|
||||
|
||||
Args:
|
||||
x: target object to be normalized.
|
||||
|
||||
Returns:
|
||||
A list.
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
return [x]
|
||||
|
Loading…
Reference in New Issue
Block a user