Add Categorical Accuracy Metric (#47)
* chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring
This commit is contained in:
parent
fd3323b875
commit
3edbba488a
@ -112,3 +112,87 @@ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper):
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {"name": self.name, "dtype": self.dtype}
|
return {"name": self.name, "dtype": self.dtype}
|
||||||
|
|
||||||
|
|
||||||
|
def categorical_accuracy(y_true, y_pred):
|
||||||
|
y_true = ops.argmax(y_true, axis=-1)
|
||||||
|
|
||||||
|
reshape_matches = False
|
||||||
|
y_pred = ops.convert_to_tensor(y_pred)
|
||||||
|
y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype)
|
||||||
|
y_true_org_shape = ops.shape(y_true)
|
||||||
|
y_pred_rank = len(y_pred.shape)
|
||||||
|
y_true_rank = len(y_true.shape)
|
||||||
|
|
||||||
|
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
|
||||||
|
if (
|
||||||
|
(y_true_rank is not None)
|
||||||
|
and (y_pred_rank is not None)
|
||||||
|
and (len(y_true.shape) == len(y_pred.shape))
|
||||||
|
):
|
||||||
|
y_true = ops.squeeze(y_true, [-1])
|
||||||
|
reshape_matches = True
|
||||||
|
y_pred = ops.argmax(y_pred, axis=-1)
|
||||||
|
|
||||||
|
# If the predicted output and actual output types don't match, force cast
|
||||||
|
# them to match.
|
||||||
|
if y_pred.dtype != y_true.dtype:
|
||||||
|
y_pred = ops.cast(y_pred, dtype=y_true.dtype)
|
||||||
|
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
|
||||||
|
if reshape_matches:
|
||||||
|
matches = ops.reshape(matches, shape=y_true_org_shape)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.metrics.BinaryAccuracy")
|
||||||
|
class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
||||||
|
"""Calculates how often predictions match one-hot labels.
|
||||||
|
|
||||||
|
You can provide logits of classes as `y_pred`, since argmax of
|
||||||
|
logits and probabilities are same.
|
||||||
|
|
||||||
|
This metric creates two local variables, `total` and `count` that are used
|
||||||
|
to compute the frequency with which `y_pred` matches `y_true`. This
|
||||||
|
frequency is ultimately returned as `categorical accuracy`: an idempotent
|
||||||
|
operation that simply divides `total` by `count`.
|
||||||
|
|
||||||
|
`y_pred` and `y_true` should be passed in as vectors of probabilities,
|
||||||
|
rather than as labels. If necessary, use `tf.one_hot` to expand `y_true` as
|
||||||
|
a vector.
|
||||||
|
|
||||||
|
If `sample_weight` is `None`, weights default to 1.
|
||||||
|
Use `sample_weight` of 0 to mask values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: (Optional) string name of the metric instance.
|
||||||
|
dtype: (Optional) data type of the metric result.
|
||||||
|
|
||||||
|
Standalone usage:
|
||||||
|
|
||||||
|
>>> m = keras_core.metrics.CategoricalAccuracy()
|
||||||
|
>>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
|
||||||
|
... [0.05, 0.95, 0]])
|
||||||
|
>>> m.result()
|
||||||
|
0.5
|
||||||
|
|
||||||
|
>>> m.reset_state()
|
||||||
|
>>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
|
||||||
|
... [0.05, 0.95, 0]],
|
||||||
|
... sample_weight=[0.7, 0.3])
|
||||||
|
>>> m.result()
|
||||||
|
0.3
|
||||||
|
|
||||||
|
Usage with `compile()` API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model.compile(optimizer='sgd',
|
||||||
|
loss='mse',
|
||||||
|
metrics=[keras_core.metrics.CategoricalAccuracy()])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name="categorical_accuracy", dtype=None):
|
||||||
|
super().__init__(fn=categorical_accuracy, name=name, dtype=dtype)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"name": self.name, "dtype": self.dtype}
|
||||||
|
@ -60,3 +60,35 @@ class BinaryAccuracyTest(testing.TestCase):
|
|||||||
bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
|
bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
|
||||||
result = bin_acc_obj.result()
|
result = bin_acc_obj.result()
|
||||||
self.assertAllClose(result, 0.5, atol=1e-3)
|
self.assertAllClose(result, 0.5, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalAccuracyTest(testing.TestCase):
|
||||||
|
def test_config(self):
|
||||||
|
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
|
||||||
|
name="categorical_accuracy", dtype="float32"
|
||||||
|
)
|
||||||
|
self.assertEqual(cat_acc_obj.name, "categorical_accuracy")
|
||||||
|
self.assertEqual(len(cat_acc_obj.variables), 2)
|
||||||
|
self.assertEqual(cat_acc_obj._dtype, "float32")
|
||||||
|
# TODO: Check save and restore config
|
||||||
|
|
||||||
|
def test_unweighted(self):
|
||||||
|
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
|
||||||
|
name="categorical_accuracy", dtype="float32"
|
||||||
|
)
|
||||||
|
y_true = np.array([[0, 0, 1], [0, 1, 0]])
|
||||||
|
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
|
||||||
|
cat_acc_obj.update_state(y_true, y_pred)
|
||||||
|
result = cat_acc_obj.result()
|
||||||
|
self.assertAllClose(result, 0.5, atol=1e-3)
|
||||||
|
|
||||||
|
def test_weighted(self):
|
||||||
|
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
|
||||||
|
name="categorical_accuracy", dtype="float32"
|
||||||
|
)
|
||||||
|
y_true = np.array([[0, 0, 1], [0, 1, 0]])
|
||||||
|
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
|
||||||
|
sample_weight = np.array([0.7, 0.3])
|
||||||
|
cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
|
||||||
|
result = cat_acc_obj.result()
|
||||||
|
self.assertAllClose(result, 0.3, atol=1e-3)
|
||||||
|
Loading…
Reference in New Issue
Block a user