Align history metrics output with tf.keras (#392)

* Align compile metrics with

* Update tests for weighted metrics

* Update metrics for losses as single element list

* Update metrics for losses as single element list
This commit is contained in:
Ramesh Sampath 2023-06-23 04:14:59 +05:30 committed by Francois Chollet
parent c373b6e1da
commit b2e7bf28bc
3 changed files with 261 additions and 72 deletions

@ -1,4 +1,5 @@
import numpy as np
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
@ -34,6 +35,27 @@ def _get_model_multi_outputs_list_no_output_names():
return model
def _get_model_single_output():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, output_a)
return model
def _get_model_single_output_list():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, [output_a])
return model
def _get_model_single_output_dict():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, {"output_a": output_a})
return model
def _get_model_multi_outputs_dict():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
@ -42,7 +64,7 @@ def _get_model_multi_outputs_dict():
return model
class ModelTest(testing.TestCase):
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
model = _get_model()
self.assertTrue(isinstance(model, Functional))
@ -91,6 +113,61 @@ class ModelTest(testing.TestCase):
)
self.assertTrue(isinstance(new_model, Functional))
@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
("single_output_3", _get_model_single_output, "dict"),
("single_output_4", _get_model_single_output, "dict_list"),
("single_list_output_1", _get_model_single_output_list, None),
("single_list_output_2", _get_model_single_output_list, "list"),
("single_list_output_3", _get_model_single_output_list, "dict"),
("single_list_output_4", _get_model_single_output_list, "dict_list"),
("single_dict_output_1", _get_model_single_output_dict, None),
("single_dict_output_2", _get_model_single_output_dict, "list"),
("single_dict_output_3", _get_model_single_output_dict, "dict"),
("single_dict_output_4", _get_model_single_output_dict, "dict_list"),
)
def test_functional_single_output(self, model_fn, loss_type):
model = model_fn()
self.assertTrue(isinstance(model, Functional))
loss = "mean_squared_error"
if loss_type == "list":
loss = [loss]
elif loss_type == "dict":
loss = {"output_a": loss}
elif loss_type == "dict_lsit":
loss = {"output_a": [loss]}
model.compile(
optimizer="sgd",
loss=loss,
metrics={
"output_a": ["mean_squared_error", "mean_absolute_error"],
},
weighted_metrics={
"output_a": "mean_squared_error",
},
)
# Fit the model to make sure compile_metrics are built
x = np.random.rand(8, 3)
y = np.random.rand(8, 1)
hist = model.fit(
x,
y,
batch_size=2,
epochs=1,
verbose=0,
)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(
[
"loss",
"mean_absolute_error",
"mean_squared_error",
"weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
@ -101,9 +178,41 @@ class ModelTest(testing.TestCase):
optimizer="sgd",
loss=["mean_squared_error", "binary_crossentropy"],
metrics=[
["mean_squared_error"],
"mean_squared_error",
["mean_squared_error", "accuracy"],
],
loss_weights=[0.1, 2],
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_nested_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss=["mean_squared_error", ["binary_crossentropy"]],
metrics=[
"mean_squared_error",
["mean_squared_error", "accuracy"],
],
loss_weights=[0.1, 2],
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
@ -131,12 +240,16 @@ class ModelTest(testing.TestCase):
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
"output_b": ["binary_crossentropy"],
},
metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(
@ -153,9 +266,12 @@ class ModelTest(testing.TestCase):
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
@ -176,6 +292,10 @@ class ModelTest(testing.TestCase):
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
@ -186,6 +306,50 @@ class ModelTest(testing.TestCase):
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# `output_b_accuracy` doesn't have `weighted_` in metric name.
# When a metric is only in weighted metrics, it skips `weighted_`
# prefix. This behavior matches`tf.keras`.
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
@ -331,3 +495,24 @@ class ModelTest(testing.TestCase):
"key 'output_c' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
def test_functional_list_outputs_invalid_nested_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss=[
"mean_squared_error",
["mean_squared_error", "binary_crossentropy"],
],
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"when providing the `loss` argument as a list, "
"it should have as many entries as the model has outputs",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)

@ -8,9 +8,10 @@ from keras_core.utils.naming import get_object_name
class MetricsList(metrics_module.Metric):
def __init__(self, metrics, name="metrics_list"):
def __init__(self, metrics, name="metrics_list", output_name=None):
super().__init__(name=name)
self.metrics = metrics
self.output_name = output_name
def update_state(self, y_true, y_pred, sample_weight=None):
for m in self.metrics:
@ -54,7 +55,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred):
return is_binary, is_sparse_categorical
def get_metric(identifier, y_true, y_pred, name_prefix=None):
def get_metric(identifier, y_true, y_pred):
if identifier is None:
return None # Ok to have no metric for an output.
@ -85,8 +86,6 @@ def get_metric(identifier, y_true, y_pred, name_prefix=None):
metric_obj = metrics_module.MeanMetricWrapper(
metric_obj, name=metric_name
)
if name_prefix and not metric_obj.name.startswith(name_prefix):
metric_obj.name = "_".join([name_prefix, metric_obj.name])
return metric_obj
@ -202,17 +201,22 @@ class CompileMetrics(metrics_module.Metric):
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
):
flat_metrics = []
if isinstance(metrics, dict):
for name in metrics.keys():
if name not in output_names:
raise ValueError(
f"In the dict argument `{argument_name}`, key "
f"'{name}' does not correspond to any model "
f"output. Received:\n{argument_name}={metrics}"
)
if num_outputs == 1:
if not metrics:
flat_metrics.append(None)
else:
if isinstance(metrics, dict):
metrics = nest.flatten(metrics)
if not isinstance(metrics, list):
raise ValueError(
"When there is only a single output, the "
f"`{argument_name}` argument must be a list of metric "
f"objects. Received instead:\n"
f"{argument_name}={metrics} of type {type(metrics)}"
)
metrics = [metrics]
if not all(is_function_like(m) for m in metrics):
raise ValueError(
f"Expected all entries in the `{argument_name}` list "
@ -239,20 +243,12 @@ class CompileMetrics(metrics_module.Metric):
f"length {len(metrics)} whereas the model has "
f"{len(y_pred)} outputs."
)
if not all(isinstance(mls, list) for mls in metrics):
raise ValueError(
"For a model with multiple outputs, "
f"when providing the `{argument_name}` argument as a "
"list, each list entry should itself be a list "
"(the list of metrics corresponding to that output). "
f"Received:\n{argument_name}={metrics}"
)
name = None
for idx, (mls, yt, yp) in enumerate(
zip(metrics, y_true, y_pred)
):
if output_names:
name = output_names[idx]
if not isinstance(mls, list):
mls = [mls]
name = output_names[idx] if output_names else None
if not all(is_function_like(e) for e in mls):
raise ValueError(
f"All entries in the sublists of the "
@ -263,10 +259,11 @@ class CompileMetrics(metrics_module.Metric):
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp, name)
get_metric(m, yt, yp)
for m in mls
if m is not None
]
],
output_name=name,
)
)
elif isinstance(metrics, dict):
@ -277,21 +274,8 @@ class CompileMetrics(metrics_module.Metric):
f"Received {argument_name}={metrics}"
)
for name in metrics.keys():
if name not in output_names:
raise ValueError(
f"In the dict argument `{argument_name}`, key "
f"'{name}' does not correspond to any model "
f"output. Received:\n{argument_name}={metrics}"
)
if not isinstance(metrics[name], list):
raise ValueError(
"For a model with multiple outputs, "
f"when providing the `{argument_name}` argument as "
"a dict, each dict entry should be a list (the "
"list of metrics corresponding to that output). "
f"At key '{name}', received invalid type:\n"
f"{metrics[name]}"
)
metrics[name] = [metrics[name]]
if not all(is_function_like(e) for e in metrics[name]):
raise ValueError(
f"All entries in the sublists of the "
@ -304,10 +288,11 @@ class CompileMetrics(metrics_module.Metric):
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp, name)
get_metric(m, yt, yp)
for m in metrics[name]
if m is not None
]
],
output_name=name,
)
)
else:
@ -365,23 +350,32 @@ class CompileMetrics(metrics_module.Metric):
if not mls:
continue
for m in mls.metrics:
if m.name not in unique_name_counters:
results[m.name] = m.result()
unique_name_counters[m.name] = 1
name = m.name
if mls.output_name:
name = f"{mls.output_name}_{name}"
if name not in unique_name_counters:
results[name] = m.result()
unique_name_counters[name] = 1
else:
name = f"{m.name}_{unique_name_counters[m.name]}"
unique_name_counters[m.name] += 1
index = unique_name_counters[name]
unique_name_counters[name] += 1
name = f"{name}_{index}"
results[name] = m.result()
for mls in self._flat_weighted_metrics:
if not mls:
continue
for m in mls.metrics:
if m.name not in unique_name_counters:
results[m.name] = m.result()
unique_name_counters[m.name] = 1
name = m.name
if mls.output_name:
name = f"{mls.output_name}_{name}"
if name not in unique_name_counters:
results[name] = m.result()
unique_name_counters[name] = 1
else:
name = f"weighted_{m.name}"
if mls.output_name:
name = f"{mls.output_name}_{name}"
if name not in unique_name_counters:
unique_name_counters[name] = 1
else:
@ -442,12 +436,25 @@ class CompileLoss(losses_module.Loss):
flat_losses = []
flat_loss_weights = []
if num_outputs == 1 and not is_function_like(loss):
raise ValueError(
"When there is only a single output, the `loss` argument "
"must be a callable. "
f"Received instead:\nloss={loss} of type {type(loss)}"
)
if isinstance(loss, dict):
for name in loss.keys():
if name not in output_names:
raise ValueError(
"In the dict argument `loss`, key "
f"'{name}' does not correspond to any model output. "
f"Received:\nloss={loss}"
)
if num_outputs == 1:
if isinstance(loss, dict):
loss = nest.flatten(loss)
if isinstance(loss, list) and len(loss) == 1:
loss = loss[0]
if not is_function_like(loss):
raise ValueError(
"When there is only a single output, the `loss` argument "
"must be a callable. "
f"Received instead:\nloss={loss} of type {type(loss)}"
)
if is_function_like(loss) and nest.is_nested(y_pred):
# The model has multiple outputs but only one loss fn
@ -471,6 +478,7 @@ class CompileLoss(losses_module.Loss):
else:
flat_loss_weights.append(1.0)
elif isinstance(loss, (list, tuple)):
loss = nest.flatten(loss)
if len(loss) != len(y_pred):
raise ValueError(
"For a model with multiple outputs, "
@ -507,11 +515,11 @@ class CompileLoss(losses_module.Loss):
f"length {len(loss_weights)} whereas the model has "
f"{len(y_pred)} outputs."
)
if not all(isinstance(e, float) for e in loss_weights):
if not all(isinstance(e, (int, float)) for e in loss_weights):
raise ValueError(
"For a model with multiple outputs, "
"when providing the `loss_weights` argument as a "
"list, each list entry should be a Python float (the "
"For a model with multiple outputs, when providing "
"the `loss_weights` argument as a list, "
"each list entry should be a Python int or float (the "
"weighting coefficient corresponding to the loss for "
f"that output). Received: loss_weights={loss_weights}"
)
@ -526,12 +534,8 @@ class CompileLoss(losses_module.Loss):
f"Received loss={loss}"
)
for name in loss.keys():
if name not in output_names:
raise ValueError(
"In the dict argument `loss`, key "
f"'{name}' does not correspond to any model output. "
f"Received:\nloss={loss}"
)
if isinstance(loss[name], list) and len(loss[name]) == 1:
loss[name] = loss[name][0]
if not is_function_like(loss[name]):
raise ValueError(
"For a model with multiple outputs, "

@ -179,13 +179,13 @@ class TestCompileMetrics(testing.TestCase):
self.assertAllClose(result["output_1_mse"], 0.055833336)
self.assertAllClose(result["output_2_mse"], 0.055833336)
self.assertAllClose(
result["weighted_output_1_mean_squared_error"], 0.0725
result["output_1_weighted_mean_squared_error"], 0.0725
)
self.assertAllClose(
result["weighted_output_2_mean_squared_error"], 0.0725
result["output_2_weighted_mean_squared_error"], 0.0725
)
self.assertAllClose(result["weighted_output_1_mse"], 0.0725)
self.assertAllClose(result["weighted_output_2_mse"], 0.0725)
self.assertAllClose(result["output_1_weighted_mse"], 0.0725)
self.assertAllClose(result["output_2_weighted_mse"], 0.0725)
compile_metrics.reset_state()
result = compile_metrics.result()
@ -193,8 +193,8 @@ class TestCompileMetrics(testing.TestCase):
self.assertEqual(len(result), 8)
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0)
self.assertAllClose(result["output_1_weighted_mean_squared_error"], 0.0)
self.assertAllClose(result["output_2_weighted_mean_squared_error"], 0.0)
def test_name_conversions(self):
compile_metrics = CompileMetrics(