Add CSVLogger and TerminateOnNaN Callbacks (#95)
* Add CSV Logger and Terminate on Nan * Add CSVLogger and Terminate on Nan tests * Update CSV Logger docstring
This commit is contained in:
parent
02d0c27ce9
commit
cc89199f1e
@ -22,7 +22,6 @@ def resize(
|
||||
"Argument `size` must be a tuple of two elements "
|
||||
f"(height, width). Received: size={size}"
|
||||
)
|
||||
size = tuple(size)
|
||||
if len(image.shape) == 4:
|
||||
if data_format == "channels_last":
|
||||
size = (image.shape[0],) + size + (image.shape[-1],)
|
||||
|
@ -22,7 +22,7 @@ def resize(
|
||||
"Argument `size` must be a tuple of two elements "
|
||||
f"(height, width). Received: size={size}"
|
||||
)
|
||||
size = tuple(size)
|
||||
|
||||
if data_format == "channels_first":
|
||||
if len(image.shape) == 4:
|
||||
image = tf.transpose(image, (0, 2, 3, 1))
|
||||
|
@ -1,6 +1,8 @@
|
||||
from keras_core.callbacks.callback import Callback
|
||||
from keras_core.callbacks.callback_list import CallbackList
|
||||
from keras_core.callbacks.csv_logger import CSVLogger
|
||||
from keras_core.callbacks.early_stopping import EarlyStopping
|
||||
from keras_core.callbacks.history import History
|
||||
from keras_core.callbacks.lambda_callback import LambdaCallback
|
||||
from keras_core.callbacks.progbar_logger import ProgbarLogger
|
||||
from keras_core.callbacks.terminate_on_nan import TerminateOnNaN
|
||||
|
101
keras_core/callbacks/csv_logger.py
Normal file
101
keras_core/callbacks/csv_logger.py
Normal file
@ -0,0 +1,101 @@
|
||||
import collections
|
||||
import csv
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.callbacks.callback import Callback
|
||||
from keras_core.utils import file_utils
|
||||
|
||||
|
||||
@keras_core_export("keras_core.callbacks.CSVLogger")
|
||||
class CSVLogger(Callback):
|
||||
"""Callback that streams epoch results to a CSV file.
|
||||
|
||||
Supports all values that can be represented as a string,
|
||||
including 1D iterables such as `np.ndarray`.
|
||||
|
||||
Args:
|
||||
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
|
||||
separator: String used to separate elements in the CSV file.
|
||||
append: Boolean. True: append if file exists (useful for continuing
|
||||
training). False: overwrite existing file.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
csv_logger = CSVLogger('training.log')
|
||||
model.fit(X_train, Y_train, callbacks=[csv_logger])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, filename, separator=",", append=False):
|
||||
super().__init__()
|
||||
self.sep = separator
|
||||
self.filename = file_utils.path_to_string(filename)
|
||||
self.append = append
|
||||
self.writer = None
|
||||
self.keys = None
|
||||
self.append_header = True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if self.append:
|
||||
if tf.io.gfile.exists(self.filename):
|
||||
with tf.io.gfile.GFile(self.filename, "r") as f:
|
||||
self.append_header = not bool(len(f.readline()))
|
||||
mode = "a"
|
||||
else:
|
||||
mode = "w"
|
||||
self.csv_file = tf.io.gfile.GFile(self.filename, mode)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
logs = logs or {}
|
||||
|
||||
def handle_value(k):
|
||||
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
||||
if isinstance(k, str):
|
||||
return k
|
||||
elif (
|
||||
isinstance(k, collections.abc.Iterable)
|
||||
and not is_zero_dim_ndarray
|
||||
):
|
||||
return f"\"[{', '.join(map(str, k))}]\""
|
||||
else:
|
||||
return k
|
||||
|
||||
if self.keys is None:
|
||||
self.keys = sorted(logs.keys())
|
||||
# When validation_freq > 1, `val_` keys are not in first epoch logs
|
||||
# Add the `val_` keys so that its part of the fieldnames of writer.
|
||||
val_keys_found = False
|
||||
for key in self.keys:
|
||||
if key.startswith("val_"):
|
||||
val_keys_found = True
|
||||
break
|
||||
if not val_keys_found:
|
||||
self.keys.extend(["val_" + k for k in self.keys])
|
||||
|
||||
if not self.writer:
|
||||
|
||||
class CustomDialect(csv.excel):
|
||||
delimiter = self.sep
|
||||
|
||||
fieldnames = ["epoch"] + self.keys
|
||||
|
||||
self.writer = csv.DictWriter(
|
||||
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
|
||||
)
|
||||
if self.append_header:
|
||||
self.writer.writeheader()
|
||||
|
||||
row_dict = collections.OrderedDict({"epoch": epoch})
|
||||
row_dict.update(
|
||||
(key, handle_value(logs.get(key, "NA"))) for key in self.keys
|
||||
)
|
||||
self.writer.writerow(row_dict)
|
||||
self.csv_file.flush()
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
self.csv_file.close()
|
||||
self.writer = None
|
176
keras_core/callbacks/csv_logger_test.py
Normal file
176
keras_core/callbacks/csv_logger_test.py
Normal file
@ -0,0 +1,176 @@
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras_core import callbacks
|
||||
from keras_core import initializers
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
from keras_core.models import Sequential
|
||||
from keras_core.utils import numerical_utils
|
||||
|
||||
TRAIN_SAMPLES = 10
|
||||
TEST_SAMPLES = 10
|
||||
INPUT_DIM = 3
|
||||
BATCH_SIZE = 4
|
||||
|
||||
|
||||
class CSVLoggerTest(testing.TestCase):
|
||||
def test_CSVLogger(self):
|
||||
OUTPUT_DIM = 1
|
||||
np.random.seed(1337)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
filepath = os.path.join(temp_dir.name, "log.tsv")
|
||||
|
||||
sep = "\t"
|
||||
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
|
||||
y_train = np.random.random((TRAIN_SAMPLES, OUTPUT_DIM))
|
||||
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
|
||||
y_test = np.random.random((TEST_SAMPLES, OUTPUT_DIM))
|
||||
|
||||
def make_model():
|
||||
np.random.seed(1337)
|
||||
model = Sequential(
|
||||
[
|
||||
layers.Dense(2, activation="relu"),
|
||||
layers.Dense(OUTPUT_DIM),
|
||||
]
|
||||
)
|
||||
model.compile(
|
||||
loss="mse",
|
||||
optimizer="sgd",
|
||||
metrics=["mse"],
|
||||
)
|
||||
return model
|
||||
|
||||
# case 1, create new file with defined separator
|
||||
model = make_model()
|
||||
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
|
||||
model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=cbks,
|
||||
epochs=1,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
assert os.path.exists(filepath)
|
||||
with open(filepath) as csvfile:
|
||||
dialect = csv.Sniffer().sniff(csvfile.read())
|
||||
assert dialect.delimiter == sep
|
||||
del model
|
||||
del cbks
|
||||
|
||||
# case 2, append data to existing file, skip header
|
||||
model = make_model()
|
||||
cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)]
|
||||
model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=cbks,
|
||||
epochs=1,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
# case 3, reuse of CSVLogger object
|
||||
model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=cbks,
|
||||
epochs=2,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
with open(filepath) as csvfile:
|
||||
list_lines = csvfile.readlines()
|
||||
for line in list_lines:
|
||||
assert line.count(sep) == 4
|
||||
assert len(list_lines) == 5
|
||||
output = " ".join(list_lines)
|
||||
assert len(re.findall("epoch", output)) == 1
|
||||
|
||||
os.remove(filepath)
|
||||
|
||||
# case 3, Verify Val. loss also registered when Validation Freq > 1
|
||||
model = make_model()
|
||||
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
|
||||
hist = model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
validation_freq=3,
|
||||
callbacks=cbks,
|
||||
epochs=5,
|
||||
verbose=0,
|
||||
)
|
||||
assert os.path.exists(filepath)
|
||||
# Verify that validation loss is registered at val. freq
|
||||
with open(filepath) as csvfile:
|
||||
rows = csv.DictReader(csvfile, delimiter=sep)
|
||||
for idx, row in enumerate(rows, 1):
|
||||
self.assertIn("val_loss", row)
|
||||
if idx == 3:
|
||||
self.assertEqual(
|
||||
row["val_loss"], str(hist.history["val_loss"][0])
|
||||
)
|
||||
else:
|
||||
self.assertEqual(row["val_loss"], "NA")
|
||||
|
||||
def test_stop_training_csv(self):
|
||||
# Test that using the CSVLogger callback with the TerminateOnNaN
|
||||
# callback does not result in invalid CSVs.
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
csv_logfile = os.path.join(tmpdir.name, "csv_logger.csv")
|
||||
NUM_CLASSES = 2
|
||||
np.random.seed(1337)
|
||||
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
|
||||
y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)
|
||||
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
|
||||
y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)
|
||||
|
||||
y_test = numerical_utils.to_categorical(y_test)
|
||||
y_train = numerical_utils.to_categorical(y_train)
|
||||
model = Sequential()
|
||||
initializer = initializers.Constant(value=1e5)
|
||||
for _ in range(5):
|
||||
model.add(
|
||||
layers.Dense(
|
||||
2,
|
||||
activation="relu",
|
||||
kernel_initializer=initializer,
|
||||
)
|
||||
)
|
||||
model.add(layers.Dense(NUM_CLASSES))
|
||||
model.compile(loss="mean_squared_error", optimizer="sgd")
|
||||
|
||||
history = model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=[
|
||||
callbacks.TerminateOnNaN(),
|
||||
callbacks.CSVLogger(csv_logfile),
|
||||
],
|
||||
epochs=20,
|
||||
)
|
||||
loss = history.history["loss"]
|
||||
self.assertEqual(len(loss), 1)
|
||||
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
|
||||
|
||||
values = []
|
||||
with open(csv_logfile) as f:
|
||||
# On Windows, due to \r\n line ends, we may end up reading empty
|
||||
# lines after each line. Skip empty lines.
|
||||
values = [x for x in csv.reader(f) if x]
|
||||
self.assertIn("nan", values[-1], "NaN not logged in CSV Logger.")
|
20
keras_core/callbacks/terminate_on_nan.py
Normal file
20
keras_core/callbacks/terminate_on_nan.py
Normal file
@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.callbacks.callback import Callback
|
||||
from keras_core.utils import io_utils
|
||||
|
||||
|
||||
@keras_core_export("keras_core.callbacks.TerminateOnNaN")
|
||||
class TerminateOnNaN(Callback):
|
||||
"""Callback that terminates training when a NaN loss is encountered."""
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
logs = logs or {}
|
||||
loss = logs.get("loss")
|
||||
if loss is not None:
|
||||
if np.isnan(loss) or np.isinf(loss):
|
||||
io_utils.print_msg(
|
||||
f"Batch {batch}: Invalid loss, terminating training"
|
||||
)
|
||||
self.model.stop_training = True
|
50
keras_core/callbacks/terminate_on_nan_test.py
Normal file
50
keras_core/callbacks/terminate_on_nan_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import callbacks
|
||||
from keras_core import initializers
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
from keras_core.models import Sequential
|
||||
from keras_core.utils import numerical_utils
|
||||
|
||||
|
||||
class TerminateOnNaNTest(testing.TestCase):
|
||||
def test_TerminateOnNaN(self):
|
||||
TRAIN_SAMPLES = 10
|
||||
TEST_SAMPLES = 10
|
||||
INPUT_DIM = 3
|
||||
NUM_CLASSES = 2
|
||||
BATCH_SIZE = 4
|
||||
|
||||
np.random.seed(1337)
|
||||
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
|
||||
y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)
|
||||
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
|
||||
y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)
|
||||
|
||||
y_test = numerical_utils.to_categorical(y_test)
|
||||
y_train = numerical_utils.to_categorical(y_train)
|
||||
model = Sequential()
|
||||
initializer = initializers.Constant(value=1e5)
|
||||
for _ in range(5):
|
||||
model.add(
|
||||
layers.Dense(
|
||||
2,
|
||||
activation="relu",
|
||||
kernel_initializer=initializer,
|
||||
)
|
||||
)
|
||||
model.add(layers.Dense(NUM_CLASSES))
|
||||
model.compile(loss="mean_squared_error", optimizer="sgd")
|
||||
|
||||
history = model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=[callbacks.TerminateOnNaN()],
|
||||
epochs=20,
|
||||
)
|
||||
loss = history.history["loss"]
|
||||
self.assertEqual(len(loss), 1)
|
||||
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
|
@ -48,7 +48,6 @@ from keras_core.layers.pooling.global_max_pooling3d import GlobalMaxPooling3D
|
||||
from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
|
||||
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
|
||||
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
|
||||
from keras_core.layers.preprocessing.center_crop import CenterCrop
|
||||
from keras_core.layers.preprocessing.normalization import Normalization
|
||||
from keras_core.layers.preprocessing.rescaling import Rescaling
|
||||
from keras_core.layers.preprocessing.resizing import Resizing
|
||||
|
@ -1,130 +0,0 @@
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import image_utils
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.CenterCrop")
|
||||
class CenterCrop(Layer):
|
||||
"""A preprocessing layer which crops images.
|
||||
|
||||
This layers crops the central portion of the images to a target size. If an
|
||||
image is smaller than the target size, it will be resized and cropped
|
||||
so as to return the largest possible window in the image that matches
|
||||
the target aspect ratio.
|
||||
|
||||
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`).
|
||||
|
||||
Input shape:
|
||||
3D (unbatched) or 4D (batched) tensor with shape:
|
||||
`(..., height, width, channels)`, in `"channels_last"` format,
|
||||
or `(..., channels, height, width)`, in `"channels_first"` format.
|
||||
|
||||
Output shape:
|
||||
3D (unbatched) or 4D (batched) tensor with shape:
|
||||
`(..., target_height, target_width, channels)`,
|
||||
or `(..., channels, target_height, target_width)`,
|
||||
in `"channels_first"` format.
|
||||
|
||||
If the input height/width is even and the target height/width is odd (or
|
||||
inversely), the input image is left-padded by 1 pixel.
|
||||
|
||||
Args:
|
||||
height: Integer, the height of the output shape.
|
||||
width: Integer, the width of the output shape.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, height, width, channels)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, height, width)`. It defaults to the
|
||||
`image_data_format` value found in your Keras config file at
|
||||
`~/.keras/keras.json`. If you never set it, then it will be
|
||||
`"channels_last"`.
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, data_format=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.data_format = data_format or backend.image_data_format()
|
||||
|
||||
def call(self, inputs):
|
||||
if self.data_format == "channels_first":
|
||||
init_height = inputs.shape[-2]
|
||||
init_width = inputs.shape[-1]
|
||||
else:
|
||||
init_height = inputs.shape[-3]
|
||||
init_width = inputs.shape[-2]
|
||||
|
||||
if init_height is None or init_width is None:
|
||||
# Dynamic size case. TODO.
|
||||
raise ValueError(
|
||||
"At this time, CenterCrop can only "
|
||||
"process images with a static spatial "
|
||||
f"shape. Received: inputs.shape={inputs.shape}"
|
||||
)
|
||||
|
||||
h_diff = init_height - self.height
|
||||
w_diff = init_width - self.width
|
||||
|
||||
h_start = int(h_diff / 2)
|
||||
w_start = int(w_diff / 2)
|
||||
|
||||
if h_diff >= 0 and w_diff >= 0:
|
||||
if len(inputs.shape) == 4:
|
||||
if self.data_format == "channels_first":
|
||||
return inputs[
|
||||
:,
|
||||
:,
|
||||
h_start : h_start + self.height,
|
||||
w_start : w_start + self.width,
|
||||
]
|
||||
return inputs[
|
||||
:,
|
||||
h_start : h_start + self.height,
|
||||
w_start : w_start + self.width,
|
||||
:,
|
||||
]
|
||||
elif len(inputs.shape) == 3:
|
||||
if self.data_format == "channels_first":
|
||||
return inputs[
|
||||
:,
|
||||
h_start : h_start + self.height,
|
||||
w_start : w_start + self.width,
|
||||
]
|
||||
return inputs[
|
||||
h_start : h_start + self.height,
|
||||
w_start : w_start + self.width,
|
||||
:,
|
||||
]
|
||||
|
||||
return image_utils.smart_resize(
|
||||
inputs, [self.height, self.width], data_format=self.data_format
|
||||
)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = list(input_shape)
|
||||
if len(input_shape) == 4:
|
||||
if self.data_format == "channels_last":
|
||||
input_shape[1] = self.height
|
||||
input_shape[2] = self.width
|
||||
else:
|
||||
input_shape[2] = self.height
|
||||
input_shape[3] = self.width
|
||||
else:
|
||||
if self.data_format == "channels_last":
|
||||
input_shape[0] = self.height
|
||||
input_shape[1] = self.width
|
||||
else:
|
||||
input_shape[1] = self.height
|
||||
input_shape[2] = self.width
|
||||
return tuple(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
base_config = super().get_config()
|
||||
config = {
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
"data_format": self.data_format,
|
||||
}
|
||||
return {**base_config, **config}
|
@ -1,96 +0,0 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class CenterCropTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_center_crop_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.CenterCrop,
|
||||
init_kwargs={
|
||||
"height": 6,
|
||||
"width": 6,
|
||||
"data_format": "channels_last",
|
||||
},
|
||||
input_shape=(2, 12, 12, 3),
|
||||
expected_output_shape=(2, 6, 6, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.CenterCrop,
|
||||
init_kwargs={
|
||||
"height": 7,
|
||||
"width": 7,
|
||||
"data_format": "channels_first",
|
||||
},
|
||||
input_shape=(2, 3, 13, 13),
|
||||
expected_output_shape=(2, 3, 7, 7),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
[
|
||||
((5, 7), "channels_first"),
|
||||
((5, 7), "channels_last"),
|
||||
((15, 10), "channels_first"),
|
||||
((10, 17), "channels_last"),
|
||||
]
|
||||
)
|
||||
def test_center_crop_correctness(self, size, data_format):
|
||||
# batched case
|
||||
if data_format == "channels_first":
|
||||
img = np.random.random((2, 3, 9, 11))
|
||||
else:
|
||||
img = np.random.random((2, 9, 11, 3))
|
||||
out = layers.CenterCrop(
|
||||
size[0],
|
||||
size[1],
|
||||
data_format=data_format,
|
||||
)(img)
|
||||
if data_format == "channels_first":
|
||||
img_transpose = np.transpose(img, (0, 2, 3, 1))
|
||||
|
||||
ref_out = tf.transpose(
|
||||
tf.keras.layers.CenterCrop(size[0], size[1])(img_transpose),
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
else:
|
||||
ref_out = tf.keras.layers.CenterCrop(size[0], size[1])(img)
|
||||
self.assertAllClose(ref_out, out)
|
||||
|
||||
# unbatched case
|
||||
if data_format == "channels_first":
|
||||
img = np.random.random((3, 9, 11))
|
||||
else:
|
||||
img = np.random.random((9, 11, 3))
|
||||
out = layers.CenterCrop(
|
||||
size[0],
|
||||
size[1],
|
||||
data_format=data_format,
|
||||
)(img)
|
||||
if data_format == "channels_first":
|
||||
img_transpose = np.transpose(img, (1, 2, 0))
|
||||
ref_out = tf.transpose(
|
||||
tf.keras.layers.CenterCrop(
|
||||
size[0],
|
||||
size[1],
|
||||
)(img_transpose),
|
||||
(2, 0, 1),
|
||||
)
|
||||
else:
|
||||
ref_out = tf.keras.layers.CenterCrop(
|
||||
size[0],
|
||||
size[1],
|
||||
)(img)
|
||||
self.assertAllClose(ref_out, out)
|
@ -14,17 +14,6 @@ class Resizing(Layer):
|
||||
format. Input pixel values can be of any range
|
||||
(e.g. `[0., 1.)` or `[0, 255]`).
|
||||
|
||||
Input shape:
|
||||
3D (unbatched) or 4D (batched) tensor with shape:
|
||||
`(..., height, width, channels)`, in `"channels_last"` format,
|
||||
or `(..., channels, height, width)`, in `"channels_first"` format.
|
||||
|
||||
Output shape:
|
||||
3D (unbatched) or 4D (batched) tensor with shape:
|
||||
`(..., target_height, target_width, channels)`,
|
||||
or `(..., channels, target_height, target_width)`,
|
||||
in `"channels_first"` format.
|
||||
|
||||
Args:
|
||||
height: Integer, the height of the output shape.
|
||||
width: Integer, the width of the output shape.
|
||||
|
@ -5,6 +5,7 @@ from keras_core.metrics.accuracy_metrics import CategoricalAccuracy
|
||||
from keras_core.metrics.accuracy_metrics import SparseCategoricalAccuracy
|
||||
from keras_core.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
|
||||
from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
|
||||
from keras_core.metrics.confusion_metrics import AUC
|
||||
from keras_core.metrics.confusion_metrics import FalseNegatives
|
||||
from keras_core.metrics.confusion_metrics import FalsePositives
|
||||
from keras_core.metrics.confusion_metrics import Precision
|
||||
@ -41,6 +42,7 @@ ALL_OBJECTS = {
|
||||
# Regression
|
||||
MeanSquaredError,
|
||||
# Classification
|
||||
AUC,
|
||||
FalseNegatives,
|
||||
FalsePositives,
|
||||
Precision,
|
||||
|
@ -1,3 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import activations
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
@ -366,7 +369,7 @@ class Precision(Metric):
|
||||
y_pred: The predicted values. Each element must be in the range
|
||||
`[0, 1]`.
|
||||
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||
Can be a tensor whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`.
|
||||
"""
|
||||
metrics_utils.update_confusion_matrix_variables(
|
||||
@ -507,7 +510,7 @@ class Recall(Metric):
|
||||
y_pred: The predicted values. Each element must be in the range
|
||||
`[0, 1]`.
|
||||
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||
Can be a tensor whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`.
|
||||
"""
|
||||
metrics_utils.update_confusion_matrix_variables(
|
||||
@ -605,7 +608,7 @@ class SensitivitySpecificityBase(Metric):
|
||||
y_true: The ground truth values.
|
||||
y_pred: The predicted values.
|
||||
sample_weight: Optional weighting of each example. Defaults to 1.
|
||||
Can be a `Tensor` whose rank is either 0, or the same rank as
|
||||
Can be a tensor whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`.
|
||||
"""
|
||||
metrics_utils.update_confusion_matrix_variables(
|
||||
@ -1052,3 +1055,518 @@ class RecallAtPrecision(SensitivitySpecificityBase):
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.AUC")
|
||||
class AUC(Metric):
|
||||
"""Approximates the AUC (Area under the curve) of the ROC or PR curves.
|
||||
|
||||
The AUC (Area under the curve) of the ROC (Receiver operating
|
||||
characteristic; default) or PR (Precision Recall) curves are quality
|
||||
measures of binary classifiers. Unlike the accuracy, and like cross-entropy
|
||||
losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
|
||||
|
||||
This class approximates AUCs using a Riemann sum. During the metric
|
||||
accumulation phrase, predictions are accumulated within predefined buckets
|
||||
by value. The AUC is then computed by interpolating per-bucket averages.
|
||||
These buckets define the evaluated operational points.
|
||||
|
||||
This metric creates four local variables, `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` that are used to
|
||||
compute the AUC. To discretize the AUC curve, a linearly spaced set of
|
||||
thresholds is used to compute pairs of recall and precision values. The area
|
||||
under the ROC-curve is therefore computed using the height of the recall
|
||||
values by the false positive rate, while the area under the PR-curve is the
|
||||
computed using the height of the precision values by the recall.
|
||||
|
||||
This value is ultimately returned as `auc`, an idempotent operation that
|
||||
computes the area under a discretized curve of precision versus recall
|
||||
values (computed using the aforementioned variables). The `num_thresholds`
|
||||
variable controls the degree of discretization with larger numbers of
|
||||
thresholds more closely approximating the true AUC. The quality of the
|
||||
approximation may vary dramatically depending on `num_thresholds`. The
|
||||
`thresholds` parameter can be used to manually specify thresholds which
|
||||
split the predictions more evenly.
|
||||
|
||||
For a best approximation of the real AUC, `predictions` should be
|
||||
distributed approximately uniformly in the range `[0, 1]` (if
|
||||
`from_logits=False`). The quality of the AUC approximation may be poor if
|
||||
this is not the case. Setting `summation_method` to 'minoring' or 'majoring'
|
||||
can help quantify the error in the approximation by providing lower or upper
|
||||
bound estimate of the AUC.
|
||||
|
||||
If `sample_weight` is `None`, weights default to 1.
|
||||
Use `sample_weight` of 0 to mask values.
|
||||
|
||||
Args:
|
||||
num_thresholds: (Optional) The number of thresholds to
|
||||
use when discretizing the roc curve. Values must be > 1.
|
||||
Defaults to `200`.
|
||||
curve: (Optional) Specifies the name of the curve to be computed,
|
||||
`'ROC'` (default) or `'PR'` for the Precision-Recall-curve.
|
||||
summation_method: (Optional) Specifies the [Riemann summation method](
|
||||
https://en.wikipedia.org/wiki/Riemann_sum) used.
|
||||
'interpolation' (default) applies mid-point summation scheme for
|
||||
`ROC`. For PR-AUC, interpolates (true/false) positives but not
|
||||
the ratio that is precision (see Davis & Goadrich 2006 for
|
||||
details); 'minoring' applies left summation for increasing
|
||||
intervals and right summation for decreasing intervals; 'majoring'
|
||||
does the opposite.
|
||||
name: (Optional) string name of the metric instance.
|
||||
dtype: (Optional) data type of the metric result.
|
||||
thresholds: (Optional) A list of floating point values to use as the
|
||||
thresholds for discretizing the curve. If set, the `num_thresholds`
|
||||
parameter is ignored. Values should be in `[0, 1]`. Endpoint
|
||||
thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive
|
||||
epsilon value will be automatically included with these to correctly
|
||||
handle predictions equal to exactly 0 or 1.
|
||||
multi_label: boolean indicating whether multilabel data should be
|
||||
treated as such, wherein AUC is computed separately for each label
|
||||
and then averaged across labels, or (when `False`) if the data
|
||||
should be flattened into a single label before AUC computation. In
|
||||
the latter case, when multilabel data is passed to AUC, each
|
||||
label-prediction pair is treated as an individual data point. Should
|
||||
be set to False for multi-class data.
|
||||
num_labels: (Optional) The number of labels, used when `multi_label` is
|
||||
True. If `num_labels` is not specified, then state variables get
|
||||
created on the first call to `update_state`.
|
||||
label_weights: (Optional) list, array, or tensor of non-negative weights
|
||||
used to compute AUCs for multilabel data. When `multi_label` is
|
||||
True, the weights are applied to the individual label AUCs when they
|
||||
are averaged to produce the multi-label AUC. When it's False, they
|
||||
are used to weight the individual label predictions in computing the
|
||||
confusion matrix on the flattened data. Note that this is unlike
|
||||
`class_weights` in that `class_weights` weights the example
|
||||
depending on the value of its label, whereas `label_weights` depends
|
||||
only on the index of that label before flattening; therefore
|
||||
`label_weights` should not be used for multi-class data.
|
||||
from_logits: boolean indicating whether the predictions (`y_pred` in
|
||||
`update_state`) are probabilities or sigmoid logits. As a rule of thumb,
|
||||
when using a keras loss, the `from_logits` constructor argument of the
|
||||
loss should match the AUC `from_logits` constructor argument.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> m = keras_core.metrics.AUC(num_thresholds=3)
|
||||
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
|
||||
>>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
|
||||
>>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
>>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
|
||||
>>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
|
||||
>>> # = 0.75
|
||||
>>> m.result()
|
||||
0.75
|
||||
|
||||
>>> m.reset_state()
|
||||
>>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
|
||||
... sample_weight=[1, 0, 0, 1])
|
||||
>>> m.result()
|
||||
1.0
|
||||
|
||||
Usage with `compile()` API:
|
||||
|
||||
```python
|
||||
# Reports the AUC of a model outputting a probability.
|
||||
model.compile(optimizer='sgd',
|
||||
loss=keras_core.losses.BinaryCrossentropy(),
|
||||
metrics=[keras_core.metrics.AUC()])
|
||||
|
||||
# Reports the AUC of a model outputting a logit.
|
||||
model.compile(optimizer='sgd',
|
||||
loss=keras_core.losses.BinaryCrossentropy(from_logits=True),
|
||||
metrics=[keras_core.metrics.AUC(from_logits=True)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_thresholds=200,
|
||||
curve="ROC",
|
||||
summation_method="interpolation",
|
||||
name=None,
|
||||
dtype=None,
|
||||
thresholds=None,
|
||||
multi_label=False,
|
||||
num_labels=None,
|
||||
label_weights=None,
|
||||
from_logits=False,
|
||||
):
|
||||
# Validate configurations.
|
||||
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
||||
metrics_utils.AUCCurve
|
||||
):
|
||||
raise ValueError(
|
||||
f'Invalid `curve` argument value "{curve}". '
|
||||
f"Expected one of: {list(metrics_utils.AUCCurve)}"
|
||||
)
|
||||
if isinstance(
|
||||
summation_method, metrics_utils.AUCSummationMethod
|
||||
) and summation_method not in list(metrics_utils.AUCSummationMethod):
|
||||
raise ValueError(
|
||||
"Invalid `summation_method` "
|
||||
f'argument value "{summation_method}". '
|
||||
f"Expected one of: {list(metrics_utils.AUCSummationMethod)}"
|
||||
)
|
||||
|
||||
# Update properties.
|
||||
self._init_from_thresholds = thresholds is not None
|
||||
if thresholds is not None:
|
||||
# If specified, use the supplied thresholds.
|
||||
self.num_thresholds = len(thresholds) + 2
|
||||
thresholds = sorted(thresholds)
|
||||
self._thresholds_distributed_evenly = (
|
||||
metrics_utils.is_evenly_distributed_thresholds(
|
||||
np.array([0.0] + thresholds + [1.0])
|
||||
)
|
||||
)
|
||||
else:
|
||||
if num_thresholds <= 1:
|
||||
raise ValueError(
|
||||
"Argument `num_thresholds` must be an integer > 1. "
|
||||
f"Received: num_thresholds={num_thresholds}"
|
||||
)
|
||||
|
||||
# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
|
||||
# (0, 1).
|
||||
self.num_thresholds = num_thresholds
|
||||
thresholds = [
|
||||
(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds - 2)
|
||||
]
|
||||
self._thresholds_distributed_evenly = True
|
||||
|
||||
# Add an endpoint "threshold" below zero and above one for either
|
||||
# threshold method to account for floating point imprecisions.
|
||||
self._thresholds = np.array(
|
||||
[0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]
|
||||
)
|
||||
|
||||
if isinstance(curve, metrics_utils.AUCCurve):
|
||||
self.curve = curve
|
||||
else:
|
||||
self.curve = metrics_utils.AUCCurve.from_str(curve)
|
||||
if isinstance(summation_method, metrics_utils.AUCSummationMethod):
|
||||
self.summation_method = summation_method
|
||||
else:
|
||||
self.summation_method = metrics_utils.AUCSummationMethod.from_str(
|
||||
summation_method
|
||||
)
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
|
||||
# Handle multilabel arguments.
|
||||
self.multi_label = multi_label
|
||||
self.num_labels = num_labels
|
||||
if label_weights is not None:
|
||||
label_weights = ops.array(label_weights, dtype=self.dtype)
|
||||
self.label_weights = label_weights
|
||||
|
||||
else:
|
||||
self.label_weights = None
|
||||
|
||||
self._from_logits = from_logits
|
||||
|
||||
self._built = False
|
||||
if self.multi_label:
|
||||
if num_labels:
|
||||
shape = [None, num_labels]
|
||||
self._build(shape)
|
||||
else:
|
||||
if num_labels:
|
||||
raise ValueError(
|
||||
"`num_labels` is needed only when `multi_label` is True."
|
||||
)
|
||||
self._build(None)
|
||||
|
||||
@property
|
||||
def thresholds(self):
|
||||
"""The thresholds used for evaluating AUC."""
|
||||
return list(self._thresholds)
|
||||
|
||||
def _build(self, shape):
|
||||
"""Initialize TP, FP, TN, and FN tensors, given the shape of the
|
||||
data."""
|
||||
if self.multi_label:
|
||||
if len(shape) != 2:
|
||||
raise ValueError(
|
||||
"`y_pred` must have rank 2 when `multi_label=True`. "
|
||||
f"Found rank {len(shape)}. "
|
||||
f"Full shape received for `y_pred`: {shape}"
|
||||
)
|
||||
self._num_labels = shape[1]
|
||||
variable_shape = [self.num_thresholds, self._num_labels]
|
||||
else:
|
||||
variable_shape = [self.num_thresholds]
|
||||
|
||||
self._build_input_shape = shape
|
||||
# Create metric variables
|
||||
self.true_positives = self.add_variable(
|
||||
shape=variable_shape,
|
||||
initializer=initializers.Zeros(),
|
||||
name="true_positives",
|
||||
)
|
||||
self.false_positives = self.add_variable(
|
||||
shape=variable_shape,
|
||||
initializer=initializers.Zeros(),
|
||||
name="false_positives",
|
||||
)
|
||||
self.true_negatives = self.add_variable(
|
||||
shape=variable_shape,
|
||||
initializer=initializers.Zeros(),
|
||||
name="true_negatives",
|
||||
)
|
||||
self.false_negatives = self.add_variable(
|
||||
shape=variable_shape,
|
||||
initializer=initializers.Zeros(),
|
||||
name="false_negatives",
|
||||
)
|
||||
|
||||
self._built = True
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
"""Accumulates confusion matrix statistics.
|
||||
|
||||
Args:
|
||||
y_true: The ground truth values.
|
||||
y_pred: The predicted values.
|
||||
sample_weight: Optional weighting of each example. Can
|
||||
be a tensor whose rank is either 0, or the same rank as
|
||||
`y_true`, and must be broadcastable to `y_true`. Defaults to
|
||||
`1`.
|
||||
"""
|
||||
if not self._built:
|
||||
self._build(y_pred.shape)
|
||||
|
||||
if self.multi_label or (self.label_weights is not None):
|
||||
# y_true should have shape (number of examples, number of labels).
|
||||
shapes = [(y_true, ("N", "L"))]
|
||||
if self.multi_label:
|
||||
# TP, TN, FP, and FN should all have shape
|
||||
# (number of thresholds, number of labels).
|
||||
shapes.extend(
|
||||
[
|
||||
(self.true_positives, ("T", "L")),
|
||||
(self.true_negatives, ("T", "L")),
|
||||
(self.false_positives, ("T", "L")),
|
||||
(self.false_negatives, ("T", "L")),
|
||||
]
|
||||
)
|
||||
if self.label_weights is not None:
|
||||
# label_weights should be of length equal to the number of
|
||||
# labels.
|
||||
shapes.append((self.label_weights, ("L",)))
|
||||
|
||||
# Only forward label_weights to update_confusion_matrix_variables when
|
||||
# multi_label is False. Otherwise the averaging of individual label AUCs
|
||||
# is handled in AUC.result
|
||||
label_weights = None if self.multi_label else self.label_weights
|
||||
|
||||
if self._from_logits:
|
||||
y_pred = activations.sigmoid(y_pred)
|
||||
|
||||
metrics_utils.update_confusion_matrix_variables(
|
||||
{
|
||||
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
|
||||
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
|
||||
},
|
||||
y_true,
|
||||
y_pred,
|
||||
self._thresholds,
|
||||
thresholds_distributed_evenly=self._thresholds_distributed_evenly,
|
||||
sample_weight=sample_weight,
|
||||
multi_label=self.multi_label,
|
||||
label_weights=label_weights,
|
||||
)
|
||||
|
||||
def interpolate_pr_auc(self):
|
||||
"""Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
|
||||
|
||||
https://www.biostat.wisc.edu/~page/rocpr.pdf
|
||||
|
||||
Note here we derive & use a closed formula not present in the paper
|
||||
as follows:
|
||||
|
||||
Precision = TP / (TP + FP) = TP / P
|
||||
|
||||
Modeling all of TP (true positive), FP (false positive) and their sum
|
||||
P = TP + FP (predicted positive) as varying linearly within each
|
||||
interval [A, B] between successive thresholds, we get
|
||||
|
||||
Precision slope = dTP / dP
|
||||
= (TP_B - TP_A) / (P_B - P_A)
|
||||
= (TP - TP_A) / (P - P_A)
|
||||
Precision = (TP_A + slope * (P - P_A)) / P
|
||||
|
||||
The area within the interval is (slope / total_pos_weight) times
|
||||
|
||||
int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
|
||||
int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
|
||||
|
||||
where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
|
||||
|
||||
int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
|
||||
|
||||
Bringing back the factor (slope / total_pos_weight) we'd put aside, we
|
||||
get
|
||||
|
||||
slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
|
||||
|
||||
where dTP == TP_B - TP_A.
|
||||
|
||||
Note that when P_A == 0 the above calculation simplifies into
|
||||
|
||||
int_A^B{Precision.dTP} = int_A^B{slope * dTP}
|
||||
= slope * (TP_B - TP_A)
|
||||
|
||||
which is really equivalent to imputing constant precision throughout the
|
||||
first bucket having >0 true positives.
|
||||
|
||||
Returns:
|
||||
pr_auc: an approximation of the area under the P-R curve.
|
||||
"""
|
||||
|
||||
dtp = (
|
||||
self.true_positives[: self.num_thresholds - 1]
|
||||
- self.true_positives[1:]
|
||||
)
|
||||
p = ops.add(self.true_positives, self.false_positives)
|
||||
dp = p[: self.num_thresholds - 1] - p[1:]
|
||||
prec_slope = ops.divide(dtp, ops.maximum(dp, backend.epsilon()))
|
||||
intercept = self.true_positives[1:] - ops.multiply(prec_slope, p[1:])
|
||||
|
||||
safe_p_ratio = ops.where(
|
||||
ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
|
||||
ops.divide(
|
||||
p[: self.num_thresholds - 1],
|
||||
ops.maximum(p[1:], backend.epsilon()),
|
||||
),
|
||||
ops.ones_like(p[1:]),
|
||||
)
|
||||
|
||||
pr_auc_increment = ops.divide(
|
||||
prec_slope * (dtp + intercept * ops.log(safe_p_ratio)),
|
||||
ops.maximum(
|
||||
self.true_positives[1:] + self.false_negatives[1:],
|
||||
backend.epsilon(),
|
||||
),
|
||||
)
|
||||
|
||||
if self.multi_label:
|
||||
by_label_auc = ops.sum(pr_auc_increment, axis=0)
|
||||
if self.label_weights is None:
|
||||
# Evenly weighted average of the label AUCs.
|
||||
return ops.mean(by_label_auc)
|
||||
else:
|
||||
# Weighted average of the label AUCs.
|
||||
return ops.divide(
|
||||
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
|
||||
ops.maximum(ops.sum(self.label_weights), backend.epsilon()),
|
||||
)
|
||||
else:
|
||||
return ops.sum(pr_auc_increment)
|
||||
|
||||
def result(self):
|
||||
if (
|
||||
self.curve == metrics_utils.AUCCurve.PR
|
||||
and self.summation_method
|
||||
== metrics_utils.AUCSummationMethod.INTERPOLATION
|
||||
):
|
||||
# This use case is different and is handled separately.
|
||||
return self.interpolate_pr_auc()
|
||||
|
||||
# Set `x` and `y` values for the curves based on `curve` config.
|
||||
recall = ops.divide(
|
||||
self.true_positives,
|
||||
ops.maximum(
|
||||
ops.add(self.true_positives, self.false_negatives),
|
||||
backend.epsilon(),
|
||||
),
|
||||
)
|
||||
if self.curve == metrics_utils.AUCCurve.ROC:
|
||||
fp_rate = ops.divide(
|
||||
self.false_positives,
|
||||
ops.maximum(
|
||||
ops.add(self.false_positives, self.true_negatives),
|
||||
backend.epsilon(),
|
||||
),
|
||||
)
|
||||
x = fp_rate
|
||||
y = recall
|
||||
else: # curve == 'PR'.
|
||||
precision = ops.divide(
|
||||
self.true_positives,
|
||||
ops.maximum(
|
||||
ops.add(self.true_positives, self.false_positives),
|
||||
backend.epsilon(),
|
||||
),
|
||||
)
|
||||
x = recall
|
||||
y = precision
|
||||
|
||||
# Find the rectangle heights based on `summation_method`.
|
||||
if (
|
||||
self.summation_method
|
||||
== metrics_utils.AUCSummationMethod.INTERPOLATION
|
||||
):
|
||||
# Note: the case ('PR', 'interpolation') has been handled above.
|
||||
heights = (y[: self.num_thresholds - 1] + y[1:]) / 2.0
|
||||
elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
|
||||
heights = ops.minimum(y[: self.num_thresholds - 1], y[1:])
|
||||
# self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
|
||||
else:
|
||||
heights = ops.maximum(y[: self.num_thresholds - 1], y[1:])
|
||||
|
||||
# Sum up the areas of all the rectangles.
|
||||
if self.multi_label:
|
||||
riemann_terms = ops.multiply(
|
||||
x[: self.num_thresholds - 1] - x[1:], heights
|
||||
)
|
||||
by_label_auc = ops.sum(riemann_terms, axis=0)
|
||||
|
||||
if self.label_weights is None:
|
||||
# Unweighted average of the label AUCs.
|
||||
return ops.mean(by_label_auc)
|
||||
else:
|
||||
# Weighted average of the label AUCs.
|
||||
return ops.divide(
|
||||
ops.sum(ops.multiply(by_label_auc, self.label_weights)),
|
||||
ops.maximum(ops.sum(self.label_weights), backend.epsilon()),
|
||||
)
|
||||
else:
|
||||
return ops.sum(
|
||||
ops.multiply(x[: self.num_thresholds - 1] - x[1:], heights)
|
||||
)
|
||||
|
||||
def reset_state(self):
|
||||
if self._built:
|
||||
if self.multi_label:
|
||||
variable_shape = (self.num_thresholds, self._num_labels)
|
||||
else:
|
||||
variable_shape = (self.num_thresholds,)
|
||||
|
||||
self.true_positives.assign(ops.zeros(variable_shape))
|
||||
self.false_positives.assign(ops.zeros(variable_shape))
|
||||
self.true_negatives.assign(ops.zeros(variable_shape))
|
||||
self.false_negatives.assign(ops.zeros(variable_shape))
|
||||
|
||||
def get_config(self):
|
||||
label_weights = self.label_weights
|
||||
config = {
|
||||
"num_thresholds": self.num_thresholds,
|
||||
"curve": self.curve.value,
|
||||
"summation_method": self.summation_method.value,
|
||||
"multi_label": self.multi_label,
|
||||
"num_labels": self.num_labels,
|
||||
"label_weights": label_weights,
|
||||
"from_logits": self._from_logits,
|
||||
}
|
||||
# optimization to avoid serializing a large number of generated
|
||||
# thresholds
|
||||
if self._init_from_thresholds:
|
||||
# We remove the endpoint thresholds as an inverse of how the
|
||||
# thresholds were initialized. This ensures that a metric
|
||||
# initialized from this config has the same thresholds.
|
||||
config["thresholds"] = self.thresholds[1:-1]
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
@ -1,10 +1,16 @@
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.ops.numpy_ops import np_config
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import metrics
|
||||
from keras_core import models
|
||||
from keras_core import operations as ops
|
||||
from keras_core import testing
|
||||
from keras_core.metrics import metrics_utils
|
||||
|
||||
# TODO: remove reliance on this (or alternatively, turn it on by default).
|
||||
# This is no longer needed with tf-nightly.
|
||||
@ -1101,3 +1107,558 @@ class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 0"
|
||||
):
|
||||
metrics.RecallAtPrecision(0.4, num_thresholds=-1)
|
||||
|
||||
|
||||
class AUCTest(testing.TestCase):
|
||||
def setUp(self):
|
||||
self.num_thresholds = 3
|
||||
self.y_pred = np.array([0, 0.5, 0.3, 0.9], dtype="float32")
|
||||
self.y_pred_multi_label = np.array(
|
||||
[[0.0, 0.4], [0.5, 0.7], [0.3, 0.2], [0.9, 0.3]], dtype="float32"
|
||||
)
|
||||
epsilon = 1e-12
|
||||
self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||
self.y_true = np.array([0, 0, 1, 1])
|
||||
self.y_true_multi_label = np.array([[0, 0], [1, 1], [1, 1], [1, 0]])
|
||||
self.sample_weight = [1, 2, 3, 4]
|
||||
|
||||
# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
|
||||
# y_pred when threshold = 0 - 1e-7 : [1, 1, 1, 1]
|
||||
# y_pred when threshold = 0.5 : [0, 0, 0, 1]
|
||||
# y_pred when threshold = 1 + 1e-7 : [0, 0, 0, 0]
|
||||
|
||||
# without sample_weight:
|
||||
# tp = np.sum([[0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]], axis=1)
|
||||
# fp = np.sum([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)
|
||||
# fn = np.sum([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]], axis=1)
|
||||
# tn = np.sum([[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]], axis=1)
|
||||
|
||||
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
|
||||
# with sample_weight:
|
||||
# tp = np.sum([[0, 0, 3, 4], [0, 0, 0, 4], [0, 0, 0, 0]], axis=1)
|
||||
# fp = np.sum([[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)
|
||||
# fn = np.sum([[0, 0, 0, 0], [0, 0, 3, 0], [0, 0, 3, 4]], axis=1)
|
||||
# tn = np.sum([[0, 0, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]], axis=1)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
|
||||
def test_config(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=100,
|
||||
curve="PR",
|
||||
summation_method="majoring",
|
||||
name="auc_1",
|
||||
dtype="float64",
|
||||
multi_label=True,
|
||||
num_labels=2,
|
||||
from_logits=True,
|
||||
)
|
||||
auc_obj.update_state(self.y_true_multi_label, self.y_pred_multi_label)
|
||||
self.assertEqual(auc_obj.name, "auc_1")
|
||||
self.assertEqual(auc_obj._dtype, "float64")
|
||||
self.assertLen(auc_obj.variables, 4)
|
||||
self.assertEqual(auc_obj.num_thresholds, 100)
|
||||
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(
|
||||
auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING
|
||||
)
|
||||
self.assertTrue(auc_obj.multi_label)
|
||||
self.assertEqual(auc_obj.num_labels, 2)
|
||||
self.assertTrue(auc_obj._from_logits)
|
||||
old_config = auc_obj.get_config()
|
||||
self.assertNotIn("thresholds", old_config)
|
||||
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
|
||||
|
||||
# Check save and restore config.
|
||||
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
|
||||
auc_obj2.update_state(self.y_true_multi_label, self.y_pred_multi_label)
|
||||
self.assertEqual(auc_obj2.name, "auc_1")
|
||||
self.assertLen(auc_obj2.variables, 4)
|
||||
self.assertEqual(auc_obj2.num_thresholds, 100)
|
||||
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(
|
||||
auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING
|
||||
)
|
||||
self.assertTrue(auc_obj2.multi_label)
|
||||
self.assertEqual(auc_obj2.num_labels, 2)
|
||||
self.assertTrue(auc_obj2._from_logits)
|
||||
new_config = auc_obj2.get_config()
|
||||
self.assertNotIn("thresholds", new_config)
|
||||
self.assertDictEqual(old_config, new_config)
|
||||
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
|
||||
|
||||
def test_config_manual_thresholds(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=None,
|
||||
curve="PR",
|
||||
summation_method="majoring",
|
||||
name="auc_1",
|
||||
thresholds=[0.3, 0.5],
|
||||
)
|
||||
auc_obj.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj.name, "auc_1")
|
||||
self.assertLen(auc_obj.variables, 4)
|
||||
self.assertEqual(auc_obj.num_thresholds, 4)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])
|
||||
self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(
|
||||
auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING
|
||||
)
|
||||
old_config = auc_obj.get_config()
|
||||
self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))
|
||||
|
||||
# Check save and restore config.
|
||||
auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
|
||||
auc_obj2.update_state(self.y_true, self.y_pred)
|
||||
self.assertEqual(auc_obj2.name, "auc_1")
|
||||
self.assertLen(auc_obj2.variables, 4)
|
||||
self.assertEqual(auc_obj2.num_thresholds, 4)
|
||||
self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
|
||||
self.assertEqual(
|
||||
auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING
|
||||
)
|
||||
new_config = auc_obj2.get_config()
|
||||
self.assertDictEqual(old_config, new_config)
|
||||
self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
auc_obj = metrics.AUC()
|
||||
self.assertEqual(auc_obj(self.y_true, self.y_true), 1)
|
||||
|
||||
def test_unweighted(self):
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
|
||||
result = auc_obj(self.y_true, self.y_pred)
|
||||
|
||||
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
|
||||
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
|
||||
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 0.75 * 1 + 0.25 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_from_logits(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, from_logits=True
|
||||
)
|
||||
result = auc_obj(self.y_true, self.y_pred_logits)
|
||||
|
||||
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
|
||||
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
|
||||
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 0.75 * 1 + 0.25 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_manual_thresholds(self):
|
||||
# Verify that when specified, thresholds are used instead of
|
||||
# num_thresholds.
|
||||
auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5])
|
||||
self.assertEqual(auc_obj.num_thresholds, 3)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
|
||||
result = auc_obj(self.y_true, self.y_pred)
|
||||
|
||||
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
|
||||
# recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
|
||||
# fp_rate = [2/2, 0, 0] = [1, 0, 0]
|
||||
# heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 0.75 * 1 + 0.25 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_roc_interpolation(self):
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
|
||||
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
|
||||
# heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 0.7855 * 1 + 0.2855 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_roc_majoring(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, summation_method="majoring"
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
|
||||
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
|
||||
# heights = [max(1, 0.571), max(0.571, 0)] = [1, 0.571]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 1 * 1 + 0.571 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_roc_minoring(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, summation_method="minoring"
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
|
||||
# fp_rate = [3/3, 0, 0] = [1, 0, 0]
|
||||
# heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0]
|
||||
# widths = [(1 - 0), (0 - 0)] = [1, 0]
|
||||
expected_result = 0.571 * 1 + 0 * 0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_pr_majoring(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
curve="PR",
|
||||
summation_method="majoring",
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]
|
||||
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
|
||||
# heights = [max(0.7, 1), max(1, 0)] = [1, 1]
|
||||
# widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]
|
||||
expected_result = 1 * 0.429 + 1 * 0.571
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_pr_minoring(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
curve="PR",
|
||||
summation_method="minoring",
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]
|
||||
# recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
|
||||
# heights = [min(0.7, 1), min(1, 0)] = [0.7, 0]
|
||||
# widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]
|
||||
expected_result = 0.7 * 0.429 + 0 * 0.571
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_pr_interpolation(self):
|
||||
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR")
|
||||
result = auc_obj(
|
||||
self.y_true, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)]
|
||||
|
||||
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
|
||||
# P = tp + fp = [10, 4, 0]
|
||||
# dTP = [7-4, 4-0] = [3, 4]
|
||||
# dP = [10-4, 4-0] = [6, 4]
|
||||
# slope = dTP/dP = [0.5, 1]
|
||||
# intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0]
|
||||
# (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1]
|
||||
# auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))]
|
||||
# = [2.416, 4]
|
||||
# auc = [2.416, 4]/(tp[1:]+fn[1:])
|
||||
expected_result = 2.416 / 7 + 4 / 7
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_invalid_num_thresholds(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 1"
|
||||
):
|
||||
metrics.AUC(num_thresholds=-1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Argument `num_thresholds` must be an integer > 1."
|
||||
):
|
||||
metrics.AUC(num_thresholds=1)
|
||||
|
||||
def test_invalid_curve(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Invalid AUC curve value: "Invalid".'
|
||||
):
|
||||
metrics.AUC(curve="Invalid")
|
||||
|
||||
def test_invalid_summation_method(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Invalid AUC summation method value: "Invalid".'
|
||||
):
|
||||
metrics.AUC(summation_method="Invalid")
|
||||
|
||||
def test_extra_dims(self):
|
||||
try:
|
||||
from scipy import special
|
||||
|
||||
logits = special.expit(
|
||||
-np.array(
|
||||
[
|
||||
[[-10.0, 10.0, -10.0], [10.0, -10.0, 10.0]],
|
||||
[[-12.0, 12.0, -12.0], [12.0, -12.0, 12.0]],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
)
|
||||
labels = np.array(
|
||||
[[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64
|
||||
)
|
||||
auc_obj = metrics.AUC()
|
||||
result = auc_obj(labels, logits)
|
||||
self.assertEqual(result, 0.5)
|
||||
except ImportError as e:
|
||||
logging.warning(f"Cannot test special functions: {str(e)}")
|
||||
|
||||
|
||||
class MultiAUCTest(testing.TestCase):
|
||||
def setUp(self):
|
||||
self.num_thresholds = 5
|
||||
self.y_pred = np.array(
|
||||
[[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]], dtype="float32"
|
||||
).T
|
||||
|
||||
epsilon = 1e-12
|
||||
self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)
|
||||
|
||||
self.y_true_good = np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T
|
||||
self.y_true_bad = np.array([[0, 0, 1, 1], [1, 1, 0, 0]]).T
|
||||
self.sample_weight = [1, 2, 3, 4]
|
||||
|
||||
# threshold values are [0 - 1e-7, 0.25, 0.5, 0.75, 1 + 1e-7]
|
||||
# y_pred when threshold = 0 - 1e-7 : [[1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# y_pred when threshold = 0.25 : [[0, 1, 1, 1], [0, 0, 1, 1]]
|
||||
# y_pred when threshold = 0.5 : [[0, 0, 0, 1], [0, 0, 0, 0]]
|
||||
# y_pred when threshold = 0.75 : [[0, 0, 0, 1], [0, 0, 0, 0]]
|
||||
# y_pred when threshold = 1 + 1e-7 : [[0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
|
||||
# for y_true_good, over thresholds:
|
||||
# tp = [[2, 2, 1, 1, 0], [2, 2, 0, 0, 0]]
|
||||
# fp = [[2, 1, 0, 0 , 0], [2, 0, 0 ,0, 0]]
|
||||
# fn = [[0, 0, 1, 1, 2], [0, 0, 2, 2, 2]]
|
||||
# tn = [[0, 1, 2, 2, 2], [0, 2, 2, 2, 2]]
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
|
||||
# for y_true_bad:
|
||||
# tp = [[2, 2, 1, 1, 0], [2, 0, 0, 0, 0]]
|
||||
# fp = [[2, 1, 0, 0 , 0], [2, 2, 0 ,0, 0]]
|
||||
# fn = [[0, 0, 1, 1, 2], [0, 2, 2, 2, 2]]
|
||||
# tn = [[0, 1, 2, 2, 2], [0, 0, 2, 2, 2]]
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 0, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 1, 0, 0, 0]]
|
||||
|
||||
# for y_true_good with sample_weights:
|
||||
|
||||
# tp = [[7, 7, 4, 4, 0], [7, 7, 0, 0, 0]]
|
||||
# fp = [[3, 2, 0, 0, 0], [3, 0, 0, 0, 0]]
|
||||
# fn = [[0, 0, 3, 3, 7], [0, 0, 7, 7, 7]]
|
||||
# tn = [[0, 1, 3, 3, 3], [0, 3, 3, 3, 3]]
|
||||
|
||||
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
|
||||
def test_unweighted_all_correct(self):
|
||||
auc_obj = metrics.AUC(multi_label=True)
|
||||
result = auc_obj(self.y_true_good, self.y_true_good)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_unweighted_all_correct_flat(self):
|
||||
auc_obj = metrics.AUC(multi_label=False)
|
||||
result = auc_obj(self.y_true_good, self.y_true_good)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_unweighted(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=True
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_from_logits(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=True,
|
||||
from_logits=True,
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 + 1.0) / 2.0
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_sample_weight_flat(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=False
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true_good, self.y_pred, sample_weight=[1, 2, 3, 4]
|
||||
)
|
||||
|
||||
# tpr = [1, 1, 0.2857, 0.2857, 0]
|
||||
# fpr = [1, 0.3333, 0, 0, 0]
|
||||
expected_result = 1.0 - (0.3333 * (1.0 - 0.2857) / 2.0)
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_full_sample_weight_flat(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=False
|
||||
)
|
||||
sw = np.arange(4 * 2)
|
||||
sw = sw.reshape(4, 2)
|
||||
result = auc_obj(self.y_true_good, self.y_pred, sample_weight=sw)
|
||||
|
||||
# tpr = [1, 1, 0.2727, 0.2727, 0]
|
||||
# fpr = [1, 0.3333, 0, 0, 0]
|
||||
expected_result = 1.0 - (0.3333 * (1.0 - 0.2727) / 2.0)
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_label_weights(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=True,
|
||||
label_weights=[0.75, 0.25],
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_label_weights_flat(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=False,
|
||||
label_weights=[0.75, 0.25],
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tpr = [1, 1, 0.375, 0.375, 0]
|
||||
# fpr = [1, 0.375, 0, 0, 0]
|
||||
expected_result = 1.0 - ((1.0 - 0.375) * 0.375 / 2.0)
|
||||
self.assertAllClose(result, expected_result, 1e-2)
|
||||
|
||||
def test_unweighted_flat(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=False
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tp = [4, 4, 1, 1, 0]
|
||||
# fp = [4, 1, 0, 0, 0]
|
||||
# fn = [0, 0, 3, 3, 4]
|
||||
# tn = [0, 3, 4, 4, 4]
|
||||
|
||||
# tpr = [1, 1, 0.25, 0.25, 0]
|
||||
# fpr = [1, 0.25, 0, 0, 0]
|
||||
expected_result = 1.0 - (3.0 / 32.0)
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_unweighted_flat_from_logits(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds,
|
||||
multi_label=False,
|
||||
from_logits=True,
|
||||
)
|
||||
result = auc_obj(self.y_true_good, self.y_pred_logits)
|
||||
|
||||
# tp = [4, 4, 1, 1, 0]
|
||||
# fp = [4, 1, 0, 0, 0]
|
||||
# fn = [0, 0, 3, 3, 4]
|
||||
# tn = [0, 3, 4, 4, 4]
|
||||
|
||||
# tpr = [1, 1, 0.25, 0.25, 0]
|
||||
# fpr = [1, 0.25, 0, 0, 0]
|
||||
expected_result = 1.0 - (3.0 / 32.0)
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_manual_thresholds(self):
|
||||
# Verify that when specified, thresholds are used instead of
|
||||
# num_thresholds.
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=2, thresholds=[0.5], multi_label=True
|
||||
)
|
||||
self.assertEqual(auc_obj.num_thresholds, 3)
|
||||
self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])
|
||||
result = auc_obj(self.y_true_good, self.y_pred)
|
||||
|
||||
# tp = [[2, 1, 0], [2, 0, 0]]
|
||||
# fp = [2, 0, 0], [2, 0, 0]]
|
||||
# fn = [[0, 1, 2], [0, 2, 2]]
|
||||
# tn = [[0, 2, 2], [0, 2, 2]]
|
||||
|
||||
# tpr = [[1, 0.5, 0], [1, 0, 0]]
|
||||
# fpr = [[1, 0, 0], [1, 0, 0]]
|
||||
|
||||
# auc by slice = [0.75, 0.5]
|
||||
expected_result = (0.75 + 0.5) / 2.0
|
||||
|
||||
self.assertAllClose(result, expected_result, 1e-3)
|
||||
|
||||
def test_weighted_roc_interpolation(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=True
|
||||
)
|
||||
result = auc_obj(
|
||||
self.y_true_good, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
|
||||
# tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]]
|
||||
# fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]]
|
||||
expected_result = 1.0 - 0.5 * 0.43 * 0.67
|
||||
self.assertAllClose(result, expected_result, 1e-1)
|
||||
|
||||
def test_pr_interpolation_unweighted(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, curve="PR", multi_label=True
|
||||
)
|
||||
good_result = auc_obj(self.y_true_good, self.y_pred)
|
||||
with self.subTest(name="good"):
|
||||
# PR AUCs are 0.917 and 1.0 respectively
|
||||
self.assertAllClose(good_result, (0.91667 + 1.0) / 2.0, 1e-1)
|
||||
bad_result = auc_obj(self.y_true_bad, self.y_pred)
|
||||
with self.subTest(name="bad"):
|
||||
# PR AUCs are 0.917 and 0.5 respectively
|
||||
self.assertAllClose(bad_result, (0.91667 + 0.5) / 2.0, 1e-1)
|
||||
|
||||
def test_pr_interpolation(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, curve="PR", multi_label=True
|
||||
)
|
||||
good_result = auc_obj(
|
||||
self.y_true_good, self.y_pred, sample_weight=self.sample_weight
|
||||
)
|
||||
# PR AUCs are 0.939 and 1.0 respectively
|
||||
self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1)
|
||||
|
||||
def test_keras_model_compiles(self):
|
||||
inputs = layers.Input(shape=(10,), batch_size=1)
|
||||
output = layers.Dense(3, activation="sigmoid")(inputs)
|
||||
model = models.Model(inputs=inputs, outputs=output)
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="binary_crossentropy",
|
||||
metrics=[metrics.AUC(multi_label=True)],
|
||||
)
|
||||
|
||||
def test_reset_state(self):
|
||||
auc_obj = metrics.AUC(
|
||||
num_thresholds=self.num_thresholds, multi_label=True
|
||||
)
|
||||
auc_obj(self.y_true_good, self.y_pred)
|
||||
auc_obj.reset_state()
|
||||
self.assertAllClose(auc_obj.true_positives, np.zeros((5, 2)))
|
||||
|
@ -38,6 +38,59 @@ class ConfusionMatrix(Enum):
|
||||
FALSE_NEGATIVES = "fn"
|
||||
|
||||
|
||||
class AUCCurve(Enum):
|
||||
"""Type of AUC Curve (ROC or PR)."""
|
||||
|
||||
ROC = "ROC"
|
||||
PR = "PR"
|
||||
|
||||
@staticmethod
|
||||
def from_str(key):
|
||||
if key in ("pr", "PR"):
|
||||
return AUCCurve.PR
|
||||
elif key in ("roc", "ROC"):
|
||||
return AUCCurve.ROC
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Invalid AUC curve value: "{key}". '
|
||||
'Expected values are ["PR", "ROC"]'
|
||||
)
|
||||
|
||||
|
||||
class AUCSummationMethod(Enum):
|
||||
"""Type of AUC summation method.
|
||||
|
||||
https://en.wikipedia.org/wiki/Riemann_sum)
|
||||
|
||||
Contains the following values:
|
||||
* 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
|
||||
`PR` curve, interpolates (true/false) positives but not the ratio that is
|
||||
precision (see Davis & Goadrich 2006 for details).
|
||||
* 'minoring': Applies left summation for increasing intervals and right
|
||||
summation for decreasing intervals.
|
||||
* 'majoring': Applies right summation for increasing intervals and left
|
||||
summation for decreasing intervals.
|
||||
"""
|
||||
|
||||
INTERPOLATION = "interpolation"
|
||||
MAJORING = "majoring"
|
||||
MINORING = "minoring"
|
||||
|
||||
@staticmethod
|
||||
def from_str(key):
|
||||
if key in ("interpolation", "Interpolation"):
|
||||
return AUCSummationMethod.INTERPOLATION
|
||||
elif key in ("majoring", "Majoring"):
|
||||
return AUCSummationMethod.MAJORING
|
||||
elif key in ("minoring", "Minoring"):
|
||||
return AUCSummationMethod.MINORING
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Invalid AUC summation method value: "{key}". '
|
||||
'Expected values are ["interpolation", "majoring", "minoring"]'
|
||||
)
|
||||
|
||||
|
||||
def _update_confusion_matrix_variables_optimized(
|
||||
variables_to_update,
|
||||
y_true,
|
||||
@ -203,14 +256,15 @@ def _update_confusion_matrix_variables_optimized(
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
|
||||
tp_bucket_v = ops.vectorized_map(
|
||||
gather_bucket, (true_labels, bucket_indices), warn=False
|
||||
tp_bucket_v = backend.vectorized_map(
|
||||
gather_bucket,
|
||||
(true_labels, bucket_indices),
|
||||
)
|
||||
fp_bucket_v = ops.vectorized_map(
|
||||
gather_bucket, (false_labels, bucket_indices), warn=False
|
||||
fp_bucket_v = backend.vectorized_map(
|
||||
gather_bucket, (false_labels, bucket_indices)
|
||||
)
|
||||
tp = ops.transpose(ops.cumsum(ops.flip(tp_bucket_v), axis=1))
|
||||
fp = ops.transpose(ops.cumsum(ops.flip(fp_bucket_v), axis=1))
|
||||
tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1)))
|
||||
fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1)))
|
||||
else:
|
||||
tp_bucket_v = ops.segment_sum(
|
||||
data=true_labels,
|
||||
@ -222,8 +276,8 @@ def _update_confusion_matrix_variables_optimized(
|
||||
segment_ids=bucket_indices,
|
||||
num_segments=num_thresholds,
|
||||
)
|
||||
tp = ops.cumsum(ops.flip(tp_bucket_v))
|
||||
fp = ops.cumsum(ops.flip(fp_bucket_v))
|
||||
tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v)))
|
||||
fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v)))
|
||||
|
||||
# fn = sum(true_labels) - tp
|
||||
# tn = sum(false_labels) - fp
|
||||
@ -383,7 +437,6 @@ def update_confusion_matrix_variables(
|
||||
one_thresh = ops.equal(
|
||||
ops.cast(1, dtype="int32"),
|
||||
thresholds.ndim,
|
||||
name="one_set_of_thresholds_cond",
|
||||
)
|
||||
else:
|
||||
one_thresh = ops.cast(True, dtype="bool")
|
||||
|
Loading…
Reference in New Issue
Block a user