Weighted metrics without metrics (#474)

* Compile weighted_metrics even if metrics is None

Fixes https://github.com/keras-team/keras-core/issues/454

* Update trainer_test.py

* fixed formatting
This commit is contained in:
mihirparadkar 2023-07-13 21:17:16 -07:00 committed by Francois Chollet
parent bea61227f1
commit e395f16bfd
2 changed files with 17 additions and 1 deletions

@ -44,7 +44,7 @@ class Trainer:
)
else:
self._compile_loss = None
if metrics is not None:
if metrics is not None or weighted_metrics is not None:
self._compile_metrics = CompileMetrics(
metrics, weighted_metrics, output_names=output_names
)

@ -109,6 +109,22 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
# And those weights are tracked at the model level
self.assertEqual(len(model.metrics_variables), 6)
# Models with only weighted_metrics should have the same 3 metrics
model_weighted = ModelWithMetric(units=3)
model_weighted.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
weighted_metrics=[metrics.MeanSquaredError()],
)
model_weighted.fit(
x,
y,
batch_size=2,
epochs=1,
sample_weight=np.ones(2),
)
self.assertEqual(len(model_weighted.metrics), 3)
@parameterized.named_parameters(
[
("eager", True, False, False),