Keep "mse" as the metric name in the log (#812)

* add the test

* add the fix

* fix other broken tests

---------

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-08-29 14:51:19 -07:00 committed by Francois Chollet
parent d42700a528
commit 549e260ade
4 changed files with 24 additions and 17 deletions

@ -220,17 +220,16 @@ class ModelTest(testing.TestCase, parameterized.TestCase):
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# TODO Align output names with 'bce', `mse`, `mae` of `tf.keras`
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_binary_crossentropy",
"output_a_mean_absolute_error",
"output_a_mean_squared_error",
"output_a_bce",
"output_a_mae",
"output_a_mse",
"output_b_acc",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_mse",
]
)
self.assertListEqual(hist_keys, ref_keys)

@ -79,13 +79,14 @@ def get_metric(identifier, y_true, y_pred):
)
if not isinstance(metric_obj, metrics_module.Metric):
if isinstance(identifier, str):
metric_name = identifier
else:
metric_name = get_object_name(metric_obj)
metric_obj = metrics_module.MeanMetricWrapper(
metric_obj, name=metric_name
)
metric_obj = metrics_module.MeanMetricWrapper(metric_obj)
if isinstance(identifier, str):
metric_name = identifier
else:
metric_name = get_object_name(metric_obj)
metric_obj.name = metric_name
return metric_obj

@ -198,7 +198,7 @@ class TestCompileMetrics(testing.TestCase):
def test_name_conversions(self):
compile_metrics = CompileMetrics(
metrics=["acc", "accuracy"],
metrics=["acc", "accuracy", "mse"],
weighted_metrics=[],
)
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
@ -207,9 +207,10 @@ class TestCompileMetrics(testing.TestCase):
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertEqual(len(result), 2)
self.assertEqual(len(result), 3)
self.assertAllClose(result["acc"], 0.333333)
self.assertAllClose(result["accuracy"], 0.333333)
self.assertTrue("mse" in result)
class TestCompileLoss(testing.TestCase, parameterized.TestCase):

@ -525,7 +525,9 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
assert keys == ["outputs"]
model = ExampleModel(units=3)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
model.compile(
optimizer="adam", loss="mse", metrics=["mean_absolute_error"]
)
x = np.ones((16, 4))
y = np.zeros((16, 3))
x_test = np.ones((16, 4))
@ -651,12 +653,16 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
inputs = layers.Input((2,))
outputs = layers.Dense(3)(inputs)
model = keras_core.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="mse", metrics=["mse"])
model.compile(
optimizer="sgd", loss="mse", metrics=["mean_squared_error"]
)
history_1 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
eval_out_1 = model.evaluate(
np.ones((3, 2)), np.ones((3, 3)), return_dict=True
)
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
model.compile(
optimizer="sgd", loss="mse", metrics=["mean_absolute_error"]
)
history_2 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
eval_out_2 = model.evaluate(
np.ones((3, 2)), np.ones((3, 3)), return_dict=True