Add 4 more Confusion Metrics (#73)
* Add recall * Add SensitivitySpecificity metrics * Review comments * Add missing method from math * Fix pytest * Update config dict creation * Remove sentence fragment * Review comments * Add special casing for min/max
This commit is contained in:
parent
1c994b749a
commit
55640456ff
@ -25,8 +25,8 @@ def mean(x, axis=None, keepdims=False):
|
||||
return jnp.mean(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
return jnp.max(x, axis=axis, keepdims=keepdims)
|
||||
def max(x, axis=None, keepdims=False, initial=None):
|
||||
return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial)
|
||||
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
@ -327,8 +327,8 @@ def meshgrid(*x, indexing="xy"):
|
||||
return jnp.meshgrid(*x, indexing=indexing)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False):
|
||||
return jnp.min(x, axis=axis, keepdims=keepdims)
|
||||
def min(x, axis=None, keepdims=False, initial=None):
|
||||
return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial)
|
||||
|
||||
|
||||
def minimum(x1, x2):
|
||||
|
@ -26,7 +26,22 @@ def mean(x, axis=None, keepdims=False):
|
||||
return tfnp.mean(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
def max(x, axis=None, keepdims=False, initial=None):
|
||||
# The TensorFlow numpy API implementation doesn't support `initial` so we
|
||||
# handle it manually here.
|
||||
if initial is not None:
|
||||
return tf.math.maximum(
|
||||
tfnp.max(x, axis=axis, keepdims=keepdims), initial
|
||||
)
|
||||
|
||||
# TensorFlow returns -inf by default for an empty list, but for consistency
|
||||
# with other backends and the numpy API we want to throw in this case.
|
||||
tf.assert_greater(
|
||||
size(x),
|
||||
tf.constant(0, dtype=tf.int64),
|
||||
message="Cannot compute the max of an empty tensor.",
|
||||
)
|
||||
|
||||
return tfnp.max(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
@ -328,7 +343,22 @@ def meshgrid(*x, indexing="xy"):
|
||||
return tfnp.meshgrid(*x, indexing=indexing)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False):
|
||||
def min(x, axis=None, keepdims=False, initial=None):
|
||||
# The TensorFlow numpy API implementation doesn't support `initial` so we
|
||||
# handle it manually here.
|
||||
if initial is not None:
|
||||
return tf.math.minimum(
|
||||
tfnp.min(x, axis=axis, keepdims=keepdims), initial
|
||||
)
|
||||
|
||||
# TensorFlow returns inf by default for an empty list, but for consistency
|
||||
# with other backends and the numpy API we want to throw in this case.
|
||||
tf.assert_greater(
|
||||
size(x),
|
||||
tf.constant(0, dtype=tf.int64),
|
||||
message="Cannot compute the min of an empty tensor.",
|
||||
)
|
||||
|
||||
return tfnp.min(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
|
@ -8,7 +8,11 @@ from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
|
||||
from keras_core.metrics.confusion_metrics import FalseNegatives
|
||||
from keras_core.metrics.confusion_metrics import FalsePositives
|
||||
from keras_core.metrics.confusion_metrics import Precision
|
||||
from keras_core.metrics.confusion_metrics import PrecisionAtRecall
|
||||
from keras_core.metrics.confusion_metrics import Recall
|
||||
from keras_core.metrics.confusion_metrics import RecallAtPrecision
|
||||
from keras_core.metrics.confusion_metrics import SensitivityAtSpecificity
|
||||
from keras_core.metrics.confusion_metrics import SpecificityAtSensitivity
|
||||
from keras_core.metrics.confusion_metrics import TrueNegatives
|
||||
from keras_core.metrics.confusion_metrics import TruePositives
|
||||
from keras_core.metrics.hinge_metrics import CategoricalHinge
|
||||
@ -40,7 +44,11 @@ ALL_OBJECTS = {
|
||||
FalseNegatives,
|
||||
FalsePositives,
|
||||
Precision,
|
||||
PrecisionAtRecall,
|
||||
Recall,
|
||||
RecallAtPrecision,
|
||||
SensitivityAtSpecificity,
|
||||
SpecificityAtSensitivity,
|
||||
TrueNegatives,
|
||||
TruePositives,
|
||||
# Hinge
|
||||
|
@ -69,7 +69,7 @@ class _ConfusionMatrixConditionCount(Metric):
|
||||
def get_config(self):
|
||||
config = {"thresholds": self.init_thresholds}
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.FalsePositives")
|
||||
@ -402,7 +402,7 @@ class Precision(Metric):
|
||||
"class_id": self.class_id,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.Recall")
|
||||
@ -543,4 +543,512 @@ class Recall(Metric):
|
||||
"class_id": self.class_id,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
class SensitivitySpecificityBase(Metric):
|
||||
"""Abstract base class for computing sensitivity and specificity.
|
||||
|
||||
For additional information about specificity and sensitivity, see
|
||||
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, value, num_thresholds=200, class_id=None, name=None, dtype=None
|
||||
):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
if num_thresholds <= 0:
|
||||
raise ValueError(
|
||||
"Argument `num_thresholds` must be an integer > 0. "
|
||||
f"Received: num_thresholds={num_thresholds}"
|
||||
)
|
||||
self.value = value
|
||||
self.class_id = class_id
|
||||
|
||||
# Compute `num_thresholds` thresholds in [0, 1]
|
||||
if num_thresholds == 1:
|
||||
self.thresholds = [0.5]
|
||||
self._thresholds_distributed_evenly = False
|
||||
else:
|
||||
thresholds = [
|
||||
(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds - 2)
|
||||
]
|
||||
self.thresholds = [0.0] + thresholds + [1.0]
|
||||
self._thresholds_distributed_evenly = True
|
||||
|
||||
self.true_positives = self.add_variable(
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=initializers.Zeros(),
|
||||
name="true_positives",
|
||||
)
|
||||
self.false_positives = self.add_variable(
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=initializers.Zeros(),
|
||||
name="false_positives",
|
||||
)
|
||||
self.true_negatives = self.add_variable(
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=initializers.Zeros(),
|
||||
name="true_negatives",
|
||||
)
|
||||
self.false_negatives = self.add_variable(
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=initializers.Zeros(),
|
||||
name="false_negatives",
|
||||
)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
"""Accumulates confusion matrix statistics.
|
||||
|
||||
Args:
|
||||
y_true: The ground truth values.
|
||||
y_pred: The predicted values.
|
||||
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`.
|
||||
"""
|
||||
metrics_utils.update_confusion_matrix_variables(
|
||||
{
|
||||
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
||||
},
|
||||
y_true,
|
||||
y_pred,
|
||||
thresholds=self.thresholds,
|
||||
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
||||
class_id=self.class_id,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
def reset_state(self):
|
||||
num_thresholds = len(self.thresholds)
|
||||
self.true_positives.assign(ops.zeros((num_thresholds,)))
|
||||
self.false_positives.assign(ops.zeros((num_thresholds,)))
|
||||
self.true_negatives.assign(ops.zeros((num_thresholds,)))
|
||||
self.false_negatives.assign(ops.zeros((num_thresholds,)))
|
||||
|
||||
def get_config(self):
|
||||
config = {"class_id": self.class_id}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
def _find_max_under_constraint(self, constrained, dependent, predicate):
|
||||
"""Returns the maximum of dependent_statistic that satisfies the
|
||||
constraint.
|
||||
|
||||
Args:
|
||||
constrained: Over these values the constraint is specified. A rank-1
|
||||
tensor.
|
||||
dependent: From these values the maximum that satiesfies the
|
||||
constraint is selected. Values in this tensor and in
|
||||
`constrained` are linked by having the same threshold at each
|
||||
position, hence this tensor must have the same shape.
|
||||
predicate: A binary boolean functor to be applied to arguments
|
||||
`constrained` and `self.value`, e.g. `ops.greater`.
|
||||
|
||||
Returns:
|
||||
maximal dependent value, if no value satisfies the constraint 0.0.
|
||||
"""
|
||||
feasible = ops.array(ops.nonzero(predicate(constrained, self.value)))
|
||||
|
||||
print(feasible)
|
||||
feasible_exists = ops.greater(ops.size(feasible), 0)
|
||||
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
|
||||
|
||||
return ops.where(feasible_exists, max_dependent, 0.0)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.SensitivityAtSpecificity")
|
||||
class SensitivityAtSpecificity(SensitivitySpecificityBase):
|
||||
"""Computes best sensitivity where specificity is >= specified value.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such `(tp / (tp + fn))`.
|
||||
`Specificity` measures the proportion of actual negatives that are correctly
|
||||
identified as such `(tn / (tn + fp))`.
|
||||
|
||||
This metric creates four local variables, `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
||||
compute the sensitivity at the given specificity. The threshold for the
|
||||
given specificity value is computed and used to evaluate the corresponding
|
||||
sensitivity.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
If `class_id` is specified, we calculate precision by considering only the
|
||||
entries in the batch for which `class_id` is above the threshold
|
||||
predictions, and computing the fraction of them for which `class_id` is
|
||||
indeed a correct label.
|
||||
|
||||
For additional information about specificity and sensitivity, see
|
||||
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
||||
|
||||
Args:
|
||||
specificity: A scalar value in range `[0, 1]`.
|
||||
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
||||
use for matching the given specificity.
|
||||
class_id: (Optional) Integer class ID for which we want binary metrics.
|
||||
This must be in the half-open interval `[0, num_classes)`, where
|
||||
`num_classes` is the last dimension of predictions.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.SensitivityAtSpecificity(0.5)
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result()
|
||||
0.5
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 1])
|
||||
>>> m.result()
|
||||
0.333333
|
||||
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
model.compile(
|
||||
optimizer='sgd',
|
||||
loss='mse',
|
||||
metrics=[keras_core.metrics.SensitivityAtSpecificity()])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specificity,
|
||||
num_thresholds=200,
|
||||
class_id=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
):
|
||||
if specificity < 0 or specificity > 1:
|
||||
raise ValueError(
|
||||
"Argument `specificity` must be in the range [0, 1]. "
|
||||
f"Received: specificity={specificity}"
|
||||
)
|
||||
self.specificity = specificity
|
||||
self.num_thresholds = num_thresholds
|
||||
super().__init__(
|
||||
specificity,
|
||||
num_thresholds=num_thresholds,
|
||||
class_id=class_id,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def result(self):
|
||||
sensitivities = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_negatives + backend.epsilon(),
|
||||
)
|
||||
specificities = ops.divide(
|
||||
self.true_negatives,
|
||||
self.true_negatives + self.false_positives + backend.epsilon(),
|
||||
)
|
||||
return self._find_max_under_constraint(
|
||||
specificities, sensitivities, ops.greater_equal
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_thresholds": self.num_thresholds,
|
||||
"specificity": self.specificity,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.SpecificityAtSensitivity")
|
||||
class SpecificityAtSensitivity(SensitivitySpecificityBase):
|
||||
"""Computes best specificity where sensitivity is >= specified value.
|
||||
|
||||
`Sensitivity` measures the proportion of actual positives that are correctly
|
||||
identified as such `(tp / (tp + fn))`.
|
||||
`Specificity` measures the proportion of actual negatives that are correctly
|
||||
identified as such `(tn / (tn + fp))`.
|
||||
|
||||
This metric creates four local variables, `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
||||
compute the specificity at the given sensitivity. The threshold for the
|
||||
given sensitivity value is computed and used to evaluate the corresponding
|
||||
specificity.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
If `class_id` is specified, we calculate precision by considering only the
|
||||
entries in the batch for which `class_id` is above the threshold
|
||||
predictions, and computing the fraction of them for which `class_id` is
|
||||
indeed a correct label.
|
||||
|
||||
For additional information about specificity and sensitivity, see
|
||||
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
|
||||
|
||||
Args:
|
||||
sensitivity: A scalar value in range `[0, 1]`.
|
||||
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
||||
use for matching the given sensitivity.
|
||||
class_id: (Optional) Integer class ID for which we want binary metrics.
|
||||
This must be in the half-open interval `[0, num_classes)`, where
|
||||
`num_classes` is the last dimension of predictions.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.SpecificityAtSensitivity(0.5)
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result()
|
||||
0.66666667
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[1, 1, 2, 2, 2])
|
||||
>>> m.result()
|
||||
0.5
|
||||
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
model.compile(
|
||||
optimizer='sgd',
|
||||
loss='mse',
|
||||
metrics=[keras_core.metrics.SpecificityAtSensitivity()])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sensitivity,
|
||||
num_thresholds=200,
|
||||
class_id=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
):
|
||||
if sensitivity < 0 or sensitivity > 1:
|
||||
raise ValueError(
|
||||
"Argument `sensitivity` must be in the range [0, 1]. "
|
||||
f"Received: sensitivity={sensitivity}"
|
||||
)
|
||||
self.sensitivity = sensitivity
|
||||
self.num_thresholds = num_thresholds
|
||||
super().__init__(
|
||||
sensitivity,
|
||||
num_thresholds=num_thresholds,
|
||||
class_id=class_id,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def result(self):
|
||||
sensitivities = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_negatives + backend.epsilon(),
|
||||
)
|
||||
specificities = ops.divide(
|
||||
self.true_negatives,
|
||||
self.true_negatives + self.false_positives + backend.epsilon(),
|
||||
)
|
||||
return self._find_max_under_constraint(
|
||||
sensitivities, specificities, ops.greater_equal
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_thresholds": self.num_thresholds,
|
||||
"sensitivity": self.sensitivity,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.PrecisionAtRecall")
|
||||
class PrecisionAtRecall(SensitivitySpecificityBase):
|
||||
"""Computes best precision where recall is >= specified value.
|
||||
|
||||
This metric creates four local variables, `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
||||
compute the precision at the given recall. The threshold for the given
|
||||
recall value is computed and used to evaluate the corresponding precision.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
If `class_id` is specified, we calculate precision by considering only the
|
||||
entries in the batch for which `class_id` is above the threshold
|
||||
predictions, and computing the fraction of them for which `class_id` is
|
||||
indeed a correct label.
|
||||
|
||||
Args:
|
||||
recall: A scalar value in range `[0, 1]`.
|
||||
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
|
||||
use for matching the given recall.
|
||||
class_id: (Optional) Integer class ID for which we want binary metrics.
|
||||
This must be in the half-open interval `[0, num_classes)`, where
|
||||
`num_classes` is the last dimension of predictions.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.PrecisionAtRecall(0.5)
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
|
||||
>>> m.result()
|
||||
0.5
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
|
||||
... sample_weight=[2, 2, 2, 1, 1])
|
||||
>>> m.result()
|
||||
0.33333333
|
||||
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
model.compile(
|
||||
optimizer='sgd',
|
||||
loss='mse',
|
||||
metrics=[keras_core.metrics.PrecisionAtRecall(recall=0.8)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
|
||||
):
|
||||
if recall < 0 or recall > 1:
|
||||
raise ValueError(
|
||||
"Argument `recall` must be in the range [0, 1]. "
|
||||
f"Received: recall={recall}"
|
||||
)
|
||||
self.recall = recall
|
||||
self.num_thresholds = num_thresholds
|
||||
super().__init__(
|
||||
value=recall,
|
||||
num_thresholds=num_thresholds,
|
||||
class_id=class_id,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def result(self):
|
||||
recalls = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_negatives + backend.epsilon(),
|
||||
)
|
||||
precisions = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_positives + backend.epsilon(),
|
||||
)
|
||||
return self._find_max_under_constraint(
|
||||
recalls, precisions, ops.greater_equal
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.RecallAtPrecision")
|
||||
class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
"""Computes best recall where precision is >= specified value.
|
||||
|
||||
For a given score-label-distribution the required precision might not
|
||||
be achievable, in this case 0.0 is returned as recall.
|
||||
|
||||
This metric creates four local variables, `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
||||
compute the recall at the given precision. The threshold for the given
|
||||
precision value is computed and used to evaluate the corresponding recall.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
If `class_id` is specified, we calculate precision by considering only the
|
||||
entries in the batch for which `class_id` is above the threshold
|
||||
predictions, and computing the fraction of them for which `class_id` is
|
||||
indeed a correct label.
|
||||
|
||||
Args:
|
||||
precision: A scalar value in range `[0, 1]`.
|
||||
num_thresholds: (Optional) Defaults to 200. The number of thresholds
|
||||
to use for matching the given precision.
|
||||
class_id: (Optional) Integer class ID for which we want binary metrics.
|
||||
This must be in the half-open interval `[0, num_classes)`, where
|
||||
`num_classes` is the last dimension of predictions.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.RecallAtPrecision(0.8)
|
||||
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> m.result()
|
||||
0.5
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> m.result()
|
||||
1.0
|
||||
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
model.compile(
|
||||
optimizer='sgd',
|
||||
loss='mse',
|
||||
metrics=[keras_core.metrics.RecallAtPrecision(precision=0.8)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision,
|
||||
num_thresholds=200,
|
||||
class_id=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
):
|
||||
if precision < 0 or precision > 1:
|
||||
raise ValueError(
|
||||
"Argument `precision` must be in the range [0, 1]. "
|
||||
f"Received: precision={precision}"
|
||||
)
|
||||
self.precision = precision
|
||||
self.num_thresholds = num_thresholds
|
||||
super().__init__(
|
||||
value=precision,
|
||||
num_thresholds=num_thresholds,
|
||||
class_id=class_id,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def result(self):
|
||||
recalls = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_negatives + backend.epsilon(),
|
||||
)
|
||||
precisions = ops.divide(
|
||||
self.true_positives,
|
||||
self.true_positives + self.false_positives + backend.epsilon(),
|
||||
)
|
||||
return self._find_max_under_constraint(
|
||||
precisions, recalls, ops.greater_equal
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_thresholds": self.num_thresholds,
|
||||
"precision": self.precision,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
@ -1,7 +1,9 @@
|
||||
import numpy as np
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.ops.numpy_ops import np_config
|
||||
|
||||
from keras_core import metrics
|
||||
from keras_core import operations as ops
|
||||
from keras_core import testing
|
||||
|
||||
# TODO: remove reliance on this (or alternatively, turn it on by default).
|
||||
@ -689,3 +691,413 @@ class RecallTest(testing.TestCase):
|
||||
self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))
|
||||
self.assertAlmostEqual(1, r_obj.true_positives)
|
||||
self.assertAlmostEqual(3, r_obj.false_negatives)
|
||||
|
||||
|
||||
class SensitivityAtSpecificityTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_config(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(
|
||||
0.4,
|
||||
num_thresholds=100,
|
||||
class_id=12,
|
||||
name="sensitivity_at_specificity_1",
|
||||
)
|
||||
self.assertEqual(s_obj.name, "sensitivity_at_specificity_1")
|
||||
self.assertLen(s_obj.variables, 4)
|
||||
self.assertEqual(s_obj.specificity, 0.4)
|
||||
self.assertEqual(s_obj.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
# Check save and restore config
|
||||
s_obj2 = metrics.SensitivityAtSpecificity.from_config(
|
||||
s_obj.get_config()
|
||||
)
|
||||
self.assertEqual(s_obj2.name, "sensitivity_at_specificity_1")
|
||||
self.assertLen(s_obj2.variables, 4)
|
||||
self.assertEqual(s_obj2.specificity, 0.4)
|
||||
self.assertEqual(s_obj2.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = np.array(inputs, dtype="float32")
|
||||
y_true = np.array(inputs)
|
||||
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_high_specificity(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
self.assertAlmostEqual(0.8, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_low_specificity(self):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_class_id(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
|
||||
|
||||
y_pred = ops.transpose(np.array([pred_values] * 3))
|
||||
y_true = ops.one_hot(label_values, num_classes=3)
|
||||
|
||||
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
|
||||
|
||||
@parameterized.parameters(["bool", "int32", "float32"])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.SensitivityAtSpecificity(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = ops.cast(label_values, dtype=label_dtype)
|
||||
weights = np.array(weight_values)
|
||||
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
self.assertAlmostEqual(0.675, result)
|
||||
|
||||
def test_invalid_specificity(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"`specificity` must be in the range \[0, 1\]."
|
||||
):
|
||||
metrics.SensitivityAtSpecificity(-1)
|
||||
|
||||
def test_invalid_num_thresholds(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 0"
|
||||
):
|
||||
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
class SpecificityAtSensitivityTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_config(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(
|
||||
0.4,
|
||||
num_thresholds=100,
|
||||
class_id=12,
|
||||
name="specificity_at_sensitivity_1",
|
||||
)
|
||||
self.assertEqual(s_obj.name, "specificity_at_sensitivity_1")
|
||||
self.assertLen(s_obj.variables, 4)
|
||||
self.assertEqual(s_obj.sensitivity, 0.4)
|
||||
self.assertEqual(s_obj.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
# Check save and restore config
|
||||
s_obj2 = metrics.SpecificityAtSensitivity.from_config(
|
||||
s_obj.get_config()
|
||||
)
|
||||
self.assertEqual(s_obj2.name, "specificity_at_sensitivity_1")
|
||||
self.assertLen(s_obj2.variables, 4)
|
||||
self.assertEqual(s_obj2.sensitivity, 0.4)
|
||||
self.assertEqual(s_obj2.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = np.array(inputs, dtype="float32")
|
||||
y_true = np.array(inputs)
|
||||
|
||||
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_high_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(1.0)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
self.assertAlmostEqual(0.2, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_low_sensitivity(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_class_id(self):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
|
||||
|
||||
y_pred = ops.transpose(np.array([pred_values] * 3))
|
||||
y_true = ops.one_hot(label_values, num_classes=3)
|
||||
|
||||
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
|
||||
|
||||
@parameterized.parameters(["bool", "int32", "float32"])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.SpecificityAtSensitivity(0.4)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = ops.cast(label_values, dtype=label_dtype)
|
||||
weights = np.array(weight_values)
|
||||
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
self.assertAlmostEqual(0.4, result)
|
||||
|
||||
def test_invalid_sensitivity(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"`sensitivity` must be in the range \[0, 1\]."
|
||||
):
|
||||
metrics.SpecificityAtSensitivity(-1)
|
||||
|
||||
def test_invalid_num_thresholds(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 0"
|
||||
):
|
||||
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
class PrecisionAtRecallTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_config(self):
|
||||
s_obj = metrics.PrecisionAtRecall(
|
||||
0.4, num_thresholds=100, class_id=12, name="precision_at_recall_1"
|
||||
)
|
||||
self.assertEqual(s_obj.name, "precision_at_recall_1")
|
||||
self.assertLen(s_obj.variables, 4)
|
||||
self.assertEqual(s_obj.recall, 0.4)
|
||||
self.assertEqual(s_obj.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
# Check save and restore config
|
||||
s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config())
|
||||
self.assertEqual(s_obj2.name, "precision_at_recall_1")
|
||||
self.assertLen(s_obj2.variables, 4)
|
||||
self.assertEqual(s_obj2.recall, 0.4)
|
||||
self.assertEqual(s_obj2.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = np.array(inputs, dtype="float32")
|
||||
y_true = np.array(inputs)
|
||||
|
||||
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_high_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
# For 0.5 < decision threshold < 0.6.
|
||||
self.assertAlmostEqual(2.0 / 3, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_low_recall(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.6)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
# For 0.2 < decision threshold < 0.5.
|
||||
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_class_id(self):
|
||||
s_obj = metrics.PrecisionAtRecall(0.6, class_id=2)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
|
||||
|
||||
y_pred = ops.transpose(np.array([pred_values] * 3))
|
||||
y_true = ops.one_hot(label_values, num_classes=3)
|
||||
|
||||
# For 0.2 < decision threshold < 0.5.
|
||||
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
|
||||
|
||||
@parameterized.parameters(["bool", "int32", "float32"])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.PrecisionAtRecall(7.0 / 8)
|
||||
pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]
|
||||
label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
|
||||
weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]
|
||||
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = ops.cast(label_values, dtype=label_dtype)
|
||||
weights = np.array(weight_values)
|
||||
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
# For 0.0 < decision threshold < 0.2.
|
||||
self.assertAlmostEqual(0.7, result)
|
||||
|
||||
def test_invalid_sensitivity(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"`recall` must be in the range \[0, 1\]."
|
||||
):
|
||||
metrics.PrecisionAtRecall(-1)
|
||||
|
||||
def test_invalid_num_thresholds(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 0"
|
||||
):
|
||||
metrics.PrecisionAtRecall(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_config(self):
|
||||
s_obj = metrics.RecallAtPrecision(
|
||||
0.4, num_thresholds=100, class_id=12, name="recall_at_precision_1"
|
||||
)
|
||||
self.assertEqual(s_obj.name, "recall_at_precision_1")
|
||||
self.assertLen(s_obj.variables, 4)
|
||||
self.assertEqual(s_obj.precision, 0.4)
|
||||
self.assertEqual(s_obj.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
# Check save and restore config
|
||||
s_obj2 = metrics.RecallAtPrecision.from_config(s_obj.get_config())
|
||||
self.assertEqual(s_obj2.name, "recall_at_precision_1")
|
||||
self.assertLen(s_obj2.variables, 4)
|
||||
self.assertEqual(s_obj2.precision, 0.4)
|
||||
self.assertEqual(s_obj2.num_thresholds, 100)
|
||||
self.assertEqual(s_obj.class_id, 12)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
s_obj = metrics.RecallAtPrecision(0.7)
|
||||
inputs = np.random.randint(0, 2, size=(100, 1))
|
||||
y_pred = np.array(inputs, dtype="float32")
|
||||
y_true = np.array(inputs)
|
||||
|
||||
self.assertAlmostEqual(1, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_high_precision(self):
|
||||
s_obj = metrics.RecallAtPrecision(0.75)
|
||||
pred_values = [
|
||||
0.05,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.3,
|
||||
0.35,
|
||||
0.4,
|
||||
0.45,
|
||||
0.5,
|
||||
0.6,
|
||||
0.9,
|
||||
0.95,
|
||||
]
|
||||
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
|
||||
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
|
||||
# 1].
|
||||
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
|
||||
# 1/6].
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
# The precision 0.75 can be reached at thresholds 0.4<=t<0.45.
|
||||
self.assertAlmostEqual(0.5, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_low_precision(self):
|
||||
s_obj = metrics.RecallAtPrecision(2.0 / 3)
|
||||
pred_values = [
|
||||
0.05,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.3,
|
||||
0.35,
|
||||
0.4,
|
||||
0.45,
|
||||
0.5,
|
||||
0.6,
|
||||
0.9,
|
||||
0.95,
|
||||
]
|
||||
label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]
|
||||
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
|
||||
# 1].
|
||||
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
|
||||
# 1/6].
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
|
||||
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
|
||||
|
||||
def test_unweighted_class_id(self):
|
||||
s_obj = metrics.RecallAtPrecision(2.0 / 3, class_id=2)
|
||||
pred_values = [
|
||||
0.05,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.3,
|
||||
0.35,
|
||||
0.4,
|
||||
0.45,
|
||||
0.5,
|
||||
0.6,
|
||||
0.9,
|
||||
0.95,
|
||||
]
|
||||
label_values = [0, 2, 0, 0, 0, 2, 2, 0, 2, 2, 0, 2]
|
||||
# precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,
|
||||
# 1].
|
||||
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
|
||||
# 1/6].
|
||||
y_pred = ops.transpose(np.array([pred_values] * 3))
|
||||
y_true = ops.one_hot(label_values, num_classes=3)
|
||||
|
||||
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
|
||||
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
|
||||
|
||||
@parameterized.parameters(["bool", "int32", "float32"])
|
||||
def test_weighted(self, label_dtype):
|
||||
s_obj = metrics.RecallAtPrecision(0.75)
|
||||
pred_values = [0.1, 0.2, 0.3, 0.5, 0.6, 0.9, 0.9]
|
||||
label_values = [0, 1, 0, 0, 0, 1, 1]
|
||||
weight_values = [1, 2, 1, 2, 1, 2, 1]
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = ops.cast(label_values, dtype=label_dtype)
|
||||
weights = np.array(weight_values)
|
||||
|
||||
result = s_obj(y_true, y_pred, sample_weight=weights)
|
||||
self.assertAlmostEqual(0.6, result)
|
||||
|
||||
def test_unachievable_precision(self):
|
||||
s_obj = metrics.RecallAtPrecision(2.0 / 3)
|
||||
pred_values = [0.1, 0.2, 0.3, 0.9]
|
||||
label_values = [1, 1, 0, 0]
|
||||
y_pred = np.array(pred_values, dtype="float32")
|
||||
y_true = np.array(label_values)
|
||||
|
||||
# The highest possible precision is 1/2 which is below the required
|
||||
# value, expect 0 recall.
|
||||
self.assertAlmostEqual(0, s_obj(y_true, y_pred))
|
||||
|
||||
def test_invalid_sensitivity(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"`precision` must be in the range \[0, 1\]."
|
||||
):
|
||||
metrics.RecallAtPrecision(-1)
|
||||
|
||||
def test_invalid_num_thresholds(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 0"
|
||||
):
|
||||
metrics.RecallAtPrecision(0.4, num_thresholds=-1)
|
||||
|
@ -74,8 +74,7 @@ def _update_confusion_matrix_variables_optimized(
|
||||
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||||
Given a prediction value p, we can map it to its bucket by
|
||||
bucket_index(p) = floor( p * (num_thresholds - 1) )
|
||||
so we can use ops.unsorted_segmenet_sum() to update the buckets in one
|
||||
pass.
|
||||
so we can use ops.segment_sum() to update the buckets in one pass.
|
||||
|
||||
Consider following example:
|
||||
y_true = [0, 0, 1, 1]
|
||||
@ -93,7 +92,7 @@ def _update_confusion_matrix_variables_optimized(
|
||||
# Note the second item "1.0" is floored to 0, since the value need to be
|
||||
# strictly larger than the bucket lower bound.
|
||||
# In the implementation, we use ops.ceil() - 1 to achieve this.
|
||||
tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices,
|
||||
tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,
|
||||
num_segments=num_thresholds)
|
||||
= [1, 1, 0]
|
||||
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
|
||||
@ -109,35 +108,35 @@ def _update_confusion_matrix_variables_optimized(
|
||||
O(T * N).
|
||||
|
||||
Args:
|
||||
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
||||
and corresponding variables to update as values.
|
||||
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
|
||||
cast to `bool`.
|
||||
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
||||
in the range `[0, 1]`.
|
||||
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
|
||||
It need to be evenly distributed (the diff between each element need to
|
||||
be the same).
|
||||
multi_label: Optional boolean indicating whether multidimensional
|
||||
prediction/labels should be treated as multilabel responses, or
|
||||
flattened into a single label. When True, the valus of
|
||||
`variables_to_update` must have a second dimension equal to the number
|
||||
of labels in y_true and y_pred, and those tensors must not be
|
||||
RaggedTensors.
|
||||
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
|
||||
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
||||
must be either `1`, or the same as the corresponding `y_true`
|
||||
dimension).
|
||||
label_weights: Optional tensor of non-negative weights for multilabel
|
||||
data. The weights are applied when calculating TP, FP, FN, and TN
|
||||
without explicit multilabel handling (i.e. when the data is to be
|
||||
flattened).
|
||||
thresholds_with_epsilon: Optional boolean indicating whether the leading
|
||||
and tailing thresholds has any epsilon added for floating point
|
||||
imprecisions. It will change how we handle the leading and tailing
|
||||
bucket.
|
||||
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid
|
||||
keys and corresponding variables to update as values.
|
||||
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
|
||||
cast to `bool`.
|
||||
y_pred: A floating point `Tensor` of arbitrary shape and whose values
|
||||
are in the range `[0, 1]`.
|
||||
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
|
||||
It need to be evenly distributed (the diff between each element need
|
||||
to be the same).
|
||||
multi_label: Optional boolean indicating whether multidimensional
|
||||
prediction/labels should be treated as multilabel responses, or
|
||||
flattened into a single label. When True, the valus of
|
||||
`variables_to_update` must have a second dimension equal to the
|
||||
number of labels in y_true and y_pred, and those tensors must not be
|
||||
RaggedTensors.
|
||||
sample_weights: Optional `Tensor` whose rank is either 0, or the same
|
||||
rank as `y_true`, and must be broadcastable to `y_true` (i.e., all
|
||||
dimensions must be either `1`, or the same as the corresponding
|
||||
`y_true` dimension).
|
||||
label_weights: Optional tensor of non-negative weights for multilabel
|
||||
data. The weights are applied when calculating TP, FP, FN, and TN
|
||||
without explicit multilabel handling (i.e. when the data is to be
|
||||
flattened).
|
||||
thresholds_with_epsilon: Optional boolean indicating whether the leading
|
||||
and tailing thresholds has any epsilon added for floating point
|
||||
imprecisions. It will change how we handle the leading and tailing
|
||||
bucket.
|
||||
"""
|
||||
num_thresholds = thresholds.shape.as_list()[0]
|
||||
num_thresholds = ops.shape(thresholds)[0]
|
||||
|
||||
if sample_weights is None:
|
||||
sample_weights = 1.0
|
||||
@ -160,7 +159,7 @@ def _update_confusion_matrix_variables_optimized(
|
||||
|
||||
# We shouldn't need this, but in case there are predict value that is out of
|
||||
# the range of [0.0, 1.0]
|
||||
y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0)
|
||||
y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)
|
||||
|
||||
y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype)
|
||||
if not multi_label:
|
||||
@ -198,7 +197,7 @@ def _update_confusion_matrix_variables_optimized(
|
||||
label_and_bucket_index[0],
|
||||
label_and_bucket_index[1],
|
||||
)
|
||||
return ops.unsorted_segmenet_sum(
|
||||
return ops.segment_sum(
|
||||
data=label,
|
||||
segment_ids=bucket_index,
|
||||
num_segments=num_thresholds,
|
||||
@ -210,21 +209,21 @@ def _update_confusion_matrix_variables_optimized(
|
||||
fp_bucket_v = ops.vectorized_map(
|
||||
gather_bucket, (false_labels, bucket_indices), warn=False
|
||||
)
|
||||
tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1))
|
||||
fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1))
|
||||
tp = ops.transpose(ops.cumsum(ops.flip(tp_bucket_v), axis=1))
|
||||
fp = ops.transpose(ops.cumsum(ops.flip(fp_bucket_v), axis=1))
|
||||
else:
|
||||
tp_bucket_v = ops.unsorted_segmenet_sum(
|
||||
tp_bucket_v = ops.segment_sum(
|
||||
data=true_labels,
|
||||
segment_ids=bucket_indices,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
fp_bucket_v = ops.unsorted_segmenet_sum(
|
||||
fp_bucket_v = ops.segment_sum(
|
||||
data=false_labels,
|
||||
segment_ids=bucket_indices,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
tp = ops.cumsum(tp_bucket_v, reverse=True)
|
||||
fp = ops.cumsum(fp_bucket_v, reverse=True)
|
||||
tp = ops.cumsum(ops.flip(tp_bucket_v))
|
||||
fp = ops.cumsum(ops.flip(fp_bucket_v))
|
||||
|
||||
# fn = sum(true_labels) - tp
|
||||
# tn = sum(false_labels) - fp
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
segment_sum
|
||||
top_k
|
||||
in_top_k
|
||||
"""
|
||||
|
||||
from keras_core import backend
|
||||
@ -10,14 +9,16 @@ from keras_core.operations.operation import Operation
|
||||
|
||||
|
||||
class SegmentSum(Operation):
|
||||
def call(self, x, segment_ids, num_segments=None, sorted=False):
|
||||
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
|
||||
def call(self, data, segment_ids, num_segments=None, sorted=False):
|
||||
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
|
||||
|
||||
|
||||
def segment_sum(x, segment_ids, num_segments=None, sorted=False):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return SegmentSum().symbolic_call(x, segment_ids, num_segments, sorted)
|
||||
return backend.math.segment_sum(x, segment_ids, num_segments, sorted)
|
||||
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
||||
if any_symbolic_tensors((data,)):
|
||||
return SegmentSum().symbolic_call(
|
||||
data, segment_ids, num_segments, sorted
|
||||
)
|
||||
return backend.math.segment_sum(data, segment_ids, num_segments, sorted)
|
||||
|
||||
|
||||
class TopK(Operation):
|
||||
|
@ -1881,16 +1881,19 @@ def matmul(x1, x2):
|
||||
|
||||
|
||||
class Max(Operation):
|
||||
def __init__(self, axis=None, keepdims=False):
|
||||
def __init__(self, axis=None, keepdims=False, initial=None):
|
||||
super().__init__()
|
||||
if isinstance(axis, int):
|
||||
self.axis = [axis]
|
||||
else:
|
||||
self.axis = axis
|
||||
self.keepdims = keepdims
|
||||
self.initial = initial
|
||||
|
||||
def call(self, x):
|
||||
return backend.numpy.max(x, axis=self.axis, keepdims=self.keepdims)
|
||||
return backend.numpy.max(
|
||||
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
|
||||
)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(
|
||||
@ -1899,10 +1902,12 @@ class Max(Operation):
|
||||
)
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
def max(x, axis=None, keepdims=False, initial=None):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Max(axis=axis, keepdims=keepdims).symbolic_call(x)
|
||||
return backend.numpy.max(x, axis=axis, keepdims=keepdims)
|
||||
return Max(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
|
||||
x
|
||||
)
|
||||
return backend.numpy.max(x, axis=axis, keepdims=keepdims, initial=initial)
|
||||
|
||||
|
||||
class Maximum(Operation):
|
||||
@ -1961,15 +1966,18 @@ def meshgrid(*x, indexing="xy"):
|
||||
|
||||
|
||||
class Min(Operation):
|
||||
def __init__(self, axis=None, keepdims=False):
|
||||
def __init__(self, axis=None, keepdims=False, initial=None):
|
||||
if isinstance(axis, int):
|
||||
self.axis = [axis]
|
||||
else:
|
||||
self.axis = axis
|
||||
self.keepdims = keepdims
|
||||
self.initial = initial
|
||||
|
||||
def call(self, x):
|
||||
return backend.numpy.min(x, axis=self.axis, keepdims=self.keepdims)
|
||||
return backend.numpy.min(
|
||||
x, axis=self.axis, keepdims=self.keepdims, initial=self.initial
|
||||
)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(
|
||||
@ -1978,10 +1986,12 @@ class Min(Operation):
|
||||
)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False):
|
||||
def min(x, axis=None, keepdims=False, initial=None):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Min(axis=axis, keepdims=keepdims).symbolic_call(x)
|
||||
return backend.numpy.min(x, axis=axis, keepdims=keepdims)
|
||||
return Min(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(
|
||||
x
|
||||
)
|
||||
return backend.numpy.min(x, axis=axis, keepdims=keepdims, initial=initial)
|
||||
|
||||
|
||||
class Minimum(Operation):
|
||||
|
Loading…
Reference in New Issue
Block a user