From 3edbba488ac8bab73efc44834a8822e74ab5341b Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Fri, 28 Apr 2023 21:49:04 +0530 Subject: [PATCH] Add Categorical Accuracy Metric (#47) * chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring --- keras_core/metrics/accuracy_metrics.py | 84 +++++++++++++++++++++ keras_core/metrics/accuracy_metrics_test.py | 32 ++++++++ 2 files changed, 116 insertions(+) diff --git a/keras_core/metrics/accuracy_metrics.py b/keras_core/metrics/accuracy_metrics.py index c14b42c94..4f6355ec8 100644 --- a/keras_core/metrics/accuracy_metrics.py +++ b/keras_core/metrics/accuracy_metrics.py @@ -112,3 +112,87 @@ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper): def get_config(self): return {"name": self.name, "dtype": self.dtype} + + +def categorical_accuracy(y_true, y_pred): + y_true = ops.argmax(y_true, axis=-1) + + reshape_matches = False + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype) + y_true_org_shape = ops.shape(y_true) + y_pred_rank = len(y_pred.shape) + y_true_rank = len(y_true.shape) + + # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if ( + (y_true_rank is not None) + and (y_pred_rank is not None) + and (len(y_true.shape) == len(y_pred.shape)) + ): + y_true = ops.squeeze(y_true, [-1]) + reshape_matches = True + y_pred = ops.argmax(y_pred, axis=-1) + + # If the predicted output and actual output types don't match, force cast + # them to match. + if y_pred.dtype != y_true.dtype: + y_pred = ops.cast(y_pred, dtype=y_true.dtype) + matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx()) + if reshape_matches: + matches = ops.reshape(matches, shape=y_true_org_shape) + return matches + + +@keras_core_export("keras_core.metrics.BinaryAccuracy") +class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper): + """Calculates how often predictions match one-hot labels. + + You can provide logits of classes as `y_pred`, since argmax of + logits and probabilities are same. + + This metric creates two local variables, `total` and `count` that are used + to compute the frequency with which `y_pred` matches `y_true`. This + frequency is ultimately returned as `categorical accuracy`: an idempotent + operation that simply divides `total` by `count`. + + `y_pred` and `y_true` should be passed in as vectors of probabilities, + rather than as labels. If necessary, use `tf.one_hot` to expand `y_true` as + a vector. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Standalone usage: + + >>> m = keras_core.metrics.CategoricalAccuracy() + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) + >>> m.result() + 0.3 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mse', + metrics=[keras_core.metrics.CategoricalAccuracy()]) + ``` + """ + + def __init__(self, name="categorical_accuracy", dtype=None): + super().__init__(fn=categorical_accuracy, name=name, dtype=dtype) + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} diff --git a/keras_core/metrics/accuracy_metrics_test.py b/keras_core/metrics/accuracy_metrics_test.py index ae660f2ab..6338f31d4 100644 --- a/keras_core/metrics/accuracy_metrics_test.py +++ b/keras_core/metrics/accuracy_metrics_test.py @@ -60,3 +60,35 @@ class BinaryAccuracyTest(testing.TestCase): bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) result = bin_acc_obj.result() self.assertAllClose(result, 0.5, atol=1e-3) + + +class CategoricalAccuracyTest(testing.TestCase): + def test_config(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + self.assertEqual(cat_acc_obj.name, "categorical_accuracy") + self.assertEqual(len(cat_acc_obj.variables), 2) + self.assertEqual(cat_acc_obj._dtype, "float32") + # TODO: Check save and restore config + + def test_unweighted(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + cat_acc_obj.update_state(y_true, y_pred) + result = cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + sample_weight = np.array([0.7, 0.3]) + cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3)