keras/keras_core/metrics/accuracy_metrics_test.py

127 lines
5.0 KiB
Python
Raw Normal View History

import numpy as np
from keras_core import testing
from keras_core.metrics import accuracy_metrics
class AccuracyTest(testing.TestCase):
def test_config(self):
acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32")
self.assertEqual(acc_obj.name, "accuracy")
self.assertEqual(len(acc_obj.variables), 2)
self.assertEqual(acc_obj._dtype, "float32")
# TODO: Check save and restore config
def test_unweighted(self):
acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32")
y_true = np.array([[1], [2], [3], [4]])
y_pred = np.array([[0], [2], [3], [4]])
acc_obj.update_state(y_true, y_pred)
result = acc_obj.result()
self.assertAllClose(result, 0.75, atol=1e-3)
def test_weighted(self):
acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32")
y_true = np.array([[1], [2], [3], [4]])
y_pred = np.array([[0], [2], [3], [4]])
sample_weight = np.array([1, 1, 0, 0])
acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
result = acc_obj.result()
self.assertAllClose(result, 0.5, atol=1e-3)
class BinaryAccuracyTest(testing.TestCase):
def test_config(self):
bin_acc_obj = accuracy_metrics.BinaryAccuracy(
name="binary_accuracy", dtype="float32"
)
self.assertEqual(bin_acc_obj.name, "binary_accuracy")
self.assertEqual(len(bin_acc_obj.variables), 2)
self.assertEqual(bin_acc_obj._dtype, "float32")
# TODO: Check save and restore config
def test_unweighted(self):
bin_acc_obj = accuracy_metrics.BinaryAccuracy(
name="binary_accuracy", dtype="float32"
)
y_true = np.array([[1], [1], [0], [0]])
y_pred = np.array([[0.98], [1], [0], [0.6]])
bin_acc_obj.update_state(y_true, y_pred)
result = bin_acc_obj.result()
self.assertAllClose(result, 0.75, atol=1e-3)
def test_weighted(self):
bin_acc_obj = accuracy_metrics.BinaryAccuracy(
name="binary_accuracy", dtype="float32"
)
y_true = np.array([[1], [1], [0], [0]])
y_pred = np.array([[0.98], [1], [0], [0.6]])
sample_weight = np.array([1, 0, 0, 1])
bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
result = bin_acc_obj.result()
self.assertAllClose(result, 0.5, atol=1e-3)
class CategoricalAccuracyTest(testing.TestCase):
def test_config(self):
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
name="categorical_accuracy", dtype="float32"
)
self.assertEqual(cat_acc_obj.name, "categorical_accuracy")
self.assertEqual(len(cat_acc_obj.variables), 2)
self.assertEqual(cat_acc_obj._dtype, "float32")
# TODO: Check save and restore config
def test_unweighted(self):
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
name="categorical_accuracy", dtype="float32"
)
y_true = np.array([[0, 0, 1], [0, 1, 0]])
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
cat_acc_obj.update_state(y_true, y_pred)
result = cat_acc_obj.result()
self.assertAllClose(result, 0.5, atol=1e-3)
def test_weighted(self):
cat_acc_obj = accuracy_metrics.CategoricalAccuracy(
name="categorical_accuracy", dtype="float32"
)
y_true = np.array([[0, 0, 1], [0, 1, 0]])
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
sample_weight = np.array([0.7, 0.3])
cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
result = cat_acc_obj.result()
self.assertAllClose(result, 0.3, atol=1e-3)
class SparseCategoricalAccuracyTest(testing.TestCase):
def test_config(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
self.assertEqual(sp_cat_acc_obj.name, "sparse_categorical_accuracy")
self.assertEqual(len(sp_cat_acc_obj.variables), 2)
self.assertEqual(sp_cat_acc_obj._dtype, "float32")
# TODO: Check save and restore config
def test_unweighted(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
y_true = np.array([[2], [1]])
y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 0.5, atol=1e-3)
def test_weighted(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
y_true = np.array([[2], [1]])
y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
sample_weight = np.array([0.7, 0.3])
sp_cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 0.3, atol=1e-3)