From e395f16bfd03e0ec65dd77247f465caa69cc83ab Mon Sep 17 00:00:00 2001 From: mihirparadkar Date: Thu, 13 Jul 2023 21:17:16 -0700 Subject: [PATCH] Weighted metrics without metrics (#474) * Compile weighted_metrics even if metrics is None Fixes https://github.com/keras-team/keras-core/issues/454 * Update trainer_test.py * fixed formatting --- keras_core/trainers/trainer.py | 2 +- keras_core/trainers/trainer_test.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 6eeb85ec2..6ec53fe6e 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -44,7 +44,7 @@ class Trainer: ) else: self._compile_loss = None - if metrics is not None: + if metrics is not None or weighted_metrics is not None: self._compile_metrics = CompileMetrics( metrics, weighted_metrics, output_names=output_names ) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 7248f4faf..ad83b66f7 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -109,6 +109,22 @@ class TestTrainer(testing.TestCase, parameterized.TestCase): # And those weights are tracked at the model level self.assertEqual(len(model.metrics_variables), 6) + # Models with only weighted_metrics should have the same 3 metrics + model_weighted = ModelWithMetric(units=3) + model_weighted.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + weighted_metrics=[metrics.MeanSquaredError()], + ) + model_weighted.fit( + x, + y, + batch_size=2, + epochs=1, + sample_weight=np.ones(2), + ) + self.assertEqual(len(model_weighted.metrics), 3) + @parameterized.named_parameters( [ ("eager", True, False, False),