Align history metrics output with tf.keras
(#392)
* Align compile metrics with * Update tests for weighted metrics * Update metrics for losses as single element list * Update metrics for losses as single element list
This commit is contained in:
parent
c373b6e1da
commit
b2e7bf28bc
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
@ -34,6 +35,27 @@ def _get_model_multi_outputs_list_no_output_names():
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_single_output():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
model = Model(x, output_a)
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_single_output_list():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
model = Model(x, [output_a])
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_single_output_dict():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
model = Model(x, {"output_a": output_a})
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_multi_outputs_dict():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
@ -42,7 +64,7 @@ def _get_model_multi_outputs_dict():
|
||||
return model
|
||||
|
||||
|
||||
class ModelTest(testing.TestCase):
|
||||
class ModelTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_functional_rerouting(self):
|
||||
model = _get_model()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
@ -91,6 +113,61 @@ class ModelTest(testing.TestCase):
|
||||
)
|
||||
self.assertTrue(isinstance(new_model, Functional))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("single_output_1", _get_model_single_output, None),
|
||||
("single_output_2", _get_model_single_output, "list"),
|
||||
("single_output_3", _get_model_single_output, "dict"),
|
||||
("single_output_4", _get_model_single_output, "dict_list"),
|
||||
("single_list_output_1", _get_model_single_output_list, None),
|
||||
("single_list_output_2", _get_model_single_output_list, "list"),
|
||||
("single_list_output_3", _get_model_single_output_list, "dict"),
|
||||
("single_list_output_4", _get_model_single_output_list, "dict_list"),
|
||||
("single_dict_output_1", _get_model_single_output_dict, None),
|
||||
("single_dict_output_2", _get_model_single_output_dict, "list"),
|
||||
("single_dict_output_3", _get_model_single_output_dict, "dict"),
|
||||
("single_dict_output_4", _get_model_single_output_dict, "dict_list"),
|
||||
)
|
||||
def test_functional_single_output(self, model_fn, loss_type):
|
||||
model = model_fn()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
loss = "mean_squared_error"
|
||||
if loss_type == "list":
|
||||
loss = [loss]
|
||||
elif loss_type == "dict":
|
||||
loss = {"output_a": loss}
|
||||
elif loss_type == "dict_lsit":
|
||||
loss = {"output_a": [loss]}
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss=loss,
|
||||
metrics={
|
||||
"output_a": ["mean_squared_error", "mean_absolute_error"],
|
||||
},
|
||||
weighted_metrics={
|
||||
"output_a": "mean_squared_error",
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
x = np.random.rand(8, 3)
|
||||
y = np.random.rand(8, 1)
|
||||
hist = model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
verbose=0,
|
||||
)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
"mean_absolute_error",
|
||||
"mean_squared_error",
|
||||
"weighted_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_list_losses(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
@ -101,9 +178,41 @@ class ModelTest(testing.TestCase):
|
||||
optimizer="sgd",
|
||||
loss=["mean_squared_error", "binary_crossentropy"],
|
||||
metrics=[
|
||||
["mean_squared_error"],
|
||||
"mean_squared_error",
|
||||
["mean_squared_error", "accuracy"],
|
||||
],
|
||||
loss_weights=[0.1, 2],
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_nested_list_losses(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss=["mean_squared_error", ["binary_crossentropy"]],
|
||||
metrics=[
|
||||
"mean_squared_error",
|
||||
["mean_squared_error", "accuracy"],
|
||||
],
|
||||
loss_weights=[0.1, 2],
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
@ -131,12 +240,16 @@ class ModelTest(testing.TestCase):
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
"output_b": ["binary_crossentropy"],
|
||||
},
|
||||
metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
weighted_metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(
|
||||
@ -153,9 +266,12 @@ class ModelTest(testing.TestCase):
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_a_weighted_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
"output_b_weighted_accuracy",
|
||||
"output_b_weighted_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
@ -176,6 +292,10 @@ class ModelTest(testing.TestCase):
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
weighted_metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
@ -186,6 +306,50 @@ class ModelTest(testing.TestCase):
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_a_weighted_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
"output_b_weighted_accuracy",
|
||||
"output_b_weighted_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error"],
|
||||
},
|
||||
weighted_metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
# `output_b_accuracy` doesn't have `weighted_` in metric name.
|
||||
# When a metric is only in weighted metrics, it skips `weighted_`
|
||||
# prefix. This behavior matches`tf.keras`.
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_a_weighted_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
@ -331,3 +495,24 @@ class ModelTest(testing.TestCase):
|
||||
"key 'output_c' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
||||
def test_functional_list_outputs_invalid_nested_list_losses(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss=[
|
||||
"mean_squared_error",
|
||||
["mean_squared_error", "binary_crossentropy"],
|
||||
],
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"when providing the `loss` argument as a list, "
|
||||
"it should have as many entries as the model has outputs",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
@ -8,9 +8,10 @@ from keras_core.utils.naming import get_object_name
|
||||
|
||||
|
||||
class MetricsList(metrics_module.Metric):
|
||||
def __init__(self, metrics, name="metrics_list"):
|
||||
def __init__(self, metrics, name="metrics_list", output_name=None):
|
||||
super().__init__(name=name)
|
||||
self.metrics = metrics
|
||||
self.output_name = output_name
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
for m in self.metrics:
|
||||
@ -54,7 +55,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred):
|
||||
return is_binary, is_sparse_categorical
|
||||
|
||||
|
||||
def get_metric(identifier, y_true, y_pred, name_prefix=None):
|
||||
def get_metric(identifier, y_true, y_pred):
|
||||
if identifier is None:
|
||||
return None # Ok to have no metric for an output.
|
||||
|
||||
@ -85,8 +86,6 @@ def get_metric(identifier, y_true, y_pred, name_prefix=None):
|
||||
metric_obj = metrics_module.MeanMetricWrapper(
|
||||
metric_obj, name=metric_name
|
||||
)
|
||||
if name_prefix and not metric_obj.name.startswith(name_prefix):
|
||||
metric_obj.name = "_".join([name_prefix, metric_obj.name])
|
||||
return metric_obj
|
||||
|
||||
|
||||
@ -202,17 +201,22 @@ class CompileMetrics(metrics_module.Metric):
|
||||
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
|
||||
):
|
||||
flat_metrics = []
|
||||
if isinstance(metrics, dict):
|
||||
for name in metrics.keys():
|
||||
if name not in output_names:
|
||||
raise ValueError(
|
||||
f"In the dict argument `{argument_name}`, key "
|
||||
f"'{name}' does not correspond to any model "
|
||||
f"output. Received:\n{argument_name}={metrics}"
|
||||
)
|
||||
if num_outputs == 1:
|
||||
if not metrics:
|
||||
flat_metrics.append(None)
|
||||
else:
|
||||
if isinstance(metrics, dict):
|
||||
metrics = nest.flatten(metrics)
|
||||
if not isinstance(metrics, list):
|
||||
raise ValueError(
|
||||
"When there is only a single output, the "
|
||||
f"`{argument_name}` argument must be a list of metric "
|
||||
f"objects. Received instead:\n"
|
||||
f"{argument_name}={metrics} of type {type(metrics)}"
|
||||
)
|
||||
metrics = [metrics]
|
||||
if not all(is_function_like(m) for m in metrics):
|
||||
raise ValueError(
|
||||
f"Expected all entries in the `{argument_name}` list "
|
||||
@ -239,20 +243,12 @@ class CompileMetrics(metrics_module.Metric):
|
||||
f"length {len(metrics)} whereas the model has "
|
||||
f"{len(y_pred)} outputs."
|
||||
)
|
||||
if not all(isinstance(mls, list) for mls in metrics):
|
||||
raise ValueError(
|
||||
"For a model with multiple outputs, "
|
||||
f"when providing the `{argument_name}` argument as a "
|
||||
"list, each list entry should itself be a list "
|
||||
"(the list of metrics corresponding to that output). "
|
||||
f"Received:\n{argument_name}={metrics}"
|
||||
)
|
||||
name = None
|
||||
for idx, (mls, yt, yp) in enumerate(
|
||||
zip(metrics, y_true, y_pred)
|
||||
):
|
||||
if output_names:
|
||||
name = output_names[idx]
|
||||
if not isinstance(mls, list):
|
||||
mls = [mls]
|
||||
name = output_names[idx] if output_names else None
|
||||
if not all(is_function_like(e) for e in mls):
|
||||
raise ValueError(
|
||||
f"All entries in the sublists of the "
|
||||
@ -263,10 +259,11 @@ class CompileMetrics(metrics_module.Metric):
|
||||
flat_metrics.append(
|
||||
MetricsList(
|
||||
[
|
||||
get_metric(m, yt, yp, name)
|
||||
get_metric(m, yt, yp)
|
||||
for m in mls
|
||||
if m is not None
|
||||
]
|
||||
],
|
||||
output_name=name,
|
||||
)
|
||||
)
|
||||
elif isinstance(metrics, dict):
|
||||
@ -277,21 +274,8 @@ class CompileMetrics(metrics_module.Metric):
|
||||
f"Received {argument_name}={metrics}"
|
||||
)
|
||||
for name in metrics.keys():
|
||||
if name not in output_names:
|
||||
raise ValueError(
|
||||
f"In the dict argument `{argument_name}`, key "
|
||||
f"'{name}' does not correspond to any model "
|
||||
f"output. Received:\n{argument_name}={metrics}"
|
||||
)
|
||||
if not isinstance(metrics[name], list):
|
||||
raise ValueError(
|
||||
"For a model with multiple outputs, "
|
||||
f"when providing the `{argument_name}` argument as "
|
||||
"a dict, each dict entry should be a list (the "
|
||||
"list of metrics corresponding to that output). "
|
||||
f"At key '{name}', received invalid type:\n"
|
||||
f"{metrics[name]}"
|
||||
)
|
||||
metrics[name] = [metrics[name]]
|
||||
if not all(is_function_like(e) for e in metrics[name]):
|
||||
raise ValueError(
|
||||
f"All entries in the sublists of the "
|
||||
@ -304,10 +288,11 @@ class CompileMetrics(metrics_module.Metric):
|
||||
flat_metrics.append(
|
||||
MetricsList(
|
||||
[
|
||||
get_metric(m, yt, yp, name)
|
||||
get_metric(m, yt, yp)
|
||||
for m in metrics[name]
|
||||
if m is not None
|
||||
]
|
||||
],
|
||||
output_name=name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -365,23 +350,32 @@ class CompileMetrics(metrics_module.Metric):
|
||||
if not mls:
|
||||
continue
|
||||
for m in mls.metrics:
|
||||
if m.name not in unique_name_counters:
|
||||
results[m.name] = m.result()
|
||||
unique_name_counters[m.name] = 1
|
||||
name = m.name
|
||||
if mls.output_name:
|
||||
name = f"{mls.output_name}_{name}"
|
||||
if name not in unique_name_counters:
|
||||
results[name] = m.result()
|
||||
unique_name_counters[name] = 1
|
||||
else:
|
||||
name = f"{m.name}_{unique_name_counters[m.name]}"
|
||||
unique_name_counters[m.name] += 1
|
||||
index = unique_name_counters[name]
|
||||
unique_name_counters[name] += 1
|
||||
name = f"{name}_{index}"
|
||||
results[name] = m.result()
|
||||
|
||||
for mls in self._flat_weighted_metrics:
|
||||
if not mls:
|
||||
continue
|
||||
for m in mls.metrics:
|
||||
if m.name not in unique_name_counters:
|
||||
results[m.name] = m.result()
|
||||
unique_name_counters[m.name] = 1
|
||||
name = m.name
|
||||
if mls.output_name:
|
||||
name = f"{mls.output_name}_{name}"
|
||||
if name not in unique_name_counters:
|
||||
results[name] = m.result()
|
||||
unique_name_counters[name] = 1
|
||||
else:
|
||||
name = f"weighted_{m.name}"
|
||||
if mls.output_name:
|
||||
name = f"{mls.output_name}_{name}"
|
||||
if name not in unique_name_counters:
|
||||
unique_name_counters[name] = 1
|
||||
else:
|
||||
@ -442,7 +436,20 @@ class CompileLoss(losses_module.Loss):
|
||||
flat_losses = []
|
||||
flat_loss_weights = []
|
||||
|
||||
if num_outputs == 1 and not is_function_like(loss):
|
||||
if isinstance(loss, dict):
|
||||
for name in loss.keys():
|
||||
if name not in output_names:
|
||||
raise ValueError(
|
||||
"In the dict argument `loss`, key "
|
||||
f"'{name}' does not correspond to any model output. "
|
||||
f"Received:\nloss={loss}"
|
||||
)
|
||||
if num_outputs == 1:
|
||||
if isinstance(loss, dict):
|
||||
loss = nest.flatten(loss)
|
||||
if isinstance(loss, list) and len(loss) == 1:
|
||||
loss = loss[0]
|
||||
if not is_function_like(loss):
|
||||
raise ValueError(
|
||||
"When there is only a single output, the `loss` argument "
|
||||
"must be a callable. "
|
||||
@ -471,6 +478,7 @@ class CompileLoss(losses_module.Loss):
|
||||
else:
|
||||
flat_loss_weights.append(1.0)
|
||||
elif isinstance(loss, (list, tuple)):
|
||||
loss = nest.flatten(loss)
|
||||
if len(loss) != len(y_pred):
|
||||
raise ValueError(
|
||||
"For a model with multiple outputs, "
|
||||
@ -507,11 +515,11 @@ class CompileLoss(losses_module.Loss):
|
||||
f"length {len(loss_weights)} whereas the model has "
|
||||
f"{len(y_pred)} outputs."
|
||||
)
|
||||
if not all(isinstance(e, float) for e in loss_weights):
|
||||
if not all(isinstance(e, (int, float)) for e in loss_weights):
|
||||
raise ValueError(
|
||||
"For a model with multiple outputs, "
|
||||
"when providing the `loss_weights` argument as a "
|
||||
"list, each list entry should be a Python float (the "
|
||||
"For a model with multiple outputs, when providing "
|
||||
"the `loss_weights` argument as a list, "
|
||||
"each list entry should be a Python int or float (the "
|
||||
"weighting coefficient corresponding to the loss for "
|
||||
f"that output). Received: loss_weights={loss_weights}"
|
||||
)
|
||||
@ -526,12 +534,8 @@ class CompileLoss(losses_module.Loss):
|
||||
f"Received loss={loss}"
|
||||
)
|
||||
for name in loss.keys():
|
||||
if name not in output_names:
|
||||
raise ValueError(
|
||||
"In the dict argument `loss`, key "
|
||||
f"'{name}' does not correspond to any model output. "
|
||||
f"Received:\nloss={loss}"
|
||||
)
|
||||
if isinstance(loss[name], list) and len(loss[name]) == 1:
|
||||
loss[name] = loss[name][0]
|
||||
if not is_function_like(loss[name]):
|
||||
raise ValueError(
|
||||
"For a model with multiple outputs, "
|
||||
|
@ -179,13 +179,13 @@ class TestCompileMetrics(testing.TestCase):
|
||||
self.assertAllClose(result["output_1_mse"], 0.055833336)
|
||||
self.assertAllClose(result["output_2_mse"], 0.055833336)
|
||||
self.assertAllClose(
|
||||
result["weighted_output_1_mean_squared_error"], 0.0725
|
||||
result["output_1_weighted_mean_squared_error"], 0.0725
|
||||
)
|
||||
self.assertAllClose(
|
||||
result["weighted_output_2_mean_squared_error"], 0.0725
|
||||
result["output_2_weighted_mean_squared_error"], 0.0725
|
||||
)
|
||||
self.assertAllClose(result["weighted_output_1_mse"], 0.0725)
|
||||
self.assertAllClose(result["weighted_output_2_mse"], 0.0725)
|
||||
self.assertAllClose(result["output_1_weighted_mse"], 0.0725)
|
||||
self.assertAllClose(result["output_2_weighted_mse"], 0.0725)
|
||||
|
||||
compile_metrics.reset_state()
|
||||
result = compile_metrics.result()
|
||||
@ -193,8 +193,8 @@ class TestCompileMetrics(testing.TestCase):
|
||||
self.assertEqual(len(result), 8)
|
||||
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["output_1_weighted_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["output_2_weighted_mean_squared_error"], 0.0)
|
||||
|
||||
def test_name_conversions(self):
|
||||
compile_metrics = CompileMetrics(
|
||||
|
Loading…
Reference in New Issue
Block a user