keras/keras_core/metrics/regression_metrics_test.py

36 lines
1.1 KiB
Python
Raw Normal View History

import numpy as np
from keras_core import testing
from keras_core.metrics import regression_metrics
class MeanSquaredErrorTest(testing.TestCase):
def test_config(self):
# TODO
pass
def test_unweighted(self):
mse_obj = regression_metrics.MeanSquaredError()
y_true = np.array(
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
)
y_pred = np.array(
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
)
mse_obj.update_state(y_true, y_pred)
result = mse_obj.result()
self.assertAllClose(0.5, result, atol=1e-5)
def test_weighted(self):
mse_obj = regression_metrics.MeanSquaredError()
y_true = np.array(
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
)
y_pred = np.array(
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
)
sample_weight = np.array([1.0, 1.5, 2.0, 2.5])
result = mse_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(0.54285, result, atol=1e-5)