2023-04-12 18:31:58 +00:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from keras_core import backend
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import initializers
|
2023-05-30 19:05:18 +00:00
|
|
|
from keras_core import metrics as metrics_module
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core import testing
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core.metrics.metric import Metric
|
|
|
|
|
|
|
|
|
|
|
|
class ExampleMetric(Metric):
|
|
|
|
def __init__(self, name="mean_square_error", dtype=None):
|
|
|
|
super().__init__(name=name, dtype=dtype)
|
2023-04-12 18:00:14 +00:00
|
|
|
self.sum = self.add_variable(
|
2023-04-13 03:40:23 +00:00
|
|
|
name="sum", shape=(), initializer=initializers.Zeros()
|
2023-04-12 18:00:14 +00:00
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
self.total = self.add_variable(
|
2023-04-13 17:59:51 +00:00
|
|
|
name="total",
|
|
|
|
shape=(),
|
|
|
|
initializer=initializers.Zeros(),
|
|
|
|
dtype="int32",
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def update_state(self, y_true, y_pred):
|
|
|
|
y_true = ops.convert_to_tensor(y_true)
|
|
|
|
y_pred = ops.convert_to_tensor(y_pred)
|
|
|
|
sum = ops.sum((y_true - y_pred) ** 2)
|
|
|
|
self.sum.assign(self.sum + sum)
|
|
|
|
batch_size = ops.shape(y_true)[0]
|
|
|
|
self.total.assign(self.total + batch_size)
|
|
|
|
|
|
|
|
def result(self):
|
|
|
|
return self.sum / (ops.cast(self.total, dtype="float32") + 1e-7)
|
|
|
|
|
|
|
|
def reset_state(self):
|
|
|
|
self.sum.assign(0.0)
|
|
|
|
self.total.assign(0)
|
|
|
|
|
|
|
|
|
|
|
|
class MetricTest(testing.TestCase):
|
|
|
|
def test_end_to_end_flow(self):
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
self.assertEqual(metric.name, "mse")
|
|
|
|
self.assertEqual(len(metric.variables), 2)
|
|
|
|
|
|
|
|
num_samples = 20
|
|
|
|
y_true = np.random.random((num_samples, 3))
|
|
|
|
y_pred = np.random.random((num_samples, 3))
|
|
|
|
batch_size = 8
|
|
|
|
for b in range(0, num_samples // batch_size + 1):
|
|
|
|
print(b * batch_size, (b + 1) * batch_size)
|
|
|
|
y_true_batch = y_true[b * batch_size : (b + 1) * batch_size]
|
|
|
|
y_pred_batch = y_pred[b * batch_size : (b + 1) * batch_size]
|
|
|
|
metric.update_state(y_true_batch, y_pred_batch)
|
|
|
|
|
|
|
|
self.assertAllClose(metric.total, 20)
|
|
|
|
result = metric.result()
|
2023-04-12 18:00:14 +00:00
|
|
|
self.assertAllClose(
|
|
|
|
result, np.sum((y_true - y_pred) ** 2) / num_samples
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
metric.reset_state()
|
|
|
|
self.assertEqual(metric.result(), 0.0)
|
|
|
|
|
|
|
|
def test_variable_tracking(self):
|
|
|
|
# In list
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
metric.more_vars = [backend.Variable(0.0), backend.Variable(1.0)]
|
|
|
|
self.assertEqual(len(metric.variables), 4)
|
|
|
|
|
|
|
|
# In dict
|
|
|
|
metric = ExampleMetric(name="mse")
|
2023-04-12 18:00:14 +00:00
|
|
|
metric.more_vars = {
|
|
|
|
"a": backend.Variable(0.0),
|
|
|
|
"b": backend.Variable(1.0),
|
|
|
|
}
|
2023-04-09 19:21:45 +00:00
|
|
|
self.assertEqual(len(metric.variables), 4)
|
|
|
|
|
|
|
|
# In nested structured
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
metric.more_vars = {"a": [backend.Variable(0.0), backend.Variable(1.0)]}
|
|
|
|
self.assertEqual(len(metric.variables), 4)
|
|
|
|
|
|
|
|
def test_submetric_tracking(self):
|
|
|
|
# Plain attr
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
metric.submetric = ExampleMetric(name="submse")
|
|
|
|
self.assertEqual(len(metric.variables), 4)
|
|
|
|
|
|
|
|
# In list
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
metric.submetrics = [
|
|
|
|
ExampleMetric(name="submse1"),
|
|
|
|
ExampleMetric(name="submse2"),
|
|
|
|
]
|
|
|
|
self.assertEqual(len(metric.variables), 6)
|
|
|
|
|
|
|
|
# In dict
|
|
|
|
metric = ExampleMetric(name="mse")
|
|
|
|
metric.submetrics = {
|
|
|
|
"1": ExampleMetric(name="submse1"),
|
|
|
|
"2": ExampleMetric(name="submse2"),
|
|
|
|
}
|
|
|
|
self.assertEqual(len(metric.variables), 6)
|
|
|
|
|
|
|
|
def test_serialization(self):
|
2023-05-14 00:07:43 +00:00
|
|
|
self.run_class_serialization_test(
|
|
|
|
ExampleMetric(name="mse"),
|
|
|
|
custom_objects={"ExampleMetric": ExampleMetric},
|
|
|
|
)
|
2023-05-30 19:05:18 +00:00
|
|
|
|
|
|
|
def test_get_method(self):
|
|
|
|
metric = metrics_module.get("mse")
|
|
|
|
self.assertTrue(isinstance(metric, metrics_module.MeanSquaredError))
|
|
|
|
|
2023-06-09 18:41:35 +00:00
|
|
|
metric = metrics_module.get("mean_squared_error")
|
|
|
|
self.assertTrue(isinstance(metric, metrics_module.MeanSquaredError))
|
|
|
|
|
|
|
|
metric = metrics_module.get("categorical_accuracy")
|
|
|
|
self.assertTrue(isinstance(metric, metrics_module.CategoricalAccuracy))
|
|
|
|
|
2023-05-30 19:05:18 +00:00
|
|
|
metric = metrics_module.get(None)
|
|
|
|
self.assertEqual(metric, None)
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
metrics_module.get("typo")
|