Keep "mse" as the metric name in the log (#812)
* add the test * add the fix * fix other broken tests --------- Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
d42700a528
commit
549e260ade
@ -220,17 +220,16 @@ class ModelTest(testing.TestCase, parameterized.TestCase):
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
# TODO Align output names with 'bce', `mse`, `mae` of `tf.keras`
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_binary_crossentropy",
|
||||
"output_a_mean_absolute_error",
|
||||
"output_a_mean_squared_error",
|
||||
"output_a_bce",
|
||||
"output_a_mae",
|
||||
"output_a_mse",
|
||||
"output_b_acc",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
"output_b_mse",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
@ -79,13 +79,14 @@ def get_metric(identifier, y_true, y_pred):
|
||||
)
|
||||
|
||||
if not isinstance(metric_obj, metrics_module.Metric):
|
||||
metric_obj = metrics_module.MeanMetricWrapper(metric_obj)
|
||||
|
||||
if isinstance(identifier, str):
|
||||
metric_name = identifier
|
||||
else:
|
||||
metric_name = get_object_name(metric_obj)
|
||||
metric_obj = metrics_module.MeanMetricWrapper(
|
||||
metric_obj, name=metric_name
|
||||
)
|
||||
metric_obj.name = metric_name
|
||||
|
||||
return metric_obj
|
||||
|
||||
|
||||
|
@ -198,7 +198,7 @@ class TestCompileMetrics(testing.TestCase):
|
||||
|
||||
def test_name_conversions(self):
|
||||
compile_metrics = CompileMetrics(
|
||||
metrics=["acc", "accuracy"],
|
||||
metrics=["acc", "accuracy", "mse"],
|
||||
weighted_metrics=[],
|
||||
)
|
||||
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
|
||||
@ -207,9 +207,10 @@ class TestCompileMetrics(testing.TestCase):
|
||||
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
|
||||
result = compile_metrics.result()
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertAllClose(result["acc"], 0.333333)
|
||||
self.assertAllClose(result["accuracy"], 0.333333)
|
||||
self.assertTrue("mse" in result)
|
||||
|
||||
|
||||
class TestCompileLoss(testing.TestCase, parameterized.TestCase):
|
||||
|
@ -525,7 +525,9 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
|
||||
assert keys == ["outputs"]
|
||||
|
||||
model = ExampleModel(units=3)
|
||||
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
|
||||
model.compile(
|
||||
optimizer="adam", loss="mse", metrics=["mean_absolute_error"]
|
||||
)
|
||||
x = np.ones((16, 4))
|
||||
y = np.zeros((16, 3))
|
||||
x_test = np.ones((16, 4))
|
||||
@ -651,12 +653,16 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
|
||||
inputs = layers.Input((2,))
|
||||
outputs = layers.Dense(3)(inputs)
|
||||
model = keras_core.Model(inputs, outputs)
|
||||
model.compile(optimizer="sgd", loss="mse", metrics=["mse"])
|
||||
model.compile(
|
||||
optimizer="sgd", loss="mse", metrics=["mean_squared_error"]
|
||||
)
|
||||
history_1 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
|
||||
eval_out_1 = model.evaluate(
|
||||
np.ones((3, 2)), np.ones((3, 3)), return_dict=True
|
||||
)
|
||||
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
|
||||
model.compile(
|
||||
optimizer="sgd", loss="mse", metrics=["mean_absolute_error"]
|
||||
)
|
||||
history_2 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
|
||||
eval_out_2 = model.evaluate(
|
||||
np.ones((3, 2)), np.ones((3, 3)), return_dict=True
|
||||
|
Loading…
Reference in New Issue
Block a user