37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import numpy as np
|
|
|
|
from keras_core import operations as ops
|
|
from keras_core import testing
|
|
from keras_core.metrics import regression_metrics
|
|
|
|
|
|
class MeanSquaredErrorTest(testing.TestCase):
|
|
def test_config(self):
|
|
# TODO
|
|
pass
|
|
|
|
def test_unweighted(self):
|
|
mse_obj = regression_metrics.MeanSquaredError()
|
|
y_true = np.array(
|
|
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
|
|
)
|
|
y_pred = np.array(
|
|
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
|
|
)
|
|
|
|
mse_obj.update_state(y_true, y_pred)
|
|
result = mse_obj.result()
|
|
self.assertAllClose(0.5, result, atol=1e-5)
|
|
|
|
def test_weighted(self):
|
|
mse_obj = regression_metrics.MeanSquaredError()
|
|
y_true = np.array(
|
|
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
|
|
)
|
|
y_pred = np.array(
|
|
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
|
|
)
|
|
sample_weight = np.array([1.0, 1.5, 2.0, 2.5])
|
|
result = mse_obj(y_true, y_pred, sample_weight=sample_weight)
|
|
self.assertAllClose(0.54285, result, atol=1e-5)
|