Add IoU metrics: IoU, BinaryIoU, OneHotIoU, OneHotMeanIoU, (#127)

* Begin iou metrics

* Attempt conversion without confusion matrix backend

* Working ioumetrics, missing scatter op

* Formatting

* Docstring formatting

* Add IoU metrics to manifest

* Update with scatter op

* Fix scatter op for repeated indices

* Formatting

* Supress warning for core operation import

* Formatting
This commit is contained in:
Gabriel Rasskin 2023-05-15 20:50:55 -04:00 committed by Francois Chollet
parent e989cb7a05
commit a426717f10
7 changed files with 1270 additions and 1 deletions

@ -156,4 +156,4 @@ def vectorized_map(function, elements):
def scatter(indices, values, shape): def scatter(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype) zeros = jnp.zeros(shape, values.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0)) key = tuple(jnp.moveaxis(indices, -1, 0))
return zeros.at[key].set(values) return zeros.at[key].add(values)

@ -21,6 +21,11 @@ from keras_core.metrics.f_score_metrics import FBetaScore
from keras_core.metrics.hinge_metrics import CategoricalHinge from keras_core.metrics.hinge_metrics import CategoricalHinge
from keras_core.metrics.hinge_metrics import Hinge from keras_core.metrics.hinge_metrics import Hinge
from keras_core.metrics.hinge_metrics import SquaredHinge from keras_core.metrics.hinge_metrics import SquaredHinge
from keras_core.metrics.iou_metrics import BinaryIoU
from keras_core.metrics.iou_metrics import IoU
from keras_core.metrics.iou_metrics import MeanIoU
from keras_core.metrics.iou_metrics import OneHotIoU
from keras_core.metrics.iou_metrics import OneHotMeanIoU
from keras_core.metrics.metric import Metric from keras_core.metrics.metric import Metric
from keras_core.metrics.probabilistic_metrics import BinaryCrossentropy from keras_core.metrics.probabilistic_metrics import BinaryCrossentropy
from keras_core.metrics.probabilistic_metrics import CategoricalCrossentropy from keras_core.metrics.probabilistic_metrics import CategoricalCrossentropy
@ -89,6 +94,12 @@ ALL_OBJECTS = {
# F-Score # F-Score
F1Score, F1Score,
FBetaScore, FBetaScore,
# IoU
IoU,
BinaryIoU,
MeanIoU,
OneHotIoU,
OneHotMeanIoU,
} }
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
ALL_OBJECTS_DICT.update( ALL_OBJECTS_DICT.update(

@ -0,0 +1,748 @@
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.metrics.metric import Metric
from keras_core.metrics.metrics_utils import confusion_matrix
class _IoUBase(Metric):
"""Computes the confusion matrix for Intersection-Over-Union metrics.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
From IoUs of individual classes, the MeanIoU can be computed as the mean of
the individual IoUs.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Args:
num_classes: The possible number of labels the prediction task can have.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
ignore_class: Optional integer. The ID of a class to be ignored during
metric computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
sparse_y_true: Whether labels are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
sparse_y_pred: Whether predictions are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
axis: (Optional) -1 is the dimension containing the logits.
Defaults to `-1`.
"""
def __init__(
self,
num_classes,
name=None,
dtype=None,
ignore_class=None,
sparse_y_true=True,
sparse_y_pred=True,
axis=-1,
):
# defaulting to float32 to avoid issues with confusion matrix
super().__init__(name=name, dtype=dtype or "float32")
self.num_classes = num_classes
self.ignore_class = ignore_class
self.sparse_y_true = sparse_y_true
self.sparse_y_pred = sparse_y_pred
self.axis = axis
self.total_cm = self.add_variable(
name="total_confusion_matrix",
shape=(num_classes, num_classes),
initializer=initializers.Zeros(),
)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Can
be a `Tensor` whose rank is either 0, or the same as `y_true`,
and must be broadcastable to `y_true`. Defaults to `1`.
Returns:
Update op.
"""
if not self.sparse_y_true:
y_true = ops.argmax(y_true, axis=self.axis)
if not self.sparse_y_pred:
y_pred = ops.argmax(y_pred, axis=self.axis)
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
# Flatten the input if its rank > 1.
if len(y_pred.shape) > 1:
y_pred = ops.reshape(y_pred, [-1])
if len(y_true.shape) > 1:
y_true = ops.reshape(y_true, [-1])
if sample_weight is None:
sample_weight = 1
sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)
if len(sample_weight.shape) > 1:
sample_weight = ops.reshape(sample_weight, [-1])
sample_weight = ops.broadcast_to(sample_weight, y_true.shape)
if self.ignore_class is not None:
ignore_class = ops.convert_to_tensor(
self.ignore_class, y_true.dtype
)
valid_mask = ops.not_equal(y_true, ignore_class)
y_true = y_true[valid_mask]
y_pred = y_pred[valid_mask]
if sample_weight is not None:
sample_weight = sample_weight[valid_mask]
y_pred = ops.cast(y_pred, dtype=self.dtype)
y_true = ops.cast(y_true, dtype=self.dtype)
sample_weight = ops.cast(sample_weight, dtype=self.dtype)
current_cm = confusion_matrix(
y_true,
y_pred,
self.num_classes,
weights=sample_weight,
dtype="float32",
)
return self.total_cm.assign(self.total_cm + current_cm)
def reset_state(self):
self.total_cm.assign(
ops.zeros(self.total_cm.shape, dtype=self.total_cm.dtype)
)
@keras_core_export("keras_core.metrics.IoU")
class IoU(_IoUBase):
"""Computes the Intersection-Over-Union metric for specific target classes.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Note, this class first computes IoUs for all individual classes, then
returns the mean of IoUs for the classes that are specified by
`target_class_ids`. If `target_class_ids` has only one id value, the IoU of
that specific class is returned.
Args:
num_classes: The possible number of labels the prediction task can have.
target_class_ids: A tuple or list of target class ids for which the
metric is returned. To compute IoU for a specific class, a list
(or tuple) of a single id value should be provided.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
ignore_class: Optional integer. The ID of a class to be ignored during
metric computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
sparse_y_true: Whether labels are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
sparse_y_pred: Whether predictions are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
axis: (Optional) -1 is the dimension containing the logits.
Defaults to `-1`.
Examples:
Standalone usage:
>>> # cm = [[1, 1],
>>> # [1, 1]]
>>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
>>> # iou = true_positives / (sum_row + sum_col - true_positives))
>>> # iou = [0.33, 0.33]
>>> m = keras_core.metrics.IoU(num_classes=2, target_class_ids=[0])
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
>>> m.result()
0.33333334
>>> m.reset_state()
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
... sample_weight=[0.3, 0.3, 0.3, 0.1])
>>> # cm = [[0.3, 0.3],
>>> # [0.3, 0.1]]
>>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4],
>>> # true_positives = [0.3, 0.1]
>>> # iou = [0.33, 0.14]
>>> m.result()
0.33333334
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.IoU(num_classes=2, target_class_ids=[0])])
```
"""
def __init__(
self,
num_classes,
target_class_ids,
name=None,
dtype=None,
ignore_class=None,
sparse_y_true=True,
sparse_y_pred=True,
axis=-1,
):
super().__init__(
name=name,
num_classes=num_classes,
ignore_class=ignore_class,
sparse_y_true=sparse_y_true,
sparse_y_pred=sparse_y_pred,
axis=axis,
dtype=dtype,
)
if max(target_class_ids) >= num_classes:
raise ValueError(
f"Target class id {max(target_class_ids)} "
"is out of range, which is "
f"[{0}, {num_classes})."
)
self.target_class_ids = list(target_class_ids)
def result(self):
"""Compute the intersection-over-union via the confusion matrix."""
sum_over_row = ops.cast(
ops.sum(self.total_cm, axis=0), dtype=self.dtype
)
sum_over_col = ops.cast(
ops.sum(self.total_cm, axis=1), dtype=self.dtype
)
true_positives = ops.cast(ops.diag(self.total_cm), dtype=self.dtype)
# sum_over_row + sum_over_col =
# 2 * true_positives + false_positives + false_negatives.
denominator = sum_over_row + sum_over_col - true_positives
target_class_ids = ops.convert_to_tensor(
self.target_class_ids, dtype="int32"
)
# Only keep the target classes
true_positives = ops.take_along_axis(
true_positives, target_class_ids, axis=-1
)
denominator = ops.take_along_axis(
denominator, target_class_ids, axis=-1
)
# If the denominator is 0, we need to ignore the class.
num_valid_entries = ops.sum(
ops.cast(ops.greater(denominator, 1e-9), dtype=self.dtype)
)
iou = ops.divide(true_positives, denominator + backend.epsilon())
return ops.divide(
ops.sum(iou, axis=self.axis), num_valid_entries + backend.epsilon()
)
def get_config(self):
config = {
"num_classes": self.num_classes,
"target_class_ids": self.target_class_ids,
"ignore_class": self.ignore_class,
"sparse_y_true": self.sparse_y_true,
"sparse_y_pred": self.sparse_y_pred,
"axis": self.axis,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_core_export("keras_core.metrics.BinaryIoU")
class BinaryIoU(IoU):
"""Computes the Intersection-Over-Union metric for class 0 and/or 1.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
This class can be used to compute IoUs for a binary classification task
where the predictions are provided as logits. First a `threshold` is applied
to the predicted values such that those that are below the `threshold` are
converted to class 0 and those that are above the `threshold` are converted
to class 1.
IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes
that are specified by `target_class_ids` is returned.
Note: with `threshold=0`, this metric has the same behavior as `IoU`.
Args:
target_class_ids: A tuple or list of target class ids for which the
metric is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With
`[0]` (or `[1]`), the IoU metric for class 0 (or class 1,
respectively) is returned. With `[0, 1]`, the mean of IoUs for the
two classes is returned.
threshold: A threshold that applies to the prediction logits to convert
them to either predicted class 0 if the logit is below `threshold`
or predicted class 1 if the logit is above `threshold`.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Examples:
Standalone usage:
>>> m = keras_core.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
>>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7])
>>> m.result()
0.33333334
>>> m.reset_state()
>>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7],
... sample_weight=[0.2, 0.3, 0.4, 0.1])
>>> # cm = [[0.2, 0.4],
>>> # [0.3, 0.1]]
>>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5],
>>> # true_positives = [0.2, 0.1]
>>> # iou = [0.222, 0.125]
>>> m.result()
0.17361112
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.BinaryIoU(
target_class_ids=[0],
threshold=0.5
)]
)
```
"""
def __init__(
self,
target_class_ids=(0, 1),
threshold=0.5,
name=None,
dtype=None,
):
super().__init__(
num_classes=2,
target_class_ids=target_class_ids,
name=name,
dtype=dtype,
)
self.threshold = threshold
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix statistics.
Before the confusion matrix is updated, the predicted values are
thresholded to be:
0 for values that are smaller than the `threshold`
1 for values that are larger or equal to the `threshold`
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Can
be a `Tensor` whose rank is either 0, or the same as `y_true`,
and must be broadcastable to `y_true`. Defaults to `1`.
Returns:
Update op.
"""
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
y_pred = ops.cast(y_pred >= self.threshold, self.dtype)
return super().update_state(y_true, y_pred, sample_weight)
def get_config(self):
return {
"target_class_ids": self.target_class_ids,
"threshold": self.threshold,
"name": self.name,
"dtype": self._dtype,
}
@keras_core_export("keras_core.metrics.MeanIoU")
class MeanIoU(IoU):
"""Computes the mean Intersection-Over-Union metric.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Note that this class first computes IoUs for all individual classes, then
returns the mean of these values.
Args:
num_classes: The possible number of labels the prediction task can have.
This value must be provided, since a confusion matrix of dimension =
[num_classes, num_classes] will be allocated.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
ignore_class: Optional integer. The ID of a class to be ignored during
metric computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
sparse_y_true: Whether labels are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
sparse_y_pred: Whether predictions are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
Examples:
Standalone usage:
>>> # cm = [[1, 1],
>>> # [1, 1]]
>>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
>>> # iou = true_positives / (sum_row + sum_col - true_positives))
>>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
>>> m = keras_core.metrics.MeanIoU(num_classes=2)
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
>>> m.result()
0.33333334
>>> m.reset_state()
>>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
... sample_weight=[0.3, 0.3, 0.3, 0.1])
>>> m.result().numpy()
0.23809525
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.MeanIoU(num_classes=2)])
```
"""
def __init__(
self,
num_classes,
name=None,
dtype=None,
ignore_class=None,
sparse_y_true=True,
sparse_y_pred=True,
axis=-1,
):
target_class_ids = list(range(num_classes))
super().__init__(
name=name,
num_classes=num_classes,
target_class_ids=target_class_ids,
axis=axis,
dtype=dtype,
ignore_class=ignore_class,
sparse_y_true=sparse_y_true,
sparse_y_pred=sparse_y_pred,
)
def get_config(self):
return {
"num_classes": self.num_classes,
"name": self.name,
"dtype": self._dtype,
"ignore_class": self.ignore_class,
"sparse_y_true": self.sparse_y_true,
"sparse_y_pred": self.sparse_y_pred,
"axis": self.axis,
}
@keras_core_export("keras_core.metrics.OneHotIoU")
class OneHotIoU(IoU):
"""Computes the Intersection-Over-Union metric for one-hot encoded labels.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
This class can be used to compute IoU for multi-class classification tasks
where the labels are one-hot encoded (the last axis should have one
dimension per class). Note that the predictions should also have the same
shape. To compute the IoU, first the labels and predictions are converted
back into integer format by taking the argmax over the class axis. Then the
same computation steps as for the base `IoU` class apply.
Note, if there is only one channel in the labels and predictions, this class
is the same as class `IoU`. In this case, use `IoU` instead.
Also, make sure that `num_classes` is equal to the number of classes in the
data, to avoid a "labels out of bound" error when the confusion matrix is
computed.
Args:
num_classes: The possible number of labels the prediction task can have.
target_class_ids: A tuple or list of target class ids for which the
metric is returned. To compute IoU for a specific class, a list
(or tuple) of a single id value should be provided.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
ignore_class: Optional integer. The ID of a class to be ignored during
metric computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
sparse_y_pred: Whether predictions are encoded using integers or
dense floating point vectors. If `False`, the `argmax` function
is used to determine each sample's most likely associated label.
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
Examples:
Standalone usage:
>>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
>>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
... [0.1, 0.4, 0.5]])
>>> sample_weight = [0.1, 0.2, 0.3, 0.4]
>>> m = keras_core.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
>>> m.update_state(
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
>>> # cm = [[0, 0, 0.2+0.4],
>>> # [0.3, 0, 0],
>>> # [0, 0, 0.1]]
>>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
>>> # true_positives = [0, 0, 0.1]
>>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
>>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
>>> m.result()
0.071
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.OneHotIoU(
num_classes=3,
target_class_id=[1]
)]
)
```
"""
def __init__(
self,
num_classes,
target_class_ids,
name=None,
dtype=None,
ignore_class=None,
sparse_y_pred=False,
axis=-1,
):
super().__init__(
num_classes=num_classes,
target_class_ids=target_class_ids,
name=name,
dtype=dtype,
ignore_class=ignore_class,
sparse_y_true=False,
sparse_y_pred=sparse_y_pred,
axis=axis,
)
def get_config(self):
return {
"num_classes": self.num_classes,
"target_class_ids": self.target_class_ids,
"name": self.name,
"dtype": self._dtype,
"ignore_class": self.ignore_class,
"sparse_y_pred": self.sparse_y_pred,
"axis": self.axis,
}
@keras_core_export("keras_core.metrics.OneHotMeanIoU")
class OneHotMeanIoU(MeanIoU):
"""Computes mean Intersection-Over-Union metric for one-hot encoded labels.
Formula:
```python
iou = true_positives / (true_positives + false_positives + false_negatives)
```
Intersection-Over-Union is a common evaluation metric for semantic image
segmentation.
To compute IoUs, the predictions are accumulated in a confusion matrix,
weighted by `sample_weight` and the metric is then calculated from it.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
This class can be used to compute the mean IoU for multi-class
classification tasks where the labels are one-hot encoded (the last axis
should have one dimension per class). Note that the predictions should also
have the same shape. To compute the mean IoU, first the labels and
predictions are converted back into integer format by taking the argmax over
the class axis. Then the same computation steps as for the base `MeanIoU`
class apply.
Note, if there is only one channel in the labels and predictions, this class
is the same as class `MeanIoU`. In this case, use `MeanIoU` instead.
Also, make sure that `num_classes` is equal to the number of classes in the
data, to avoid a "labels out of bound" error when the confusion matrix is
computed.
Args:
num_classes: The possible number of labels the prediction task can have.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
ignore_class: Optional integer. The ID of a class to be ignored during
metric computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
sparse_y_pred: Whether predictions are encoded using natural numbers or
probability distribution vectors. If `False`, the `argmax`
function will be used to determine each sample's most likely
associated label.
axis: (Optional) The dimension containing the logits. Defaults to `-1`.
Examples:
Standalone usage:
>>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
>>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
... [0.1, 0.4, 0.5]])
>>> sample_weight = [0.1, 0.2, 0.3, 0.4]
>>> m = keras_core.metrics.OneHotMeanIoU(num_classes=3)
>>> m.update_state(
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
>>> # cm = [[0, 0, 0.2+0.4],
>>> # [0.3, 0, 0],
>>> # [0, 0, 0.1]]
>>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
>>> # true_positives = [0, 0, 0.1]
>>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
>>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3
>>> m.result()
0.048
Usage with `compile()` API:
```python
model.compile(
optimizer='sgd',
loss='mse',
metrics=[keras_core.metrics.OneHotMeanIoU(num_classes=3)])
```
"""
def __init__(
self,
num_classes,
name=None,
dtype=None,
ignore_class=None,
sparse_y_pred=False,
axis=-1,
):
super().__init__(
num_classes=num_classes,
axis=axis,
name=name,
dtype=dtype,
ignore_class=ignore_class,
sparse_y_true=False,
sparse_y_pred=sparse_y_pred,
)
def get_config(self):
return {
"num_classes": self.num_classes,
"name": self.name,
"dtype": self._dtype,
"ignore_class": self.ignore_class,
"sparse_y_pred": self.sparse_y_pred,
"axis": self.axis,
}

@ -0,0 +1,430 @@
import numpy as np
from keras_core import testing
from keras_core.metrics import iou_metrics as metrics
class IoUTest(testing.TestCase):
def test_config(self):
obj = metrics.IoU(
num_classes=2, target_class_ids=[1, 0], name="iou_class_1_0"
)
self.assertEqual(obj.name, "iou_class_1_0")
self.assertEqual(obj.num_classes, 2)
self.assertEqual(obj.target_class_ids, [1, 0])
obj2 = metrics.IoU.from_config(obj.get_config())
self.assertEqual(obj2.name, "iou_class_1_0")
self.assertEqual(obj2.num_classes, 2)
self.assertEqual(obj2.target_class_ids, [1, 0])
def test_unweighted(self):
y_pred = [0, 1, 0, 1]
y_true = [0, 0, 1, 1]
obj = metrics.IoU(
num_classes=2, target_class_ids=[0, 1], dtype="float32"
)
result = obj(y_true, y_pred)
# cm = [[1, 1],
# [1, 1]]
# sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted(self):
y_pred = np.array([0, 1, 0, 1], dtype=np.float32)
y_true = np.array([0, 0, 1, 1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])
obj = metrics.IoU(
num_classes=2, target_class_ids=[1, 0], dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
# cm = [[0.2, 0.3],
# [0.4, 0.1]]
# sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,
# 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.1 / (0.4 + 0.5 - 0.1) + 0.2 / (0.6 + 0.5 - 0.2)
) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_multi_dim_input(self):
y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32)
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])
obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])
result = obj(y_true, y_pred, sample_weight=sample_weight)
# cm = [[0.2, 0.3],
# [0.4, 0.1]]
# sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,
# 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_zero_valid_entries(self):
obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])
self.assertAllClose(obj.result(), 0, atol=1e-3)
def test_zero_and_non_zero_entries(self):
y_pred = np.array([1], dtype=np.float32)
y_true = np.array([1])
obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])
result = obj(y_true, y_pred)
# cm = [[0, 0],
# [0, 1]]
# sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (1 / (1 + 1 - 1)) / 1
self.assertAllClose(result, expected_result, atol=1e-3)
class BinaryIoUTest(testing.TestCase):
def test_config(self):
obj = metrics.BinaryIoU(
target_class_ids=[1, 0], threshold=0.1, name="iou_class_1_0"
)
self.assertEqual(obj.name, "iou_class_1_0")
self.assertAlmostEqual(obj.threshold, 0.1)
self.assertEqual(obj.target_class_ids, [1, 0])
obj2 = metrics.BinaryIoU.from_config(obj.get_config())
self.assertEqual(obj.name, "iou_class_1_0")
self.assertAlmostEqual(obj2.threshold, 0.1)
self.assertEqual(obj.target_class_ids, [1, 0])
def test_different_thresholds_weighted(self):
y_true = [0, 1, 0, 1]
y_pred = [0.1, 0.2, 0.4, 0.7]
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])
# with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1]
# cm = [[0.2, 0.4],
# [0.3, 0.1]]
# sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,
# 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)
sample_weight = np.array([0.1, 0.2, 0.4, 0.3])
# with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1]
# cm = [[0.1+0.4, 0],
# [0.2, 0.3]]
# sum_row = [0.5, 0.5], sum_col = [0.7, 0.3], true_positives = [0.5,
# 0.3]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.5 / (0.5 + 0.7 - 0.5) + 0.3 / (0.5 + 0.3 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)
def test_different_thresholds_unweighted(self):
y_true = [0, 1, 0, 1]
y_pred = [0.1, 0.2, 0.4, 0.7]
# with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1]
# cm = [[1, 1],
# [1, 1]]
# sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
result = obj(y_true, y_pred)
self.assertAllClose(result, expected_result, atol=1e-3)
# with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1]
# cm = [[2, 0],
# [1, 1]]
# sum_row = [2, 2], sum_col = [3, 1], true_positives = [2, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (2 / (2 + 3 - 2) + 1 / (2 + 1 - 1)) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
result = obj(y_true, y_pred)
self.assertAllClose(result, expected_result, atol=1e-3)
def test_multi_dim_input(self):
y_true = np.array([[0, 1], [0, 1]], dtype=np.float32)
y_pred = np.array([[0.1, 0.7], [0.9, 0.3]])
threshold = 0.4 # y_pred will become [[0, 1], [1, 0]]
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])
# cm = [[0.2, 0.4],
# [0.1, 0.3]]
# sum_row = [0.6, 0.4], sum_col = [0.3, 0.7], true_positives = [0.2,
# 0.3]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.3 - 0.2) + 0.3 / (0.4 + 0.7 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)
def test_zero_valid_entries(self):
obj = metrics.BinaryIoU(target_class_ids=[0, 1])
self.assertAllClose(obj.result(), 0, atol=1e-3)
def test_zero_and_non_zero_entries(self):
y_pred = np.array([0.6], dtype=np.float32)
threshold = 0.5
y_true = np.array([1])
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold)
result = obj(y_true, y_pred)
# cm = [[0, 0],
# [0, 1]]
# sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = 1 / (1 + 1 - 1)
self.assertAllClose(result, expected_result, atol=1e-3)
class MeanIoUTest(testing.TestCase):
def test_config(self):
m_obj = metrics.MeanIoU(num_classes=2, name="mean_iou")
self.assertEqual(m_obj.name, "mean_iou")
self.assertEqual(m_obj.num_classes, 2)
m_obj2 = metrics.MeanIoU.from_config(m_obj.get_config())
self.assertEqual(m_obj2.name, "mean_iou")
self.assertEqual(m_obj2.num_classes, 2)
def test_unweighted(self):
y_pred = [0, 1, 0, 1]
y_true = [0, 0, 1, 1]
m_obj = metrics.MeanIoU(num_classes=2)
result = m_obj(y_true, y_pred)
# cm = [[1, 1],
# [1, 1]]
# sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_unweighted_ignore_class_255(self):
y_pred = [0, 1, 1, 1]
y_true = [0, 1, 2, 255]
m_obj = metrics.MeanIoU(num_classes=3, ignore_class=255)
result = m_obj(y_true, y_pred)
# cm = [[1, 0, 0],
# [0, 1, 0],
# [0, 1, 0]]
# sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0)
) / 3
self.assertAllClose(result, expected_result, atol=1e-3)
def test_unweighted_ignore_class_1(self):
y_pred = [0, 1, 1, 1]
y_true = [0, 1, 2, -1]
m_obj = metrics.MeanIoU(num_classes=3, ignore_class=-1)
result = m_obj(y_true, y_pred)
# cm = [[1, 0, 0],
# [0, 1, 0],
# [0, 1, 0]]
# sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0)
) / 3
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted(self):
y_pred = np.array([0, 1, 0, 1], dtype=np.float32)
y_true = np.array([0, 0, 1, 1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])
m_obj = metrics.MeanIoU(num_classes=2)
result = m_obj(y_true, y_pred, sample_weight=sample_weight)
# cm = [[0.2, 0.3],
# [0.4, 0.1]]
# sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,
# 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted_ignore_class_1(self):
y_pred = np.array([0, 1, 0, 1], dtype=np.float32)
y_true = np.array([0, 0, 1, -1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])
m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1)
result = m_obj(y_true, y_pred, sample_weight=sample_weight)
# cm = [[0.2, 0.3],
# [0.4, 0.0]]
# sum_row = [0.6, 0.3], sum_col = [0.5, 0.4], true_positives = [0.2,
# 0.0]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.0 / (0.3 + 0.4 - 0.0)
) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_multi_dim_input(self):
y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32)
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])
m_obj = metrics.MeanIoU(num_classes=2)
result = m_obj(y_true, y_pred, sample_weight=sample_weight)
# cm = [[0.2, 0.3],
# [0.4, 0.1]]
# sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,
# 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
self.assertAllClose(result, expected_result, atol=1e-3)
def test_zero_valid_entries(self):
m_obj = metrics.MeanIoU(num_classes=2)
self.assertAllClose(m_obj.result(), 0, atol=1e-3)
def test_zero_and_non_zero_entries(self):
y_pred = np.array([1], dtype=np.float32)
y_true = np.array([1])
m_obj = metrics.MeanIoU(num_classes=2)
result = m_obj(y_true, y_pred)
# cm = [[0, 0],
# [0, 1]]
# sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 + 1 / (1 + 1 - 1)) / 1
self.assertAllClose(result, expected_result, atol=1e-3)
class OneHotIoUTest(testing.TestCase):
def test_unweighted(self):
y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
# y_true will be converted to [2, 0, 1, 0]
y_pred = np.array(
[[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]
)
# y_pred will be converted to [2, 2, 0, 2]
# cm = [[0, 0, 2],
# [1, 0, 0],
# [0, 0, 1]
# sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 / (1 + 2 - 0) + 1 / (3 + 1 - 1)) / 2
obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
result = obj(y_true, y_pred)
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted(self):
y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
# y_true will be converted to [2, 0, 1, 0]
y_pred = np.array(
[[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]
)
# y_pred will be converted to [2, 2, 0, 2]
sample_weight = [0.1, 0.2, 0.3, 0.4]
# cm = [[0, 0, 0.2+0.4],
# [0.3, 0, 0],
# [0, 0, 0.1]]
# sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
# true_positives = [0, 0, 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)
class OneHotMeanIoUTest(testing.TestCase):
def test_unweighted(self):
y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
# y_true will be converted to [2, 0, 1, 0]
y_pred = np.array(
[[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]
)
# y_pred will be converted to [2, 2, 0, 2]
# cm = [[0, 0, 2],
# [1, 0, 0],
# [0, 0, 1]
# sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 + 0 + 1 / (3 + 1 - 1)) / 3
obj = metrics.OneHotMeanIoU(num_classes=3)
result = obj(y_true, y_pred)
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted(self):
y_true = np.array(
[
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
]
)
# y_true will be converted to [2, 0, 1, 0, 0]
y_pred = np.array(
[
[0.2, 0.3, 0.5],
[0.1, 0.2, 0.7],
[0.5, 0.3, 0.1],
[0.1, 0.4, 0.5],
[0.6, 0.2, 0.2],
]
)
# y_pred will be converted to [2, 2, 0, 2, 0]
sample_weight = [0.1, 0.2, 0.3, 0.3, 0.1]
# cm = [[0.1, 0, 0.2+0.3],
# [0.3, 0, 0],
# [0, 0, 0.1]]
# sum_row = [0.4, 0, 0.6], sum_col = [0.6, 0.3, 0.1]
# true_positives = [0.1, 0, 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (
0.1 / (0.4 + 0.6 - 0.1) + 0 + 0.1 / (0.6 + 0.1 - 0.1)
) / 3
obj = metrics.OneHotMeanIoU(num_classes=3)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

@ -590,3 +590,78 @@ def _filter_top_k(x, k):
ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2 ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2
) )
return x * top_k_mask + NEG_INF * (1 - top_k_mask) return x * top_k_mask + NEG_INF * (1 - top_k_mask)
def confusion_matrix(
labels,
predictions,
num_classes=None,
weights=None,
dtype="int32",
):
"""Computes the confusion matrix from predictions and labels.
The matrix columns represent the prediction labels and the rows represent
the real labels. The confusion matrix is always a 2-D array of shape
`(n, n)`, where `n` is the number of valid labels for a given classification
task. Both prediction and labels must be 1-D arrays of the same shape in
order for this function to work.
If `num_classes` is `None`, then `num_classes` will be set to one plus the
maximum value in either predictions or labels. Class labels are expected to
start at 0. For example, if `num_classes` is 3, then the possible labels
would be `[0, 1, 2]`.
If `weights` is not `None`, then each prediction contributes its
corresponding weight to the total value of the confusion matrix cell.
For example:
```python
keras_core.metrics.metrics_utils.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
```
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
resulting in a 5x5 confusion matrix.
Args:
labels: 1-D tensor of real labels for the classification task.
predictions: 1-D tensor of predictions for a given classification.
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
using both predictions and labels array.
weights: An optional tensor whose shape matches `predictions`.
dtype: Data type of the confusion matrix.
Returns:
A tensor of type `dtype` with shape `(n, n)` representing the confusion
matrix, where `n` is the number of possible labels in the classification
task.
"""
labels = ops.convert_to_tensor(labels, dtype)
predictions = ops.convert_to_tensor(predictions, dtype)
labels, predictions = squeeze_to_same_rank(labels, predictions)
predictions = ops.cast(predictions, dtype)
labels = ops.cast(labels, dtype)
if num_classes is None:
num_classes = ops.maximum(ops.max(predictions), ops.max(labels)) + 1
else:
num_classes = ops.cast(num_classes, dtype)
if weights is not None:
weights = ops.convert_to_tensor(weights, dtype)
indices = ops.stack([labels, predictions], axis=1)
values = ops.ones_like(predictions, dtype) if weights is None else weights
indices = ops.cast(indices, dtype="int64")
values = ops.cast(values, dtype=dtype)
num_classes = ops.cast(num_classes, "int64")
confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))
return confusion_matrix

@ -11,6 +11,7 @@ from keras_core.backend import random
from keras_core.backend import shape from keras_core.backend import shape
from keras_core.operations import image from keras_core.operations import image
from keras_core.operations import operation_utils from keras_core.operations import operation_utils
from keras_core.operations.core import * # noqa: F403
from keras_core.operations.math import * # noqa: F403 from keras_core.operations.math import * # noqa: F403
from keras_core.operations.nn import * # noqa: F403 from keras_core.operations.nn import * # noqa: F403
from keras_core.operations.numpy import * # noqa: F403 from keras_core.operations.numpy import * # noqa: F403

@ -62,3 +62,7 @@ class CoreOpsCorrectnessTest(testing.TestCase):
core.scatter(indices, values, (6, 3)), core.scatter(indices, values, (6, 3)),
[[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],
) )
# Duplicate indices
indices = np.array([[0], [0]])
values = np.array([1, 1])
self.assertAllClose(core.scatter(indices, values, (1,)), [2])