From d6bcc56001b9a89629cbba80f314d3f95a9ac871 Mon Sep 17 00:00:00 2001 From: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com> Date: Sat, 22 Apr 2023 10:10:44 -0600 Subject: [PATCH] First confusion metric (FalsePositives) (#30) * Added confusion metrics -- still using TF ops * Fixed structure + tests pass for TF (still need to port to multi-backend) * Got rid of most tf deps, still a few more to go * Full removal of TF. Tests pass for both Jax and TF * Full removal of TF. Tests pass for both Jax and TF * Formatting * Formatting * Review comments * More review comments + formatting --- keras_core/activations/__init__.py | 4 - keras_core/backend/jax/__init__.py | 5 + keras_core/backend/jax/math.py | 15 + keras_core/backend/jax/nn.py | 4 + keras_core/backend/tensorflow/__init__.py | 5 + keras_core/backend/tensorflow/math.py | 12 + keras_core/backend/tensorflow/nn.py | 4 + keras_core/layers/core/dense.py | 5 - keras_core/metrics/__init__.py | 1 + keras_core/metrics/confusion_metrics.py | 115 ++++ keras_core/metrics/confusion_metrics_test.py | 103 ++++ keras_core/metrics/metrics_utils.py | 540 +++++++++++++++++++ keras_core/saving/serialization_lib.py | 4 +- keras_core/testing/test_case.py | 13 + keras_core/utils/python_utils.py | 17 + 15 files changed, 836 insertions(+), 11 deletions(-) create mode 100644 keras_core/backend/jax/math.py create mode 100644 keras_core/backend/tensorflow/math.py create mode 100644 keras_core/metrics/confusion_metrics.py create mode 100644 keras_core/metrics/confusion_metrics_test.py create mode 100644 keras_core/metrics/metrics_utils.py diff --git a/keras_core/activations/__init__.py b/keras_core/activations/__init__.py index 2f7a86b99..1dfb4a189 100644 --- a/keras_core/activations/__init__.py +++ b/keras_core/activations/__init__.py @@ -15,7 +15,3 @@ def get(identifier): if identifier == "relu": return relu return identifier - - -def serialize(activation): - return activation.__name__ diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index c5f74cd04..fabedc87f 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -5,6 +5,7 @@ from tensorflow import nest from keras_core.backend.common import KerasVariable from keras_core.backend.common import standardize_dtype +from keras_core.backend.jax import math from keras_core.backend.jax import nn from keras_core.backend.jax import numpy from keras_core.backend.jax import random @@ -50,6 +51,10 @@ def name_scope(name): return jax.named_scope(name) +def vectorized_map(function, elements): + return jax.vmap(function)(elements) + + class Variable(KerasVariable): def __init__(self, value, dtype=None, trainable=True, name=None): self.name = name or auto_name(self.__class__.__name__) diff --git a/keras_core/backend/jax/math.py b/keras_core/backend/jax/math.py new file mode 100644 index 000000000..d9ddbdc50 --- /dev/null +++ b/keras_core/backend/jax/math.py @@ -0,0 +1,15 @@ +import jax + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + return jax.ops.segment_sum( + data, segment_ids, num_segments, indices_are_sorted=sorted + ) + + +def top_k(x, k, sorted=False): + if sorted: + return ValueError( + "Jax backend does not support `sorted=True` for `ops.top_k`" + ) + return jax.lax.top_k(x, k) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 6ce73a282..eb0ebff1a 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -103,3 +103,7 @@ def conv_transpose( ): # TODO: Implement `conv_transpose`. raise NotImplementedError + + +def one_hot(x, num_classes, axis=-1): + return jnn.one_hot(x, num_classes, axis=axis) diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index 5b0f9e47f..0dd6a4db8 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -6,6 +6,7 @@ from keras_core.backend.common import standardize_dtype from keras_core.backend.keras_tensor import KerasTensor from keras_core.backend.stateless_scope import get_stateless_scope from keras_core.backend.stateless_scope import in_stateless_scope +from keras_core.backend.tensorflow import math from keras_core.backend.tensorflow import nn from keras_core.backend.tensorflow import numpy from keras_core.backend.tensorflow import random @@ -240,6 +241,10 @@ def name_scope(name): return tf.name_scope(name) +def vectorized_map(function, elements): + return tf.vectorized_map(function, elements) + + def compute_output_spec(fn, *args, **kwargs): graph_name = auto_name("scratch_graph") with tf.__internal__.FuncGraph(graph_name).as_default(): diff --git a/keras_core/backend/tensorflow/math.py b/keras_core/backend/tensorflow/math.py new file mode 100644 index 000000000..6bc2e2309 --- /dev/null +++ b/keras_core/backend/tensorflow/math.py @@ -0,0 +1,12 @@ +import tensorflow as tf + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + if sorted: + return tf.math.segment_sum(data, segment_ids) + else: + return tf.math.unsorted_segment_sum(data, segment_ids, num_segments) + + +def top_k(x, k, sorted=False): + return tf.math.top_k(x, k, sorted=sorted) diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index b75aaa8d6..a2e782c00 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -390,3 +390,7 @@ def conv_transpose( data_format=tf_data_format, dilations=dilation_rate, ) + + +def one_hot(x, num_classes, axis=-1): + return tf.one_hot(x, num_classes, axis=axis) diff --git a/keras_core/layers/core/dense.py b/keras_core/layers/core/dense.py index a309aa123..c4ffe8f5d 100644 --- a/keras_core/layers/core/dense.py +++ b/keras_core/layers/core/dense.py @@ -61,15 +61,10 @@ class Dense(Layer): # TODO config = { "units": self.units, - "activation": activations.serialize(self.activation), "use_bias": self.use_bias, "kernel_initializer": initializers.serialize( self.kernel_initializer ), "bias_initializer": initializers.serialize(self.bias_initializer), - "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), - "bias_regularizer": regularizers.serialize(self.bias_regularizer), - "kernel_constraint": constraints.serialize(self.kernel_constraint), - "bias_constraint": constraints.serialize(self.bias_constraint), } return {**base_config, **config} diff --git a/keras_core/metrics/__init__.py b/keras_core/metrics/__init__.py index c1e4836eb..278a3e40f 100644 --- a/keras_core/metrics/__init__.py +++ b/keras_core/metrics/__init__.py @@ -1,4 +1,5 @@ from keras_core.api_export import keras_core_export +from keras_core.metrics.confusion_metrics import FalsePositives from keras_core.metrics.metric import Metric from keras_core.metrics.reduction_metrics import Mean from keras_core.metrics.reduction_metrics import MeanMetricWrapper diff --git a/keras_core/metrics/confusion_metrics.py b/keras_core/metrics/confusion_metrics.py new file mode 100644 index 000000000..3b7ad5d22 --- /dev/null +++ b/keras_core/metrics/confusion_metrics.py @@ -0,0 +1,115 @@ +from keras_core import backend +from keras_core import initializers +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.metrics import metrics_utils +from keras_core.metrics.metric import Metric + + +class _ConfusionMatrixConditionCount(Metric): + """Calculates the number of the given confusion matrix condition. + + Args: + confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` + conditions. + thresholds: (Optional) Defaults to 0.5. A float value or a python list / + tuple of float threshold values in [0, 1]. A threshold is compared + with prediction values to determine the truth value of predictions + (i.e., above the threshold is `true`, below is `false`). One metric + value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__( + self, confusion_matrix_cond, thresholds=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + self._confusion_matrix_cond = confusion_matrix_cond + self.init_thresholds = thresholds + self.thresholds = metrics_utils.parse_init_thresholds( + thresholds, default_threshold=0.5 + ) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds(self.thresholds) + ) + self.accumulator = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="accumulator", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates the metric statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Defaults to 1. + Can be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + return metrics_utils.update_confusion_matrix_variables( + {self._confusion_matrix_cond: self.accumulator}, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + sample_weight=sample_weight, + ) + + def result(self): + if len(self.thresholds) == 1: + result = self.accumulator[0] + else: + result = self.accumulator + return backend.convert_to_tensor(result) + + def get_config(self): + config = {"thresholds": self.init_thresholds} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_core_export("keras_core.metrics.FalsePositives") +class FalsePositives(_ConfusionMatrixConditionCount): + """Calculates the number of false positives. + + If `sample_weight` is given, calculates the sum of the weights of + false positives. This metric creates one local variable, `accumulator` + that is used to keep track of the number of false positives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + thresholds: (Optional) Defaults to 0.5. A float value, or a Python + list/tuple of float threshold values in [0, 1]. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `true`, below is `false`). + If used with a loss function that sets `from_logits=True` (i.e. no + sigmoid applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Standalone usage: + + >>> m = keras_core.metrics.FalsePositives() + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) + >>> m.result() + 2.0 + + >>> m.reset_state() + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + """ + + def __init__(self, thresholds=None, name=None, dtype=None): + super().__init__( + confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, + thresholds=thresholds, + name=name, + dtype=dtype, + ) diff --git a/keras_core/metrics/confusion_metrics_test.py b/keras_core/metrics/confusion_metrics_test.py new file mode 100644 index 000000000..1ed5d2e27 --- /dev/null +++ b/keras_core/metrics/confusion_metrics_test.py @@ -0,0 +1,103 @@ +import numpy as np +from tensorflow.python.ops.numpy_ops import np_config + +from keras_core import metrics +from keras_core import testing + +np_config.enable_numpy_behavior() + + +class FalsePositivesTest(testing.TestCase): + def test_config(self): + fp_obj = metrics.FalsePositives(name="my_fp", thresholds=[0.4, 0.9]) + self.assertEqual(fp_obj.name, "my_fp") + self.assertLen(fp_obj.variables, 1) + self.assertEqual(fp_obj.thresholds, [0.4, 0.9]) + + # Check save and restore config + fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config()) + self.assertEqual(fp_obj2.name, "my_fp") + self.assertLen(fp_obj2.variables, 1) + self.assertEqual(fp_obj2.thresholds, [0.4, 0.9]) + + def test_unweighted(self): + fp_obj = metrics.FalsePositives() + + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + + fp_obj.update_state(y_true, y_pred) + result = fp_obj.result() + self.assertAllClose(7.0, result) + + def test_weighted(self): + fp_obj = metrics.FalsePositives() + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = fp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(14.0, result) + + def test_unweighted_with_thresholds(self): + fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + + fp_obj.update_state(y_true, y_pred) + result = fp_obj.result() + self.assertAllClose([7.0, 4.0, 2.0], result) + + def test_weighted_with_thresholds(self): + fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + sample_weight = ( + (1.0, 2.0, 3.0, 5.0), + (7.0, 11.0, 13.0, 17.0), + (19.0, 23.0, 29.0, 31.0), + (5.0, 15.0, 10.0, 0), + ) + + result = fp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose([125.0, 42.0, 12.0], result) + + def test_threshold_limit(self): + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]", + ): + metrics.FalsePositives(thresholds=[-1, 0.5, 2]) + + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[None\]", + ): + metrics.FalsePositives(thresholds=[None]) diff --git a/keras_core/metrics/metrics_utils.py b/keras_core/metrics/metrics_utils.py new file mode 100644 index 000000000..2c46b0032 --- /dev/null +++ b/keras_core/metrics/metrics_utils.py @@ -0,0 +1,540 @@ +from enum import Enum + +import numpy as np + +from keras_core import backend +from keras_core import operations as ops +from keras_core.losses.loss import squeeze_to_same_rank +from keras_core.utils.python_utils import to_list + +NEG_INF = -1e10 + + +def assert_thresholds_range(thresholds): + if thresholds is not None: + invalid_thresholds = [ + t for t in thresholds if t is None or t < 0 or t > 1 + ] + if invalid_thresholds: + raise ValueError( + "Threshold values must be in [0, 1]. " + f"Received: {invalid_thresholds}" + ) + + +def parse_init_thresholds(thresholds, default_threshold=0.5): + if thresholds is not None: + assert_thresholds_range(to_list(thresholds)) + thresholds = to_list( + default_threshold if thresholds is None else thresholds + ) + return thresholds + + +class ConfusionMatrix(Enum): + TRUE_POSITIVES = "tp" + FALSE_POSITIVES = "fp" + TRUE_NEGATIVES = "tn" + FALSE_NEGATIVES = "fn" + + +def _update_confusion_matrix_variables_optimized( + variables_to_update, + y_true, + y_pred, + thresholds, + multi_label=False, + sample_weights=None, + label_weights=None, + thresholds_with_epsilon=False, +): + """Update confusion matrix variables with memory efficient alternative. + + Note that the thresholds need to be evenly distributed within the list, eg, + the diff between consecutive elements are the same. + + To compute TP/FP/TN/FN, we are measuring a binary classifier + C(t) = (predictions >= t) + at each threshold 't'. So we have + TP(t) = sum( C(t) * true_labels ) + FP(t) = sum( C(t) * false_labels ) + + But, computing C(t) requires computation for each t. To make it fast, + observe that C(t) is a cumulative integral, and so if we have + thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + where n = num_thresholds, and if we can compute the bucket function + B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + then we get + C(t_i) = sum( B(j), j >= i ) + which is the reversed cumulative sum in ops.cumsum(). + + We can compute B(i) efficiently by taking advantage of the fact that + our thresholds are evenly distributed, in that + width = 1.0 / (num_thresholds - 1) + thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + Given a prediction value p, we can map it to its bucket by + bucket_index(p) = floor( p * (num_thresholds - 1) ) + so we can use ops.unsorted_segmenet_sum() to update the buckets in one + pass. + + Consider following example: + y_true = [0, 0, 1, 1] + y_pred = [0.1, 0.5, 0.3, 0.9] + thresholds = [0.0, 0.5, 1.0] + num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] + bucket_index(y_pred) = ops.floor(y_pred * num_buckets) + = ops.floor([0.2, 1.0, 0.6, 1.8]) + = [0, 0, 0, 1] + # The meaning of this bucket is that if any of the label is true, + # then 1 will be added to the corresponding bucket with the index. + # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the + # label for 1.8 is true, then 1 will be added to bucket 1. + # + # Note the second item "1.0" is floored to 0, since the value need to be + # strictly larger than the bucket lower bound. + # In the implementation, we use ops.ceil() - 1 to achieve this. + tp_bucket_value = ops.unsorted_segmenet_sum(true_labels, bucket_indices, + num_segments=num_thresholds) + = [1, 1, 0] + # For [1, 1, 0] here, it means there is 1 true value contributed by bucket + # 0, and 1 value contributed by bucket 1. When we aggregate them to + # together, the result become [a + b + c, b + c, c], since large thresholds + # will always contribute to the value for smaller thresholds. + true_positive = ops.cumsum(tp_bucket_value, reverse=True) + = [2, 1, 0] + + This implementation exhibits a run time and space complexity of O(T + N), + where T is the number of thresholds and N is the size of predictions. + Metrics that rely on standard implementation instead exhibit a complexity of + O(T * N). + + Args: + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys + and corresponding variables to update as values. + y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be + cast to `bool`. + y_pred: A floating point `Tensor` of arbitrary shape and whose values are + in the range `[0, 1]`. + thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. + It need to be evenly distributed (the diff between each element need to + be the same). + multi_label: Optional boolean indicating whether multidimensional + prediction/labels should be treated as multilabel responses, or + flattened into a single label. When True, the valus of + `variables_to_update` must have a second dimension equal to the number + of labels in y_true and y_pred, and those tensors must not be + RaggedTensors. + sample_weights: Optional `Tensor` whose rank is either 0, or the same rank + as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions + must be either `1`, or the same as the corresponding `y_true` + dimension). + label_weights: Optional tensor of non-negative weights for multilabel + data. The weights are applied when calculating TP, FP, FN, and TN + without explicit multilabel handling (i.e. when the data is to be + flattened). + thresholds_with_epsilon: Optional boolean indicating whether the leading + and tailing thresholds has any epsilon added for floating point + imprecisions. It will change how we handle the leading and tailing + bucket. + """ + num_thresholds = thresholds.shape.as_list()[0] + + if sample_weights is None: + sample_weights = 1.0 + else: + sample_weights = ops.broadcast_to( + ops.cast(sample_weights, dtype=y_pred.dtype), y_pred.shape + ) + if not multi_label: + sample_weights = ops.reshape(sample_weights, [-1]) + if label_weights is None: + label_weights = 1.0 + else: + label_weights = ops.expand_dims(label_weights, 0) + label_weights = ops.broadcast_to(label_weights, y_pred.shape) + if not multi_label: + label_weights = ops.reshape(label_weights, [-1]) + weights = ops.cast( + ops.multiply(sample_weights, label_weights), y_true.dtype + ) + + # We shouldn't need this, but in case there are predict value that is out of + # the range of [0.0, 1.0] + y_pred = ops.clip(y_pred, clip_value_min=0.0, clip_value_max=1.0) + + y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype) + if not multi_label: + y_true = ops.reshape(y_true, [-1]) + y_pred = ops.reshape(y_pred, [-1]) + + true_labels = ops.multiply(y_true, weights) + false_labels = ops.multiply((1.0 - y_true), weights) + + # Compute the bucket indices for each prediction value. + # Since the predict value has to be strictly greater than the thresholds, + # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. + # We have to use math.ceil(val) - 1 for the bucket. + bucket_indices = ops.ceil(y_pred * (num_thresholds - 1)) - 1 + + if thresholds_with_epsilon: + # In this case, the first bucket should actually take into account since + # the any prediction between [0.0, 1.0] should be larger than the first + # threshold. We change the bucket value from -1 to 0. + bucket_indices = ops.relu(bucket_indices) + + bucket_indices = ops.cast(bucket_indices, "int32") + + if multi_label: + # We need to run bucket segment sum for each of the label class. In the + # multi_label case, the rank of the label is 2. We first transpose it so + # that the label dim becomes the first and we can parallel run though + # them. + true_labels = ops.transpose(true_labels) + false_labels = ops.transpose(false_labels) + bucket_indices = ops.transpose(bucket_indices) + + def gather_bucket(label_and_bucket_index): + label, bucket_index = ( + label_and_bucket_index[0], + label_and_bucket_index[1], + ) + return ops.unsorted_segmenet_sum( + data=label, + segment_ids=bucket_index, + num_segments=num_thresholds, + ) + + tp_bucket_v = ops.vectorized_map( + gather_bucket, (true_labels, bucket_indices), warn=False + ) + fp_bucket_v = ops.vectorized_map( + gather_bucket, (false_labels, bucket_indices), warn=False + ) + tp = ops.transpose(ops.cumsum(tp_bucket_v, reverse=True, axis=1)) + fp = ops.transpose(ops.cumsum(fp_bucket_v, reverse=True, axis=1)) + else: + tp_bucket_v = ops.unsorted_segmenet_sum( + data=true_labels, + segment_ids=bucket_indices, + num_segments=num_thresholds, + ) + fp_bucket_v = ops.unsorted_segmenet_sum( + data=false_labels, + segment_ids=bucket_indices, + num_segments=num_thresholds, + ) + tp = ops.cumsum(tp_bucket_v, reverse=True) + fp = ops.cumsum(fp_bucket_v, reverse=True) + + # fn = sum(true_labels) - tp + # tn = sum(false_labels) - fp + if ( + ConfusionMatrix.TRUE_NEGATIVES in variables_to_update + or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update + ): + if multi_label: + total_true_labels = ops.sum(true_labels, axis=1) + total_false_labels = ops.sum(false_labels, axis=1) + else: + total_true_labels = ops.sum(true_labels) + total_false_labels = ops.sum(false_labels) + + if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] + variable.assign(variable + tp) + if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] + variable.assign(variable + fp) + if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] + tn = total_false_labels - fp + variable.assign(variable + tn) + if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] + fn = total_true_labels - tp + variable.assign(variable + fn) + + +def is_evenly_distributed_thresholds(thresholds): + """Check if the thresholds list is evenly distributed. + + We could leverage evenly distributed thresholds to use less memory when + calculate metrcis like AUC where each individual threshold need to be + evaluated. + + Args: + thresholds: A python list or tuple, or 1D numpy array whose value is + ranged in [0, 1]. + + Returns: + boolean, whether the values in the inputs are evenly distributed. + """ + # Check the list value and see if it is evenly distributed. + num_thresholds = len(thresholds) + if num_thresholds < 3: + return False + even_thresholds = np.arange(num_thresholds, dtype=np.float32) / ( + num_thresholds - 1 + ) + return np.allclose(thresholds, even_thresholds, atol=backend.epsilon()) + + +def update_confusion_matrix_variables( + variables_to_update, + y_true, + y_pred, + thresholds, + top_k=None, + class_id=None, + sample_weight=None, + multi_label=False, + label_weights=None, + thresholds_distributed_evenly=False, +): + """Updates the given confusion matrix variables. + + For every pair of values in y_true and y_pred: + + true_positive: y_true == True and y_pred > thresholds + false_negatives: y_true == True and y_pred <= thresholds + true_negatives: y_true == False and y_pred <= thresholds + false_positive: y_true == False and y_pred > thresholds + + The results will be weighted and added together. When multiple thresholds + are provided, we will repeat the same for every threshold. + + For estimation of these metrics over a stream of data, the function creates + an `update_op` operation that updates the given variables. + + If `sample_weight` is `None`, weights default to 1. + Use weights of 0 to mask values. + + Args: + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys + and corresponding variables to update as values. + y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. + y_pred: A floating point `Tensor` of arbitrary shape and whose values are + in the range `[0, 1]`. + thresholds: A float value, float tensor, python list, or tuple of float + thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). + top_k: Optional int, indicates that the positive labels should be limited + to the top k predictions. + class_id: Optional int, limits the prediction and labels to the class + specified by this argument. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank + as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions + must be either `1`, or the same as the corresponding `y_true` + dimension). + multi_label: Optional boolean indicating whether multidimensional + prediction/labels should be treated as multilabel responses, or + flattened into a single label. When True, the valus of + `variables_to_update` must have a second dimension equal to the number + of labels in y_true and y_pred, and those tensors must not be + RaggedTensors. + label_weights: (optional) tensor of non-negative weights for multilabel + data. The weights are applied when calculating TP, FP, FN, and TN + without explicit multilabel handling (i.e. when the data is to be + flattened). + thresholds_distributed_evenly: Boolean, whether the thresholds are evenly + distributed within the list. An optimized method will be used if this is + the case. See _update_confusion_matrix_variables_optimized() for more + details. + + Raises: + ValueError: If `y_pred` and `y_true` have mismatched shapes, or if + `sample_weight` is not `None` and its shape doesn't match `y_pred`, or + if `variables_to_update` contains invalid keys. + """ + if multi_label and label_weights is not None: + raise ValueError( + "`label_weights` for multilabel data should be handled " + "outside of `update_confusion_matrix_variables` when " + "`multi_label` is True." + ) + if variables_to_update is None: + return + if not any( + key for key in variables_to_update if key in list(ConfusionMatrix) + ): + raise ValueError( + "Please provide at least one valid confusion matrix " + "variable to update. Valid variable key options are: " + f'"{list(ConfusionMatrix)}". ' + f'Received: "{variables_to_update.keys()}"' + ) + + variable_dtype = list(variables_to_update.values())[0].dtype + + y_true = ops.cast(y_true, dtype=variable_dtype) + y_pred = ops.cast(y_pred, dtype=variable_dtype) + + if thresholds_distributed_evenly: + # Check whether the thresholds has any leading or tailing epsilon added + # for floating point imprecision. The leading and tailing threshold will + # be handled bit differently as the corner case. At this point, + # thresholds should be a list/array with more than 2 items, and ranged + # between [0, 1]. See is_evenly_distributed_thresholds() for more + # details. + thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0 + + thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype) + num_thresholds = ops.shape(thresholds)[0] + + if multi_label: + one_thresh = ops.equal( + ops.cast(1, dtype="int32"), + thresholds.ndim, + name="one_set_of_thresholds_cond", + ) + else: + one_thresh = ops.cast(True, dtype="bool") + + invalid_keys = [ + key for key in variables_to_update if key not in list(ConfusionMatrix) + ] + if invalid_keys: + raise ValueError( + f'Invalid keys: "{invalid_keys}". ' + f'Valid variable key options are: "{list(ConfusionMatrix)}"' + ) + + y_pred, y_true = squeeze_to_same_rank(y_pred, y_true) + if sample_weight is not None: + sample_weight = ops.expand_dims( + ops.cast(sample_weight, dtype=variable_dtype), axis=-1 + ) + _, sample_weight = squeeze_to_same_rank(y_true, sample_weight) + + if top_k is not None: + y_pred = _filter_top_k(y_pred, top_k) + if class_id is not None: + # Preserve dimension to match with sample_weight + y_true = y_true[..., class_id, None] + y_pred = y_pred[..., class_id, None] + + if thresholds_distributed_evenly: + return _update_confusion_matrix_variables_optimized( + variables_to_update, + y_true, + y_pred, + thresholds, + multi_label=multi_label, + sample_weights=sample_weight, + label_weights=label_weights, + thresholds_with_epsilon=thresholds_with_epsilon, + ) + + pred_shape = ops.shape(y_pred) + num_predictions = pred_shape[0] + if y_pred.ndim == 1: + num_labels = 1 + else: + num_labels = ops.cast( + ops.prod(ops.array(pred_shape[1:]), axis=0), "int32" + ) + thresh_label_tile = ops.where(one_thresh, num_labels, 1) + + # Reshape predictions and labels, adding a dim for thresholding. + if multi_label: + predictions_extra_dim = ops.expand_dims(y_pred, 0) + labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0) + else: + # Flatten predictions and labels when not multilabel. + predictions_extra_dim = ops.reshape(y_pred, [1, -1]) + labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1]) + + # Tile the thresholds for every prediction. + if multi_label: + thresh_pretile_shape = [num_thresholds, 1, -1] + thresh_tiles = [1, num_predictions, thresh_label_tile] + data_tiles = [num_thresholds, 1, 1] + else: + thresh_pretile_shape = [num_thresholds, -1] + thresh_tiles = [1, num_predictions * num_labels] + data_tiles = [num_thresholds, 1] + + thresh_tiled = ops.tile( + ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles) + ) + + # Tile the predictions for every threshold. + preds_tiled = ops.tile(predictions_extra_dim, data_tiles) + + # Compare predictions and threshold. + pred_is_pos = ops.greater(preds_tiled, thresh_tiled) + + # Tile labels by number of thresholds + label_is_pos = ops.tile(labels_extra_dim, data_tiles) + + if sample_weight is not None: + sample_weight = ops.broadcast_to( + ops.cast(sample_weight, dtype=y_pred.dtype), y_pred.shape + ) + weights_tiled = ops.tile( + ops.reshape(sample_weight, thresh_tiles), data_tiles + ) + else: + weights_tiled = None + + if label_weights is not None and not multi_label: + label_weights = ops.expand_dims(label_weights, 0) + label_weights = ops.broadcast_to(label_weights, y_pred.shape) + label_weights_tiled = ops.tile( + ops.reshape(label_weights, thresh_tiles), data_tiles + ) + if weights_tiled is None: + weights_tiled = label_weights_tiled + else: + weights_tiled = ops.multiply(weights_tiled, label_weights_tiled) + + def weighted_assign_add(label, pred, weights, var): + label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype) + if weights is not None: + label_and_pred *= ops.cast(weights, dtype=var.dtype) + var.assign(var + ops.sum(label_and_pred, 1)) + + loop_vars = { + ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), + } + update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update + update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update + update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update + + if update_fn or update_tn: + pred_is_neg = ops.logical_not(pred_is_pos) + loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) + + if update_fp or update_tn: + label_is_neg = ops.logical_not(label_is_pos) + loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) + if update_tn: + loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = ( + label_is_neg, + pred_is_neg, + ) + + for matrix_cond, (label, pred) in loop_vars.items(): + if matrix_cond in variables_to_update: + weighted_assign_add( + label, pred, weights_tiled, variables_to_update[matrix_cond] + ) + + +def _filter_top_k(x, k): + """Filters top-k values in the last dim of x and set the rest to NEG_INF. + + Used for computing top-k prediction values in dense labels (which has the + same shape as predictions) for recall and precision top-k metrics. + + Args: + x: tensor with any dimensions. + k: the number of values to keep. + + Returns: + tensor with same shape and dtype as x. + """ + _, top_k_idx = ops.top_k(x, k, sorted=False) + top_k_mask = ops.sum( + ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2 + ) + return x * top_k_mask + NEG_INF * (1 - top_k_mask) diff --git a/keras_core/saving/serialization_lib.py b/keras_core/saving/serialization_lib.py index 79edd3adf..687d9c13d 100644 --- a/keras_core/saving/serialization_lib.py +++ b/keras_core/saving/serialization_lib.py @@ -146,7 +146,7 @@ def serialize_keras_object(obj): "class_name": "__tensor__", "config": { "value": np.array(obj).tolist(), - "dtype": backend.standardize_dtype(obj.dtype), + "dtype": str(obj.dtype), }, } if type(obj).__module__ == np.__name__: @@ -155,7 +155,7 @@ def serialize_keras_object(obj): "class_name": "__numpy__", "config": { "value": obj.tolist(), - "dtype": backend.standardize_dtype(obj.dtype), + "dtype": backend.standardize_dytpe(obj.dtype), }, } else: diff --git a/keras_core/testing/test_case.py b/keras_core/testing/test_case.py index b2b72ab1a..fdb9124ec 100644 --- a/keras_core/testing/test_case.py +++ b/keras_core/testing/test_case.py @@ -9,3 +9,16 @@ class TestCase(unittest.TestCase): def assertAlmostEqual(self, x1, x2, decimal=3): np.testing.assert_almost_equal(x1, x2, decimal=decimal) + + def assertEqual(self, x1, x2): + np.testing.assert_equal(x1, x2) + + def assertLen(self, iterable, expected_len): + np.testing.assert_equal(len(iterable), expected_len) + + def assertRaisesRegex( + self, exception_class, expected_regexp, *args, **kwargs + ): + return np.testing.assert_raises_regex( + exception_class, expected_regexp, *args, **kwargs + ) diff --git a/keras_core/utils/python_utils.py b/keras_core/utils/python_utils.py index 5cfd0acb0..4e13dcc0d 100644 --- a/keras_core/utils/python_utils.py +++ b/keras_core/utils/python_utils.py @@ -86,3 +86,20 @@ def func_load(code, defaults=None, closure=None, globs=None): return python_types.FunctionType( code, globs, name=code.co_name, argdefs=defaults, closure=closure ) + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x]