Weighted metrics without metrics (#474)
* Compile weighted_metrics even if metrics is None Fixes https://github.com/keras-team/keras-core/issues/454 * Update trainer_test.py * fixed formatting
This commit is contained in:
parent
bea61227f1
commit
e395f16bfd
@ -44,7 +44,7 @@ class Trainer:
|
||||
)
|
||||
else:
|
||||
self._compile_loss = None
|
||||
if metrics is not None:
|
||||
if metrics is not None or weighted_metrics is not None:
|
||||
self._compile_metrics = CompileMetrics(
|
||||
metrics, weighted_metrics, output_names=output_names
|
||||
)
|
||||
|
@ -109,6 +109,22 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
|
||||
# And those weights are tracked at the model level
|
||||
self.assertEqual(len(model.metrics_variables), 6)
|
||||
|
||||
# Models with only weighted_metrics should have the same 3 metrics
|
||||
model_weighted = ModelWithMetric(units=3)
|
||||
model_weighted.compile(
|
||||
optimizer=optimizers.SGD(),
|
||||
loss=losses.MeanSquaredError(),
|
||||
weighted_metrics=[metrics.MeanSquaredError()],
|
||||
)
|
||||
model_weighted.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
sample_weight=np.ones(2),
|
||||
)
|
||||
self.assertEqual(len(model_weighted.metrics), 3)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[
|
||||
("eager", True, False, False),
|
||||
|
Loading…
Reference in New Issue
Block a user