Add CSVLogger and TerminateOnNaN Callbacks (#95)

* Add CSV Logger and Terminate on Nan

* Add CSVLogger and Terminate on Nan tests

* Update CSV Logger docstring
This commit is contained in:
Ramesh Sampath 2023-05-06 00:09:26 -05:00 committed by Francois Chollet
parent 02d0c27ce9
commit cc89199f1e
15 changed files with 1496 additions and 252 deletions

@ -22,7 +22,6 @@ def resize(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if len(image.shape) == 4:
if data_format == "channels_last":
size = (image.shape[0],) + size + (image.shape[-1],)

@ -22,7 +22,7 @@ def resize(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if data_format == "channels_first":
if len(image.shape) == 4:
image = tf.transpose(image, (0, 2, 3, 1))

@ -1,6 +1,8 @@
from keras_core.callbacks.callback import Callback
from keras_core.callbacks.callback_list import CallbackList
from keras_core.callbacks.csv_logger import CSVLogger
from keras_core.callbacks.early_stopping import EarlyStopping
from keras_core.callbacks.history import History
from keras_core.callbacks.lambda_callback import LambdaCallback
from keras_core.callbacks.progbar_logger import ProgbarLogger
from keras_core.callbacks.terminate_on_nan import TerminateOnNaN

@ -0,0 +1,101 @@
import collections
import csv
import numpy as np
import tensorflow as tf
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import file_utils
@keras_core_export("keras_core.callbacks.CSVLogger")
class CSVLogger(Callback):
"""Callback that streams epoch results to a CSV file.
Supports all values that can be represented as a string,
including 1D iterables such as `np.ndarray`.
Args:
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
separator: String used to separate elements in the CSV file.
append: Boolean. True: append if file exists (useful for continuing
training). False: overwrite existing file.
Example:
```python
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])
```
"""
def __init__(self, filename, separator=",", append=False):
super().__init__()
self.sep = separator
self.filename = file_utils.path_to_string(filename)
self.append = append
self.writer = None
self.keys = None
self.append_header = True
def on_train_begin(self, logs=None):
if self.append:
if tf.io.gfile.exists(self.filename):
with tf.io.gfile.GFile(self.filename, "r") as f:
self.append_header = not bool(len(f.readline()))
mode = "a"
else:
mode = "w"
self.csv_file = tf.io.gfile.GFile(self.filename, mode)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif (
isinstance(k, collections.abc.Iterable)
and not is_zero_dim_ndarray
):
return f"\"[{', '.join(map(str, k))}]\""
else:
return k
if self.keys is None:
self.keys = sorted(logs.keys())
# When validation_freq > 1, `val_` keys are not in first epoch logs
# Add the `val_` keys so that its part of the fieldnames of writer.
val_keys_found = False
for key in self.keys:
if key.startswith("val_"):
val_keys_found = True
break
if not val_keys_found:
self.keys.extend(["val_" + k for k in self.keys])
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ["epoch"] + self.keys
self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()
row_dict = collections.OrderedDict({"epoch": epoch})
row_dict.update(
(key, handle_value(logs.get(key, "NA"))) for key in self.keys
)
self.writer.writerow(row_dict)
self.csv_file.flush()
def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None

@ -0,0 +1,176 @@
import csv
import os
import re
import tempfile
import numpy as np
from keras_core import callbacks
from keras_core import initializers
from keras_core import layers
from keras_core import testing
from keras_core.models import Sequential
from keras_core.utils import numerical_utils
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
INPUT_DIM = 3
BATCH_SIZE = 4
class CSVLoggerTest(testing.TestCase):
def test_CSVLogger(self):
OUTPUT_DIM = 1
np.random.seed(1337)
temp_dir = tempfile.TemporaryDirectory()
filepath = os.path.join(temp_dir.name, "log.tsv")
sep = "\t"
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
y_train = np.random.random((TRAIN_SAMPLES, OUTPUT_DIM))
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
y_test = np.random.random((TEST_SAMPLES, OUTPUT_DIM))
def make_model():
np.random.seed(1337)
model = Sequential(
[
layers.Dense(2, activation="relu"),
layers.Dense(OUTPUT_DIM),
]
)
model.compile(
loss="mse",
optimizer="sgd",
metrics=["mse"],
)
return model
# case 1, create new file with defined separator
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=1,
verbose=0,
)
assert os.path.exists(filepath)
with open(filepath) as csvfile:
dialect = csv.Sniffer().sniff(csvfile.read())
assert dialect.delimiter == sep
del model
del cbks
# case 2, append data to existing file, skip header
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)]
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=1,
verbose=0,
)
# case 3, reuse of CSVLogger object
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=2,
verbose=0,
)
with open(filepath) as csvfile:
list_lines = csvfile.readlines()
for line in list_lines:
assert line.count(sep) == 4
assert len(list_lines) == 5
output = " ".join(list_lines)
assert len(re.findall("epoch", output)) == 1
os.remove(filepath)
# case 3, Verify Val. loss also registered when Validation Freq > 1
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
hist = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
validation_freq=3,
callbacks=cbks,
epochs=5,
verbose=0,
)
assert os.path.exists(filepath)
# Verify that validation loss is registered at val. freq
with open(filepath) as csvfile:
rows = csv.DictReader(csvfile, delimiter=sep)
for idx, row in enumerate(rows, 1):
self.assertIn("val_loss", row)
if idx == 3:
self.assertEqual(
row["val_loss"], str(hist.history["val_loss"][0])
)
else:
self.assertEqual(row["val_loss"], "NA")
def test_stop_training_csv(self):
# Test that using the CSVLogger callback with the TerminateOnNaN
# callback does not result in invalid CSVs.
tmpdir = tempfile.TemporaryDirectory()
csv_logfile = os.path.join(tmpdir.name, "csv_logger.csv")
NUM_CLASSES = 2
np.random.seed(1337)
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)
y_test = numerical_utils.to_categorical(y_test)
y_train = numerical_utils.to_categorical(y_train)
model = Sequential()
initializer = initializers.Constant(value=1e5)
for _ in range(5):
model.add(
layers.Dense(
2,
activation="relu",
kernel_initializer=initializer,
)
)
model.add(layers.Dense(NUM_CLASSES))
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=[
callbacks.TerminateOnNaN(),
callbacks.CSVLogger(csv_logfile),
],
epochs=20,
)
loss = history.history["loss"]
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
values = []
with open(csv_logfile) as f:
# On Windows, due to \r\n line ends, we may end up reading empty
# lines after each line. Skip empty lines.
values = [x for x in csv.reader(f) if x]
self.assertIn("nan", values[-1], "NaN not logged in CSV Logger.")

@ -0,0 +1,20 @@
import numpy as np
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils
@keras_core_export("keras_core.callbacks.TerminateOnNaN")
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered."""
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get("loss")
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True

@ -0,0 +1,50 @@
import numpy as np
from keras_core import callbacks
from keras_core import initializers
from keras_core import layers
from keras_core import testing
from keras_core.models import Sequential
from keras_core.utils import numerical_utils
class TerminateOnNaNTest(testing.TestCase):
def test_TerminateOnNaN(self):
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
INPUT_DIM = 3
NUM_CLASSES = 2
BATCH_SIZE = 4
np.random.seed(1337)
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)
y_test = numerical_utils.to_categorical(y_test)
y_train = numerical_utils.to_categorical(y_train)
model = Sequential()
initializer = initializers.Constant(value=1e5)
for _ in range(5):
model.add(
layers.Dense(
2,
activation="relu",
kernel_initializer=initializer,
)
)
model.add(layers.Dense(NUM_CLASSES))
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=[callbacks.TerminateOnNaN()],
epochs=20,
)
loss = history.history["loss"]
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))

@ -48,7 +48,6 @@ from keras_core.layers.pooling.global_max_pooling3d import GlobalMaxPooling3D
from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
from keras_core.layers.preprocessing.center_crop import CenterCrop
from keras_core.layers.preprocessing.normalization import Normalization
from keras_core.layers.preprocessing.rescaling import Rescaling
from keras_core.layers.preprocessing.resizing import Resizing

@ -1,130 +0,0 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.CenterCrop")
class CenterCrop(Layer):
"""A preprocessing layer which crops images.
This layers crops the central portion of the images to a target size. If an
image is smaller than the target size, it will be resized and cropped
so as to return the largest possible window in the image that matches
the target aspect ratio.
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format,
or `(..., channels, height, width)`, in `"channels_first"` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., target_height, target_width, channels)`,
or `(..., channels, target_height, target_width)`,
in `"channels_first"` format.
If the input height/width is even and the target height/width is odd (or
inversely), the input image is left-padded by 1 pixel.
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
"""
def __init__(self, height, width, data_format=None, **kwargs):
super().__init__(**kwargs)
self.height = height
self.width = width
self.data_format = data_format or backend.image_data_format()
def call(self, inputs):
if self.data_format == "channels_first":
init_height = inputs.shape[-2]
init_width = inputs.shape[-1]
else:
init_height = inputs.shape[-3]
init_width = inputs.shape[-2]
if init_height is None or init_width is None:
# Dynamic size case. TODO.
raise ValueError(
"At this time, CenterCrop can only "
"process images with a static spatial "
f"shape. Received: inputs.shape={inputs.shape}"
)
h_diff = init_height - self.height
w_diff = init_width - self.width
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)
if h_diff >= 0 and w_diff >= 0:
if len(inputs.shape) == 4:
if self.data_format == "channels_first":
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
elif len(inputs.shape) == 3:
if self.data_format == "channels_first":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
return inputs[
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return image_utils.smart_resize(
inputs, [self.height, self.width], data_format=self.data_format
)
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
if len(input_shape) == 4:
if self.data_format == "channels_last":
input_shape[1] = self.height
input_shape[2] = self.width
else:
input_shape[2] = self.height
input_shape[3] = self.width
else:
if self.data_format == "channels_last":
input_shape[0] = self.height
input_shape[1] = self.width
else:
input_shape[1] = self.height
input_shape[2] = self.width
return tuple(input_shape)
def get_config(self):
base_config = super().get_config()
config = {
"height": self.height,
"width": self.width,
"data_format": self.data_format,
}
return {**base_config, **config}

@ -1,96 +0,0 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class CenterCropTest(testing.TestCase, parameterized.TestCase):
def test_center_crop_basics(self):
self.run_layer_test(
layers.CenterCrop,
init_kwargs={
"height": 6,
"width": 6,
"data_format": "channels_last",
},
input_shape=(2, 12, 12, 3),
expected_output_shape=(2, 6, 6, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.CenterCrop,
init_kwargs={
"height": 7,
"width": 7,
"data_format": "channels_first",
},
input_shape=(2, 3, 13, 13),
expected_output_shape=(2, 3, 7, 7),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
[
((5, 7), "channels_first"),
((5, 7), "channels_last"),
((15, 10), "channels_first"),
((10, 17), "channels_last"),
]
)
def test_center_crop_correctness(self, size, data_format):
# batched case
if data_format == "channels_first":
img = np.random.random((2, 3, 9, 11))
else:
img = np.random.random((2, 9, 11, 3))
out = layers.CenterCrop(
size[0],
size[1],
data_format=data_format,
)(img)
if data_format == "channels_first":
img_transpose = np.transpose(img, (0, 2, 3, 1))
ref_out = tf.transpose(
tf.keras.layers.CenterCrop(size[0], size[1])(img_transpose),
(0, 3, 1, 2),
)
else:
ref_out = tf.keras.layers.CenterCrop(size[0], size[1])(img)
self.assertAllClose(ref_out, out)
# unbatched case
if data_format == "channels_first":
img = np.random.random((3, 9, 11))
else:
img = np.random.random((9, 11, 3))
out = layers.CenterCrop(
size[0],
size[1],
data_format=data_format,
)(img)
if data_format == "channels_first":
img_transpose = np.transpose(img, (1, 2, 0))
ref_out = tf.transpose(
tf.keras.layers.CenterCrop(
size[0],
size[1],
)(img_transpose),
(2, 0, 1),
)
else:
ref_out = tf.keras.layers.CenterCrop(
size[0],
size[1],
)(img)
self.assertAllClose(ref_out, out)

@ -14,17 +14,6 @@ class Resizing(Layer):
format. Input pixel values can be of any range
(e.g. `[0., 1.)` or `[0, 255]`).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format,
or `(..., channels, height, width)`, in `"channels_first"` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., target_height, target_width, channels)`,
or `(..., channels, target_height, target_width)`,
in `"channels_first"` format.
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.

@ -5,6 +5,7 @@ from keras_core.metrics.accuracy_metrics import CategoricalAccuracy
from keras_core.metrics.accuracy_metrics import SparseCategoricalAccuracy
from keras_core.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
from keras_core.metrics.confusion_metrics import AUC
from keras_core.metrics.confusion_metrics import FalseNegatives
from keras_core.metrics.confusion_metrics import FalsePositives
from keras_core.metrics.confusion_metrics import Precision
@ -41,6 +42,7 @@ ALL_OBJECTS = {
# Regression
MeanSquaredError,
# Classification
AUC,
FalseNegatives,
FalsePositives,
Precision,

@ -1,3 +1,6 @@
import numpy as np
from keras_core import activations
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
@ -366,7 +369,7 @@ class Precision(Metric):
y_pred: The predicted values. Each element must be in the range
`[0, 1]`.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
Can be a tensor whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
@ -507,7 +510,7 @@ class Recall(Metric):
y_pred: The predicted values. Each element must be in the range
`[0, 1]`.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
Can be a tensor whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
@ -605,7 +608,7 @@ class SensitivitySpecificityBase(Metric):
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0, or the same rank as
Can be a tensor whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`.
"""
metrics_utils.update_confusion_matrix_variables(
@ -1052,3 +1055,518 @@ class RecallAtPrecision(SensitivitySpecificityBase):
}
base_config = super().get_config()
return {**base_config, **config}
@keras_core_export("keras_core.metrics.AUC")
class AUC(Metric):
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
The AUC (Area under the curve) of the ROC (Receiver operating
characteristic; default) or PR (Precision Recall) curves are quality
measures of binary classifiers. Unlike the accuracy, and like cross-entropy
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric
accumulation phrase, predictions are accumulated within predefined buckets
by value. The AUC is then computed by interpolating per-bucket averages.
These buckets define the evaluated operational points.
This metric creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as `auc`, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The `num_thresholds`
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on `num_thresholds`. The
`thresholds` parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, `predictions` should be
distributed approximately uniformly in the range `[0, 1]` (if
`from_logits=False`). The quality of the AUC approximation may be poor if
this is not the case. Setting `summation_method` to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Args:
num_thresholds: (Optional) The number of thresholds to
use when discretizing the roc curve. Values must be > 1.
Defaults to `200`.
curve: (Optional) Specifies the name of the curve to be computed,
`'ROC'` (default) or `'PR'` for the Precision-Recall-curve.
summation_method: (Optional) Specifies the [Riemann summation method](
https://en.wikipedia.org/wiki/Riemann_sum) used.
'interpolation' (default) applies mid-point summation scheme for
`ROC`. For PR-AUC, interpolates (true/false) positives but not
the ratio that is precision (see Davis & Goadrich 2006 for
details); 'minoring' applies left summation for increasing
intervals and right summation for decreasing intervals; 'majoring'
does the opposite.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
thresholds: (Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in `[0, 1]`. Endpoint
thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive
epsilon value will be automatically included with these to correctly
handle predictions equal to exactly 0 or 1.
multi_label: boolean indicating whether multilabel data should be
treated as such, wherein AUC is computed separately for each label
and then averaged across labels, or (when `False`) if the data
should be flattened into a single label before AUC computation. In
the latter case, when multilabel data is passed to AUC, each
label-prediction pair is treated as an individual data point. Should
be set to False for multi-class data.
num_labels: (Optional) The number of labels, used when `multi_label` is
True. If `num_labels` is not specified, then state variables get
created on the first call to `update_state`.
label_weights: (Optional) list, array, or tensor of non-negative weights
used to compute AUCs for multilabel data. When `multi_label` is
True, the weights are applied to the individual label AUCs when they
are averaged to produce the multi-label AUC. When it's False, they
are used to weight the individual label predictions in computing the
confusion matrix on the flattened data. Note that this is unlike
`class_weights` in that `class_weights` weights the example
depending on the value of its label, whereas `label_weights` depends
only on the index of that label before flattening; therefore
`label_weights` should not be used for multi-class data.
from_logits: boolean indicating whether the predictions (`y_pred` in
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
when using a keras loss, the `from_logits` constructor argument of the
loss should match the AUC `from_logits` constructor argument.
Standalone usage:
>>> m = keras_core.metrics.AUC(num_thresholds=3)
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
>>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
>>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
>>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
>>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
>>> # = 0.75
>>> m.result()
0.75
>>> m.reset_state()
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
... sample_weight=[1, 0, 0, 1])
>>> m.result()
1.0
Usage with `compile()` API:
```python
# Reports the AUC of a model outputting a probability.
model.compile(optimizer='sgd',
loss=keras_core.losses.BinaryCrossentropy(),
metrics=[keras_core.metrics.AUC()])
# Reports the AUC of a model outputting a logit.
model.compile(optimizer='sgd',
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras_core.metrics.AUC(from_logits=True)])
```
"""
def __init__(
self,
num_thresholds=200,
curve="ROC",
summation_method="interpolation",
name=None,
dtype=None,
thresholds=None,
multi_label=False,
num_labels=None,
label_weights=None,
from_logits=False,
):
# Validate configurations.
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
metrics_utils.AUCCurve
):
raise ValueError(
f'Invalid `curve` argument value "{curve}". '
f"Expected one of: {list(metrics_utils.AUCCurve)}"
)
if isinstance(
summation_method, metrics_utils.AUCSummationMethod
) and summation_method not in list(metrics_utils.AUCSummationMethod):
raise ValueError(
"Invalid `summation_method` "
f'argument value "{summation_method}". '
f"Expected one of: {list(metrics_utils.AUCSummationMethod)}"
)
# Update properties.
self._init_from_thresholds = thresholds is not None
if thresholds is not None:
# If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2
thresholds = sorted(thresholds)
self._thresholds_distributed_evenly = (
metrics_utils.is_evenly_distributed_thresholds(
np.array([0.0] + thresholds + [1.0])
)
)
else:
if num_thresholds <= 1:
raise ValueError(
"Argument `num_thresholds` must be an integer > 1. "
f"Received: num_thresholds={num_thresholds}"
)
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
self.num_thresholds = num_thresholds
thresholds = [
(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)
]
self._thresholds_distributed_evenly = True
# Add an endpoint "threshold" below zero and above one for either
# threshold method to account for floating point imprecisions.
self._thresholds = np.array(
[0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]
)
if isinstance(curve, metrics_utils.AUCCurve):
self.curve = curve
else:
self.curve = metrics_utils.AUCCurve.from_str(curve)
if isinstance(summation_method, metrics_utils.AUCSummationMethod):
self.summation_method = summation_method
else:
self.summation_method = metrics_utils.AUCSummationMethod.from_str(
summation_method
)
super().__init__(name=name, dtype=dtype)
# Handle multilabel arguments.
self.multi_label = multi_label
self.num_labels = num_labels
if label_weights is not None:
label_weights = ops.array(label_weights, dtype=self.dtype)
self.label_weights = label_weights
else:
self.label_weights = None
self._from_logits = from_logits
self._built = False
if self.multi_label:
if num_labels:
shape = [None, num_labels]
self._build(shape)
else:
if num_labels:
raise ValueError(
"`num_labels` is needed only when `multi_label` is True."
)
self._build(None)
@property
def thresholds(self):
"""The thresholds used for evaluating AUC."""
return list(self._thresholds)
def _build(self, shape):
"""Initialize TP, FP, TN, and FN tensors, given the shape of the
data."""
if self.multi_label:
if len(shape) != 2:
raise ValueError(
"`y_pred` must have rank 2 when `multi_label=True`. "
f"Found rank {len(shape)}. "
f"Full shape received for `y_pred`: {shape}"
)
self._num_labels = shape[1]
variable_shape = [self.num_thresholds, self._num_labels]
else:
variable_shape = [self.num_thresholds]
self._build_input_shape = shape
# Create metric variables
self.true_positives = self.add_variable(
shape=variable_shape,
initializer=initializers.Zeros(),
name="true_positives",
)
self.false_positives = self.add_variable(
shape=variable_shape,
initializer=initializers.Zeros(),
name="false_positives",
)
self.true_negatives = self.add_variable(
shape=variable_shape,
initializer=initializers.Zeros(),
name="true_negatives",
)
self.false_negatives = self.add_variable(
shape=variable_shape,
initializer=initializers.Zeros(),
name="false_negatives",
)
self._built = True
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates confusion matrix statistics.
Args:
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Can
be a tensor whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true`. Defaults to
`1`.
"""
if not self._built:
self._build(y_pred.shape)
if self.multi_label or (self.label_weights is not None):
# y_true should have shape (number of examples, number of labels).
shapes = [(y_true, ("N", "L"))]
if self.multi_label:
# TP, TN, FP, and FN should all have shape
# (number of thresholds, number of labels).
shapes.extend(
[
(self.true_positives, ("T", "L")),
(self.true_negatives, ("T", "L")),
(self.false_positives, ("T", "L")),
(self.false_negatives, ("T", "L")),
]
)
if self.label_weights is not None:
# label_weights should be of length equal to the number of
# labels.
shapes.append((self.label_weights, ("L",)))
# Only forward label_weights to update_confusion_matrix_variables when
# multi_label is False. Otherwise the averaging of individual label AUCs
# is handled in AUC.result
label_weights = None if self.multi_label else self.label_weights
if self._from_logits:
y_pred = activations.sigmoid(y_pred)
metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
},
y_true,
y_pred,
self._thresholds,
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
sample_weight=sample_weight,
multi_label=self.multi_label,
label_weights=label_weights,
)
def interpolate_pr_auc(self):
"""Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
https://www.biostat.wisc.edu/~page/rocpr.pdf
Note here we derive & use a closed formula not present in the paper
as follows:
Precision = TP / (TP + FP) = TP / P
Modeling all of TP (true positive), FP (false positive) and their sum
P = TP + FP (predicted positive) as varying linearly within each
interval [A, B] between successive thresholds, we get
Precision slope = dTP / dP
= (TP_B - TP_A) / (P_B - P_A)
= (TP - TP_A) / (P - P_A)
Precision = (TP_A + slope * (P - P_A)) / P
The area within the interval is (slope / total_pos_weight) times
int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
Bringing back the factor (slope / total_pos_weight) we'd put aside, we
get
slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
where dTP == TP_B - TP_A.
Note that when P_A == 0 the above calculation simplifies into
int_A^B{Precision.dTP} = int_A^B{slope * dTP}
= slope * (TP_B - TP_A)
which is really equivalent to imputing constant precision throughout the
first bucket having >0 true positives.
Returns:
pr_auc: an approximation of the area under the P-R curve.
"""
dtp = (
self.true_positives[: self.num_thresholds - 1]
- self.true_positives[1:]
)
p = ops.add(self.true_positives, self.false_positives)
dp = p[: self.num_thresholds - 1] - p[1:]
prec_slope = ops.divide(dtp, ops.maximum(dp, backend.epsilon()))
intercept = self.true_positives[1:] - ops.multiply(prec_slope, p[1:])
safe_p_ratio = ops.where(
ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
ops.divide(
p[: self.num_thresholds - 1],
ops.maximum(p[1:], backend.epsilon()),
),
ops.ones_like(p[1:]),
)
pr_auc_increment = ops.divide(
prec_slope * (dtp + intercept * ops.log(safe_p_ratio)),
ops.maximum(
self.true_positives[1:] + self.false_negatives[1:],
backend.epsilon(),
),
)
if self.multi_label:
by_label_auc = ops.sum(pr_auc_increment, axis=0)
if self.label_weights is None:
# Evenly weighted average of the label AUCs.
return ops.mean(by_label_auc)
else:
# Weighted average of the label AUCs.
return ops.divide(
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
ops.maximum(ops.sum(self.label_weights), backend.epsilon()),
)
else:
return ops.sum(pr_auc_increment)
def result(self):
if (
self.curve == metrics_utils.AUCCurve.PR
and self.summation_method
== metrics_utils.AUCSummationMethod.INTERPOLATION
):
# This use case is different and is handled separately.
return self.interpolate_pr_auc()
# Set `x` and `y` values for the curves based on `curve` config.
recall = ops.divide(
self.true_positives,
ops.maximum(
ops.add(self.true_positives, self.false_negatives),
backend.epsilon(),
),
)
if self.curve == metrics_utils.AUCCurve.ROC:
fp_rate = ops.divide(
self.false_positives,
ops.maximum(
ops.add(self.false_positives, self.true_negatives),
backend.epsilon(),
),
)
x = fp_rate
y = recall
else: # curve == 'PR'.
precision = ops.divide(
self.true_positives,
ops.maximum(
ops.add(self.true_positives, self.false_positives),
backend.epsilon(),
),
)
x = recall
y = precision
# Find the rectangle heights based on `summation_method`.
if (
self.summation_method
== metrics_utils.AUCSummationMethod.INTERPOLATION
):
# Note: the case ('PR', 'interpolation') has been handled above.
heights = (y[: self.num_thresholds - 1] + y[1:]) / 2.0
elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
heights = ops.minimum(y[: self.num_thresholds - 1], y[1:])
# self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
else:
heights = ops.maximum(y[: self.num_thresholds - 1], y[1:])
# Sum up the areas of all the rectangles.
if self.multi_label:
riemann_terms = ops.multiply(
x[: self.num_thresholds - 1] - x[1:], heights
)
by_label_auc = ops.sum(riemann_terms, axis=0)
if self.label_weights is None:
# Unweighted average of the label AUCs.
return ops.mean(by_label_auc)
else:
# Weighted average of the label AUCs.
return ops.divide(
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
ops.maximum(ops.sum(self.label_weights), backend.epsilon()),
)
else:
return ops.sum(
ops.multiply(x[: self.num_thresholds - 1] - x[1:], heights)
)
def reset_state(self):
if self._built:
if self.multi_label:
variable_shape = (self.num_thresholds, self._num_labels)
else:
variable_shape = (self.num_thresholds,)
self.true_positives.assign(ops.zeros(variable_shape))
self.false_positives.assign(ops.zeros(variable_shape))
self.true_negatives.assign(ops.zeros(variable_shape))
self.false_negatives.assign(ops.zeros(variable_shape))
def get_config(self):
label_weights = self.label_weights
config = {
"num_thresholds": self.num_thresholds,
"curve": self.curve.value,
"summation_method": self.summation_method.value,
"multi_label": self.multi_label,
"num_labels": self.num_labels,
"label_weights": label_weights,
"from_logits": self._from_logits,
}
# optimization to avoid serializing a large number of generated
# thresholds
if self._init_from_thresholds:
# We remove the endpoint thresholds as an inverse of how the
# thresholds were initialized. This ensures that a metric
# initialized from this config has the same thresholds.
config["thresholds"] = self.thresholds[1:-1]
base_config = super().get_config()
return {**base_config, **config}

@ -1,10 +1,16 @@
import json
import numpy as np
from absl import logging
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import layers
from keras_core import metrics
from keras_core import models
from keras_core import operations as ops
from keras_core import testing
from keras_core.metrics import metrics_utils
# TODO: remove reliance on this (or alternatively, turn it on by default).
# This is no longer needed with tf-nightly.
@ -1101,3 +1107,558 @@ class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
ValueError, "Argument `num_thresholds` must be an integer > 0"
):
metrics.RecallAtPrecision(0.4, num_thresholds=-1)
class AUCTest(testing.TestCase):
def setUp(self):
self.num_thresholds = 3
self.y_pred = np.array([0, 0.5, 0.3, 0.9], dtype="float32")
self.y_pred_multi_label = np.array(
[[0.0, 0.4], [0.5, 0.7], [0.3, 0.2], [0.9, 0.3]], dtype="float32"
)
epsilon = 1e-12
self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
self.y_true = np.array([0, 0, 1, 1])
self.y_true_multi_label = np.array([[0, 0], [1, 1], [1, 1], [1, 0]])
self.sample_weight = [1, 2, 3, 4]
# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
# y_pred when threshold = 0 - 1e-7 : [1, 1, 1, 1]
# y_pred when threshold = 0.5 : [0, 0, 0, 1]
# y_pred when threshold = 1 + 1e-7 : [0, 0, 0, 0]
# without sample_weight:
# tp = np.sum([[0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]], axis=1)
# fp = np.sum([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)
# fn = np.sum([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]], axis=1)
# tn = np.sum([[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]], axis=1)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# with sample_weight:
# tp = np.sum([[0, 0, 3, 4], [0, 0, 0, 4], [0, 0, 0, 0]], axis=1)
# fp = np.sum([[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)
# fn = np.sum([[0, 0, 0, 0], [0, 0, 3, 0], [0, 0, 3, 4]], axis=1)
# tn = np.sum([[0, 0, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]], axis=1)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
def test_config(self):
auc_obj = metrics.AUC(
num_thresholds=100,
curve="PR",
summation_method="majoring",
name="auc_1",
dtype="float64",
multi_label=True,
num_labels=2,
from_logits=True,
)
auc_obj.update_state(self.y_true_multi_label, self.y_pred_multi_label)
self.assertEqual(auc_obj.name, "auc_1")
self.assertEqual(auc_obj._dtype, "float64")
self.assertLen(auc_obj.variables, 4)
self.assertEqual(auc_obj.num_thresholds, 100)
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(
auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING
)
self.assertTrue(auc_obj.multi_label)
self.assertEqual(auc_obj.num_labels, 2)
self.assertTrue(auc_obj._from_logits)
old_config = auc_obj.get_config()
self.assertNotIn("thresholds", old_config)
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
auc_obj2.update_state(self.y_true_multi_label, self.y_pred_multi_label)
self.assertEqual(auc_obj2.name, "auc_1")
self.assertLen(auc_obj2.variables, 4)
self.assertEqual(auc_obj2.num_thresholds, 100)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(
auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING
)
self.assertTrue(auc_obj2.multi_label)
self.assertEqual(auc_obj2.num_labels, 2)
self.assertTrue(auc_obj2._from_logits)
new_config = auc_obj2.get_config()
self.assertNotIn("thresholds", new_config)
self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_config_manual_thresholds(self):
auc_obj = metrics.AUC(
num_thresholds=None,
curve="PR",
summation_method="majoring",
name="auc_1",
thresholds=[0.3, 0.5],
)
auc_obj.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj.name, "auc_1")
self.assertLen(auc_obj.variables, 4)
self.assertEqual(auc_obj.num_thresholds, 4)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(
auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING
)
old_config = auc_obj.get_config()
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
# Check save and restore config.
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
auc_obj2.update_state(self.y_true, self.y_pred)
self.assertEqual(auc_obj2.name, "auc_1")
self.assertLen(auc_obj2.variables, 4)
self.assertEqual(auc_obj2.num_thresholds, 4)
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
self.assertEqual(
auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING
)
new_config = auc_obj2.get_config()
self.assertDictEqual(old_config, new_config)
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
def test_unweighted_all_correct(self):
auc_obj = metrics.AUC()
self.assertEqual(auc_obj(self.y_true, self.y_true), 1)
def test_unweighted(self):
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
result = auc_obj(self.y_true, self.y_pred)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 0.75 * 1 + 0.25 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_unweighted_from_logits(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, from_logits=True
)
result = auc_obj(self.y_true, self.y_pred_logits)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 0.75 * 1 + 0.25 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_manual_thresholds(self):
# Verify that when specified, thresholds are used instead of
# num_thresholds.
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5])
self.assertEqual(auc_obj.num_thresholds, 3)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
result = auc_obj(self.y_true, self.y_pred)
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 0.75 * 1 + 0.25 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_roc_interpolation(self):
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
# heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 0.7855 * 1 + 0.2855 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_roc_majoring(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, summation_method="majoring"
)
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
# heights = [max(1, 0.571), max(0.571, 0)] = [1, 0.571]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 1 * 1 + 0.571 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_roc_minoring(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, summation_method="minoring"
)
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
# heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 0.571 * 1 + 0 * 0
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_pr_majoring(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
curve="PR",
summation_method="majoring",
)
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
# heights = [max(0.7, 1), max(1, 0)] = [1, 1]
# widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]
expected_result = 1 * 0.429 + 1 * 0.571
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_pr_minoring(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
curve="PR",
summation_method="minoring",
)
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
# heights = [min(0.7, 1), min(1, 0)] = [0.7, 0]
# widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]
expected_result = 0.7 * 0.429 + 0 * 0.571
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_pr_interpolation(self):
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR")
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)
# auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)]
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# P = tp + fp = [10, 4, 0]
# dTP = [7-4, 4-0] = [3, 4]
# dP = [10-4, 4-0] = [6, 4]
# slope = dTP/dP = [0.5, 1]
# intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0]
# (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1]
# auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))]
# = [2.416, 4]
# auc = [2.416, 4]/(tp[1:]+fn[1:])
expected_result = 2.416 / 7 + 4 / 7
self.assertAllClose(result, expected_result, 1e-3)
def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 1"
):
metrics.AUC(num_thresholds=-1)
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 1."
):
metrics.AUC(num_thresholds=1)
def test_invalid_curve(self):
with self.assertRaisesRegex(
ValueError, 'Invalid AUC curve value: "Invalid".'
):
metrics.AUC(curve="Invalid")
def test_invalid_summation_method(self):
with self.assertRaisesRegex(
ValueError, 'Invalid AUC summation method value: "Invalid".'
):
metrics.AUC(summation_method="Invalid")
def test_extra_dims(self):
try:
from scipy import special
logits = special.expit(
-np.array(
[
[[-10.0, 10.0, -10.0], [10.0, -10.0, 10.0]],
[[-12.0, 12.0, -12.0], [12.0, -12.0, 12.0]],
],
dtype=np.float32,
)
)
labels = np.array(
[[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64
)
auc_obj = metrics.AUC()
result = auc_obj(labels, logits)
self.assertEqual(result, 0.5)
except ImportError as e:
logging.warning(f"Cannot test special functions: {str(e)}")
class MultiAUCTest(testing.TestCase):
def setUp(self):
self.num_thresholds = 5
self.y_pred = np.array(
[[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]], dtype="float32"
).T
epsilon = 1e-12
self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
self.y_true_good = np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T
self.y_true_bad = np.array([[0, 0, 1, 1], [1, 1, 0, 0]]).T
self.sample_weight = [1, 2, 3, 4]
# threshold values are [0 - 1e-7, 0.25, 0.5, 0.75, 1 + 1e-7]
# y_pred when threshold = 0 - 1e-7 : [[1, 1, 1, 1], [1, 1, 1, 1]]
# y_pred when threshold = 0.25 : [[0, 1, 1, 1], [0, 0, 1, 1]]
# y_pred when threshold = 0.5 : [[0, 0, 0, 1], [0, 0, 0, 0]]
# y_pred when threshold = 0.75 : [[0, 0, 0, 1], [0, 0, 0, 0]]
# y_pred when threshold = 1 + 1e-7 : [[0, 0, 0, 0], [0, 0, 0, 0]]
# for y_true_good, over thresholds:
# tp = [[2, 2, 1, 1, 0], [2, 2, 0, 0, 0]]
# fp = [[2, 1, 0, 0 , 0], [2, 0, 0 ,0, 0]]
# fn = [[0, 0, 1, 1, 2], [0, 0, 2, 2, 2]]
# tn = [[0, 1, 2, 2, 2], [0, 2, 2, 2, 2]]
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
# for y_true_bad:
# tp = [[2, 2, 1, 1, 0], [2, 0, 0, 0, 0]]
# fp = [[2, 1, 0, 0 , 0], [2, 2, 0 ,0, 0]]
# fn = [[0, 0, 1, 1, 2], [0, 2, 2, 2, 2]]
# tn = [[0, 1, 2, 2, 2], [0, 0, 2, 2, 2]]
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 0, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 1, 0, 0, 0]]
# for y_true_good with sample_weights:
# tp = [[7, 7, 4, 4, 0], [7, 7, 0, 0, 0]]
# fp = [[3, 2, 0, 0, 0], [3, 0, 0, 0, 0]]
# fn = [[0, 0, 3, 3, 7], [0, 0, 7, 7, 7]]
# tn = [[0, 1, 3, 3, 3], [0, 3, 3, 3, 3]]
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
def test_unweighted_all_correct(self):
auc_obj = metrics.AUC(multi_label=True)
result = auc_obj(self.y_true_good, self.y_true_good)
self.assertEqual(result, 1)
def test_unweighted_all_correct_flat(self):
auc_obj = metrics.AUC(multi_label=False)
result = auc_obj(self.y_true_good, self.y_true_good)
self.assertEqual(result, 1)
def test_unweighted(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=True
)
result = auc_obj(self.y_true_good, self.y_pred)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 + 1.0) / 2.0
self.assertAllClose(result, expected_result, 1e-3)
def test_unweighted_from_logits(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
multi_label=True,
from_logits=True,
)
result = auc_obj(self.y_true_good, self.y_pred_logits)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 + 1.0) / 2.0
self.assertAllClose(result, expected_result, 1e-3)
def test_sample_weight_flat(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=False
)
result = auc_obj(
self.y_true_good, self.y_pred, sample_weight=[1, 2, 3, 4]
)
# tpr = [1, 1, 0.2857, 0.2857, 0]
# fpr = [1, 0.3333, 0, 0, 0]
expected_result = 1.0 - (0.3333 * (1.0 - 0.2857) / 2.0)
self.assertAllClose(result, expected_result, 1e-3)
def test_full_sample_weight_flat(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=False
)
sw = np.arange(4 * 2)
sw = sw.reshape(4, 2)
result = auc_obj(self.y_true_good, self.y_pred, sample_weight=sw)
# tpr = [1, 1, 0.2727, 0.2727, 0]
# fpr = [1, 0.3333, 0, 0, 0]
expected_result = 1.0 - (0.3333 * (1.0 - 0.2727) / 2.0)
self.assertAllClose(result, expected_result, 1e-3)
def test_label_weights(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
multi_label=True,
label_weights=[0.75, 0.25],
)
result = auc_obj(self.y_true_good, self.y_pred)
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)
self.assertAllClose(result, expected_result, 1e-3)
def test_label_weights_flat(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
multi_label=False,
label_weights=[0.75, 0.25],
)
result = auc_obj(self.y_true_good, self.y_pred)
# tpr = [1, 1, 0.375, 0.375, 0]
# fpr = [1, 0.375, 0, 0, 0]
expected_result = 1.0 - ((1.0 - 0.375) * 0.375 / 2.0)
self.assertAllClose(result, expected_result, 1e-2)
def test_unweighted_flat(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=False
)
result = auc_obj(self.y_true_good, self.y_pred)
# tp = [4, 4, 1, 1, 0]
# fp = [4, 1, 0, 0, 0]
# fn = [0, 0, 3, 3, 4]
# tn = [0, 3, 4, 4, 4]
# tpr = [1, 1, 0.25, 0.25, 0]
# fpr = [1, 0.25, 0, 0, 0]
expected_result = 1.0 - (3.0 / 32.0)
self.assertAllClose(result, expected_result, 1e-3)
def test_unweighted_flat_from_logits(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
multi_label=False,
from_logits=True,
)
result = auc_obj(self.y_true_good, self.y_pred_logits)
# tp = [4, 4, 1, 1, 0]
# fp = [4, 1, 0, 0, 0]
# fn = [0, 0, 3, 3, 4]
# tn = [0, 3, 4, 4, 4]
# tpr = [1, 1, 0.25, 0.25, 0]
# fpr = [1, 0.25, 0, 0, 0]
expected_result = 1.0 - (3.0 / 32.0)
self.assertAllClose(result, expected_result, 1e-3)
def test_manual_thresholds(self):
# Verify that when specified, thresholds are used instead of
# num_thresholds.
auc_obj = metrics.AUC(
num_thresholds=2, thresholds=[0.5], multi_label=True
)
self.assertEqual(auc_obj.num_thresholds, 3)
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
result = auc_obj(self.y_true_good, self.y_pred)
# tp = [[2, 1, 0], [2, 0, 0]]
# fp = [2, 0, 0], [2, 0, 0]]
# fn = [[0, 1, 2], [0, 2, 2]]
# tn = [[0, 2, 2], [0, 2, 2]]
# tpr = [[1, 0.5, 0], [1, 0, 0]]
# fpr = [[1, 0, 0], [1, 0, 0]]
# auc by slice = [0.75, 0.5]
expected_result = (0.75 + 0.5) / 2.0
self.assertAllClose(result, expected_result, 1e-3)
def test_weighted_roc_interpolation(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=True
)
result = auc_obj(
self.y_true_good, self.y_pred, sample_weight=self.sample_weight
)
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
expected_result = 1.0 - 0.5 * 0.43 * 0.67
self.assertAllClose(result, expected_result, 1e-1)
def test_pr_interpolation_unweighted(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, curve="PR", multi_label=True
)
good_result = auc_obj(self.y_true_good, self.y_pred)
with self.subTest(name="good"):
# PR AUCs are 0.917 and 1.0 respectively
self.assertAllClose(good_result, (0.91667 + 1.0) / 2.0, 1e-1)
bad_result = auc_obj(self.y_true_bad, self.y_pred)
with self.subTest(name="bad"):
# PR AUCs are 0.917 and 0.5 respectively
self.assertAllClose(bad_result, (0.91667 + 0.5) / 2.0, 1e-1)
def test_pr_interpolation(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, curve="PR", multi_label=True
)
good_result = auc_obj(
self.y_true_good, self.y_pred, sample_weight=self.sample_weight
)
# PR AUCs are 0.939 and 1.0 respectively
self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1)
def test_keras_model_compiles(self):
inputs = layers.Input(shape=(10,), batch_size=1)
output = layers.Dense(3, activation="sigmoid")(inputs)
model = models.Model(inputs=inputs, outputs=output)
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=[metrics.AUC(multi_label=True)],
)
def test_reset_state(self):
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds, multi_label=True
)
auc_obj(self.y_true_good, self.y_pred)
auc_obj.reset_state()
self.assertAllClose(auc_obj.true_positives, np.zeros((5, 2)))

@ -38,6 +38,59 @@ class ConfusionMatrix(Enum):
FALSE_NEGATIVES = "fn"
class AUCCurve(Enum):
"""Type of AUC Curve (ROC or PR)."""
ROC = "ROC"
PR = "PR"
@staticmethod
def from_str(key):
if key in ("pr", "PR"):
return AUCCurve.PR
elif key in ("roc", "ROC"):
return AUCCurve.ROC
else:
raise ValueError(
f'Invalid AUC curve value: "{key}". '
'Expected values are ["PR", "ROC"]'
)
class AUCSummationMethod(Enum):
"""Type of AUC summation method.
https://en.wikipedia.org/wiki/Riemann_sum)
Contains the following values:
* 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
`PR` curve, interpolates (true/false) positives but not the ratio that is
precision (see Davis & Goadrich 2006 for details).
* 'minoring': Applies left summation for increasing intervals and right
summation for decreasing intervals.
* 'majoring': Applies right summation for increasing intervals and left
summation for decreasing intervals.
"""
INTERPOLATION = "interpolation"
MAJORING = "majoring"
MINORING = "minoring"
@staticmethod
def from_str(key):
if key in ("interpolation", "Interpolation"):
return AUCSummationMethod.INTERPOLATION
elif key in ("majoring", "Majoring"):
return AUCSummationMethod.MAJORING
elif key in ("minoring", "Minoring"):
return AUCSummationMethod.MINORING
else:
raise ValueError(
f'Invalid AUC summation method value: "{key}". '
'Expected values are ["interpolation", "majoring", "minoring"]'
)
def _update_confusion_matrix_variables_optimized(
variables_to_update,
y_true,
@ -203,14 +256,15 @@ def _update_confusion_matrix_variables_optimized(
num_segments=num_thresholds,
)
tp_bucket_v = ops.vectorized_map(
gather_bucket, (true_labels, bucket_indices), warn=False
tp_bucket_v = backend.vectorized_map(
gather_bucket,
(true_labels, bucket_indices),
)
fp_bucket_v = ops.vectorized_map(
gather_bucket, (false_labels, bucket_indices), warn=False
fp_bucket_v = backend.vectorized_map(
gather_bucket, (false_labels, bucket_indices)
)
tp = ops.transpose(ops.cumsum(ops.flip(tp_bucket_v), axis=1))
fp = ops.transpose(ops.cumsum(ops.flip(fp_bucket_v), axis=1))
tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1)))
fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1)))
else:
tp_bucket_v = ops.segment_sum(
data=true_labels,
@ -222,8 +276,8 @@ def _update_confusion_matrix_variables_optimized(
segment_ids=bucket_indices,
num_segments=num_thresholds,
)
tp = ops.cumsum(ops.flip(tp_bucket_v))
fp = ops.cumsum(ops.flip(fp_bucket_v))
tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v)))
fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v)))
# fn = sum(true_labels) - tp
# tn = sum(false_labels) - fp
@ -383,7 +437,6 @@ def update_confusion_matrix_variables(
one_thresh = ops.equal(
ops.cast(1, dtype="int32"),
thresholds.ndim,
name="one_set_of_thresholds_cond",
)
else:
one_thresh = ops.cast(True, dtype="bool")