Align history metrics output with tf.keras
(#392)
* Align compile metrics with * Update tests for weighted metrics * Update metrics for losses as single element list * Update metrics for losses as single element list
This commit is contained in:
parent
c373b6e1da
commit
b2e7bf28bc
@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from keras_core import layers
|
from keras_core import layers
|
||||||
from keras_core import testing
|
from keras_core import testing
|
||||||
@ -34,6 +35,27 @@ def _get_model_multi_outputs_list_no_output_names():
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_single_output():
|
||||||
|
x = Input(shape=(3,), name="input_a")
|
||||||
|
output_a = layers.Dense(1, name="output_a")(x)
|
||||||
|
model = Model(x, output_a)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_single_output_list():
|
||||||
|
x = Input(shape=(3,), name="input_a")
|
||||||
|
output_a = layers.Dense(1, name="output_a")(x)
|
||||||
|
model = Model(x, [output_a])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_single_output_dict():
|
||||||
|
x = Input(shape=(3,), name="input_a")
|
||||||
|
output_a = layers.Dense(1, name="output_a")(x)
|
||||||
|
model = Model(x, {"output_a": output_a})
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _get_model_multi_outputs_dict():
|
def _get_model_multi_outputs_dict():
|
||||||
x = Input(shape=(3,), name="input_a")
|
x = Input(shape=(3,), name="input_a")
|
||||||
output_a = layers.Dense(1, name="output_a")(x)
|
output_a = layers.Dense(1, name="output_a")(x)
|
||||||
@ -42,7 +64,7 @@ def _get_model_multi_outputs_dict():
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class ModelTest(testing.TestCase):
|
class ModelTest(testing.TestCase, parameterized.TestCase):
|
||||||
def test_functional_rerouting(self):
|
def test_functional_rerouting(self):
|
||||||
model = _get_model()
|
model = _get_model()
|
||||||
self.assertTrue(isinstance(model, Functional))
|
self.assertTrue(isinstance(model, Functional))
|
||||||
@ -91,6 +113,61 @@ class ModelTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(isinstance(new_model, Functional))
|
self.assertTrue(isinstance(new_model, Functional))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("single_output_1", _get_model_single_output, None),
|
||||||
|
("single_output_2", _get_model_single_output, "list"),
|
||||||
|
("single_output_3", _get_model_single_output, "dict"),
|
||||||
|
("single_output_4", _get_model_single_output, "dict_list"),
|
||||||
|
("single_list_output_1", _get_model_single_output_list, None),
|
||||||
|
("single_list_output_2", _get_model_single_output_list, "list"),
|
||||||
|
("single_list_output_3", _get_model_single_output_list, "dict"),
|
||||||
|
("single_list_output_4", _get_model_single_output_list, "dict_list"),
|
||||||
|
("single_dict_output_1", _get_model_single_output_dict, None),
|
||||||
|
("single_dict_output_2", _get_model_single_output_dict, "list"),
|
||||||
|
("single_dict_output_3", _get_model_single_output_dict, "dict"),
|
||||||
|
("single_dict_output_4", _get_model_single_output_dict, "dict_list"),
|
||||||
|
)
|
||||||
|
def test_functional_single_output(self, model_fn, loss_type):
|
||||||
|
model = model_fn()
|
||||||
|
self.assertTrue(isinstance(model, Functional))
|
||||||
|
loss = "mean_squared_error"
|
||||||
|
if loss_type == "list":
|
||||||
|
loss = [loss]
|
||||||
|
elif loss_type == "dict":
|
||||||
|
loss = {"output_a": loss}
|
||||||
|
elif loss_type == "dict_lsit":
|
||||||
|
loss = {"output_a": [loss]}
|
||||||
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss=loss,
|
||||||
|
metrics={
|
||||||
|
"output_a": ["mean_squared_error", "mean_absolute_error"],
|
||||||
|
},
|
||||||
|
weighted_metrics={
|
||||||
|
"output_a": "mean_squared_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fit the model to make sure compile_metrics are built
|
||||||
|
x = np.random.rand(8, 3)
|
||||||
|
y = np.random.rand(8, 1)
|
||||||
|
hist = model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=2,
|
||||||
|
epochs=1,
|
||||||
|
verbose=0,
|
||||||
|
)
|
||||||
|
hist_keys = sorted(hist.history.keys())
|
||||||
|
ref_keys = sorted(
|
||||||
|
[
|
||||||
|
"loss",
|
||||||
|
"mean_absolute_error",
|
||||||
|
"mean_squared_error",
|
||||||
|
"weighted_mean_squared_error",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertListEqual(hist_keys, ref_keys)
|
||||||
|
|
||||||
def test_functional_list_outputs_list_losses(self):
|
def test_functional_list_outputs_list_losses(self):
|
||||||
model = _get_model_multi_outputs_list()
|
model = _get_model_multi_outputs_list()
|
||||||
self.assertTrue(isinstance(model, Functional))
|
self.assertTrue(isinstance(model, Functional))
|
||||||
@ -101,9 +178,41 @@ class ModelTest(testing.TestCase):
|
|||||||
optimizer="sgd",
|
optimizer="sgd",
|
||||||
loss=["mean_squared_error", "binary_crossentropy"],
|
loss=["mean_squared_error", "binary_crossentropy"],
|
||||||
metrics=[
|
metrics=[
|
||||||
["mean_squared_error"],
|
"mean_squared_error",
|
||||||
["mean_squared_error", "accuracy"],
|
["mean_squared_error", "accuracy"],
|
||||||
],
|
],
|
||||||
|
loss_weights=[0.1, 2],
|
||||||
|
)
|
||||||
|
# Fit the model to make sure compile_metrics are built
|
||||||
|
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
|
hist_keys = sorted(hist.history.keys())
|
||||||
|
# TODO `tf.keras` also outputs individual losses for outputs
|
||||||
|
ref_keys = sorted(
|
||||||
|
[
|
||||||
|
"loss",
|
||||||
|
# "output_a_loss",
|
||||||
|
"output_a_mean_squared_error",
|
||||||
|
"output_b_accuracy",
|
||||||
|
# "output_b_loss",
|
||||||
|
"output_b_mean_squared_error",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertListEqual(hist_keys, ref_keys)
|
||||||
|
|
||||||
|
def test_functional_list_outputs_nested_list_losses(self):
|
||||||
|
model = _get_model_multi_outputs_list()
|
||||||
|
self.assertTrue(isinstance(model, Functional))
|
||||||
|
x = np.random.rand(8, 3)
|
||||||
|
y1 = np.random.rand(8, 1)
|
||||||
|
y2 = np.random.randint(0, 2, (8, 1))
|
||||||
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss=["mean_squared_error", ["binary_crossentropy"]],
|
||||||
|
metrics=[
|
||||||
|
"mean_squared_error",
|
||||||
|
["mean_squared_error", "accuracy"],
|
||||||
|
],
|
||||||
|
loss_weights=[0.1, 2],
|
||||||
)
|
)
|
||||||
# Fit the model to make sure compile_metrics are built
|
# Fit the model to make sure compile_metrics are built
|
||||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
@ -131,12 +240,16 @@ class ModelTest(testing.TestCase):
|
|||||||
optimizer="sgd",
|
optimizer="sgd",
|
||||||
loss={
|
loss={
|
||||||
"output_a": "mean_squared_error",
|
"output_a": "mean_squared_error",
|
||||||
"output_b": "binary_crossentropy",
|
"output_b": ["binary_crossentropy"],
|
||||||
},
|
},
|
||||||
metrics={
|
metrics={
|
||||||
"output_a": ["mean_squared_error"],
|
"output_a": ["mean_squared_error"],
|
||||||
"output_b": ["mean_squared_error", "accuracy"],
|
"output_b": ["mean_squared_error", "accuracy"],
|
||||||
},
|
},
|
||||||
|
weighted_metrics={
|
||||||
|
"output_a": ["mean_squared_error"],
|
||||||
|
"output_b": ["mean_squared_error", "accuracy"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Fit the model to make sure compile_metrics are built
|
# Fit the model to make sure compile_metrics are built
|
||||||
hist = model.fit(
|
hist = model.fit(
|
||||||
@ -153,9 +266,12 @@ class ModelTest(testing.TestCase):
|
|||||||
"loss",
|
"loss",
|
||||||
# "output_a_loss",
|
# "output_a_loss",
|
||||||
"output_a_mean_squared_error",
|
"output_a_mean_squared_error",
|
||||||
|
"output_a_weighted_mean_squared_error",
|
||||||
"output_b_accuracy",
|
"output_b_accuracy",
|
||||||
# "output_b_loss",
|
# "output_b_loss",
|
||||||
"output_b_mean_squared_error",
|
"output_b_mean_squared_error",
|
||||||
|
"output_b_weighted_accuracy",
|
||||||
|
"output_b_weighted_mean_squared_error",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.assertListEqual(hist_keys, ref_keys)
|
self.assertListEqual(hist_keys, ref_keys)
|
||||||
@ -176,6 +292,10 @@ class ModelTest(testing.TestCase):
|
|||||||
"output_a": ["mean_squared_error"],
|
"output_a": ["mean_squared_error"],
|
||||||
"output_b": ["mean_squared_error", "accuracy"],
|
"output_b": ["mean_squared_error", "accuracy"],
|
||||||
},
|
},
|
||||||
|
weighted_metrics={
|
||||||
|
"output_a": ["mean_squared_error"],
|
||||||
|
"output_b": ["mean_squared_error", "accuracy"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Fit the model to make sure compile_metrics are built
|
# Fit the model to make sure compile_metrics are built
|
||||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
@ -186,6 +306,50 @@ class ModelTest(testing.TestCase):
|
|||||||
"loss",
|
"loss",
|
||||||
# "output_a_loss",
|
# "output_a_loss",
|
||||||
"output_a_mean_squared_error",
|
"output_a_mean_squared_error",
|
||||||
|
"output_a_weighted_mean_squared_error",
|
||||||
|
"output_b_accuracy",
|
||||||
|
# "output_b_loss",
|
||||||
|
"output_b_mean_squared_error",
|
||||||
|
"output_b_weighted_accuracy",
|
||||||
|
"output_b_weighted_mean_squared_error",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertListEqual(hist_keys, ref_keys)
|
||||||
|
|
||||||
|
def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
|
||||||
|
model = _get_model_multi_outputs_list()
|
||||||
|
self.assertTrue(isinstance(model, Functional))
|
||||||
|
x = np.random.rand(8, 3)
|
||||||
|
y1 = np.random.rand(8, 1)
|
||||||
|
y2 = np.random.randint(0, 2, (8, 1))
|
||||||
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss={
|
||||||
|
"output_a": "mean_squared_error",
|
||||||
|
"output_b": "binary_crossentropy",
|
||||||
|
},
|
||||||
|
metrics={
|
||||||
|
"output_a": ["mean_squared_error"],
|
||||||
|
"output_b": ["mean_squared_error"],
|
||||||
|
},
|
||||||
|
weighted_metrics={
|
||||||
|
"output_a": ["mean_squared_error"],
|
||||||
|
"output_b": ["accuracy"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fit the model to make sure compile_metrics are built
|
||||||
|
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
|
hist_keys = sorted(hist.history.keys())
|
||||||
|
# TODO `tf.keras` also outputs individual losses for outputs
|
||||||
|
# `output_b_accuracy` doesn't have `weighted_` in metric name.
|
||||||
|
# When a metric is only in weighted metrics, it skips `weighted_`
|
||||||
|
# prefix. This behavior matches`tf.keras`.
|
||||||
|
ref_keys = sorted(
|
||||||
|
[
|
||||||
|
"loss",
|
||||||
|
# "output_a_loss",
|
||||||
|
"output_a_mean_squared_error",
|
||||||
|
"output_a_weighted_mean_squared_error",
|
||||||
"output_b_accuracy",
|
"output_b_accuracy",
|
||||||
# "output_b_loss",
|
# "output_b_loss",
|
||||||
"output_b_mean_squared_error",
|
"output_b_mean_squared_error",
|
||||||
@ -331,3 +495,24 @@ class ModelTest(testing.TestCase):
|
|||||||
"key 'output_c' does not correspond to any model output",
|
"key 'output_c' does not correspond to any model output",
|
||||||
):
|
):
|
||||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
|
|
||||||
|
def test_functional_list_outputs_invalid_nested_list_losses(self):
|
||||||
|
model = _get_model_multi_outputs_list()
|
||||||
|
self.assertTrue(isinstance(model, Functional))
|
||||||
|
x = np.random.rand(8, 3)
|
||||||
|
y1 = np.random.rand(8, 1)
|
||||||
|
y2 = np.random.randint(0, 2, (8, 1))
|
||||||
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss=[
|
||||||
|
"mean_squared_error",
|
||||||
|
["mean_squared_error", "binary_crossentropy"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Fit the model to make sure compile_metrics are built
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"when providing the `loss` argument as a list, "
|
||||||
|
"it should have as many entries as the model has outputs",
|
||||||
|
):
|
||||||
|
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||||
|
@ -8,9 +8,10 @@ from keras_core.utils.naming import get_object_name
|
|||||||
|
|
||||||
|
|
||||||
class MetricsList(metrics_module.Metric):
|
class MetricsList(metrics_module.Metric):
|
||||||
def __init__(self, metrics, name="metrics_list"):
|
def __init__(self, metrics, name="metrics_list", output_name=None):
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.output_name = output_name
|
||||||
|
|
||||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
for m in self.metrics:
|
for m in self.metrics:
|
||||||
@ -54,7 +55,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred):
|
|||||||
return is_binary, is_sparse_categorical
|
return is_binary, is_sparse_categorical
|
||||||
|
|
||||||
|
|
||||||
def get_metric(identifier, y_true, y_pred, name_prefix=None):
|
def get_metric(identifier, y_true, y_pred):
|
||||||
if identifier is None:
|
if identifier is None:
|
||||||
return None # Ok to have no metric for an output.
|
return None # Ok to have no metric for an output.
|
||||||
|
|
||||||
@ -85,8 +86,6 @@ def get_metric(identifier, y_true, y_pred, name_prefix=None):
|
|||||||
metric_obj = metrics_module.MeanMetricWrapper(
|
metric_obj = metrics_module.MeanMetricWrapper(
|
||||||
metric_obj, name=metric_name
|
metric_obj, name=metric_name
|
||||||
)
|
)
|
||||||
if name_prefix and not metric_obj.name.startswith(name_prefix):
|
|
||||||
metric_obj.name = "_".join([name_prefix, metric_obj.name])
|
|
||||||
return metric_obj
|
return metric_obj
|
||||||
|
|
||||||
|
|
||||||
@ -202,17 +201,22 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
|
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
|
||||||
):
|
):
|
||||||
flat_metrics = []
|
flat_metrics = []
|
||||||
|
if isinstance(metrics, dict):
|
||||||
|
for name in metrics.keys():
|
||||||
|
if name not in output_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"In the dict argument `{argument_name}`, key "
|
||||||
|
f"'{name}' does not correspond to any model "
|
||||||
|
f"output. Received:\n{argument_name}={metrics}"
|
||||||
|
)
|
||||||
if num_outputs == 1:
|
if num_outputs == 1:
|
||||||
if not metrics:
|
if not metrics:
|
||||||
flat_metrics.append(None)
|
flat_metrics.append(None)
|
||||||
else:
|
else:
|
||||||
|
if isinstance(metrics, dict):
|
||||||
|
metrics = nest.flatten(metrics)
|
||||||
if not isinstance(metrics, list):
|
if not isinstance(metrics, list):
|
||||||
raise ValueError(
|
metrics = [metrics]
|
||||||
"When there is only a single output, the "
|
|
||||||
f"`{argument_name}` argument must be a list of metric "
|
|
||||||
f"objects. Received instead:\n"
|
|
||||||
f"{argument_name}={metrics} of type {type(metrics)}"
|
|
||||||
)
|
|
||||||
if not all(is_function_like(m) for m in metrics):
|
if not all(is_function_like(m) for m in metrics):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected all entries in the `{argument_name}` list "
|
f"Expected all entries in the `{argument_name}` list "
|
||||||
@ -239,20 +243,12 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
f"length {len(metrics)} whereas the model has "
|
f"length {len(metrics)} whereas the model has "
|
||||||
f"{len(y_pred)} outputs."
|
f"{len(y_pred)} outputs."
|
||||||
)
|
)
|
||||||
if not all(isinstance(mls, list) for mls in metrics):
|
|
||||||
raise ValueError(
|
|
||||||
"For a model with multiple outputs, "
|
|
||||||
f"when providing the `{argument_name}` argument as a "
|
|
||||||
"list, each list entry should itself be a list "
|
|
||||||
"(the list of metrics corresponding to that output). "
|
|
||||||
f"Received:\n{argument_name}={metrics}"
|
|
||||||
)
|
|
||||||
name = None
|
|
||||||
for idx, (mls, yt, yp) in enumerate(
|
for idx, (mls, yt, yp) in enumerate(
|
||||||
zip(metrics, y_true, y_pred)
|
zip(metrics, y_true, y_pred)
|
||||||
):
|
):
|
||||||
if output_names:
|
if not isinstance(mls, list):
|
||||||
name = output_names[idx]
|
mls = [mls]
|
||||||
|
name = output_names[idx] if output_names else None
|
||||||
if not all(is_function_like(e) for e in mls):
|
if not all(is_function_like(e) for e in mls):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All entries in the sublists of the "
|
f"All entries in the sublists of the "
|
||||||
@ -263,10 +259,11 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
flat_metrics.append(
|
flat_metrics.append(
|
||||||
MetricsList(
|
MetricsList(
|
||||||
[
|
[
|
||||||
get_metric(m, yt, yp, name)
|
get_metric(m, yt, yp)
|
||||||
for m in mls
|
for m in mls
|
||||||
if m is not None
|
if m is not None
|
||||||
]
|
],
|
||||||
|
output_name=name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(metrics, dict):
|
elif isinstance(metrics, dict):
|
||||||
@ -277,21 +274,8 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
f"Received {argument_name}={metrics}"
|
f"Received {argument_name}={metrics}"
|
||||||
)
|
)
|
||||||
for name in metrics.keys():
|
for name in metrics.keys():
|
||||||
if name not in output_names:
|
|
||||||
raise ValueError(
|
|
||||||
f"In the dict argument `{argument_name}`, key "
|
|
||||||
f"'{name}' does not correspond to any model "
|
|
||||||
f"output. Received:\n{argument_name}={metrics}"
|
|
||||||
)
|
|
||||||
if not isinstance(metrics[name], list):
|
if not isinstance(metrics[name], list):
|
||||||
raise ValueError(
|
metrics[name] = [metrics[name]]
|
||||||
"For a model with multiple outputs, "
|
|
||||||
f"when providing the `{argument_name}` argument as "
|
|
||||||
"a dict, each dict entry should be a list (the "
|
|
||||||
"list of metrics corresponding to that output). "
|
|
||||||
f"At key '{name}', received invalid type:\n"
|
|
||||||
f"{metrics[name]}"
|
|
||||||
)
|
|
||||||
if not all(is_function_like(e) for e in metrics[name]):
|
if not all(is_function_like(e) for e in metrics[name]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All entries in the sublists of the "
|
f"All entries in the sublists of the "
|
||||||
@ -304,10 +288,11 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
flat_metrics.append(
|
flat_metrics.append(
|
||||||
MetricsList(
|
MetricsList(
|
||||||
[
|
[
|
||||||
get_metric(m, yt, yp, name)
|
get_metric(m, yt, yp)
|
||||||
for m in metrics[name]
|
for m in metrics[name]
|
||||||
if m is not None
|
if m is not None
|
||||||
]
|
],
|
||||||
|
output_name=name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -365,23 +350,32 @@ class CompileMetrics(metrics_module.Metric):
|
|||||||
if not mls:
|
if not mls:
|
||||||
continue
|
continue
|
||||||
for m in mls.metrics:
|
for m in mls.metrics:
|
||||||
if m.name not in unique_name_counters:
|
name = m.name
|
||||||
results[m.name] = m.result()
|
if mls.output_name:
|
||||||
unique_name_counters[m.name] = 1
|
name = f"{mls.output_name}_{name}"
|
||||||
|
if name not in unique_name_counters:
|
||||||
|
results[name] = m.result()
|
||||||
|
unique_name_counters[name] = 1
|
||||||
else:
|
else:
|
||||||
name = f"{m.name}_{unique_name_counters[m.name]}"
|
index = unique_name_counters[name]
|
||||||
unique_name_counters[m.name] += 1
|
unique_name_counters[name] += 1
|
||||||
|
name = f"{name}_{index}"
|
||||||
results[name] = m.result()
|
results[name] = m.result()
|
||||||
|
|
||||||
for mls in self._flat_weighted_metrics:
|
for mls in self._flat_weighted_metrics:
|
||||||
if not mls:
|
if not mls:
|
||||||
continue
|
continue
|
||||||
for m in mls.metrics:
|
for m in mls.metrics:
|
||||||
if m.name not in unique_name_counters:
|
name = m.name
|
||||||
results[m.name] = m.result()
|
if mls.output_name:
|
||||||
unique_name_counters[m.name] = 1
|
name = f"{mls.output_name}_{name}"
|
||||||
|
if name not in unique_name_counters:
|
||||||
|
results[name] = m.result()
|
||||||
|
unique_name_counters[name] = 1
|
||||||
else:
|
else:
|
||||||
name = f"weighted_{m.name}"
|
name = f"weighted_{m.name}"
|
||||||
|
if mls.output_name:
|
||||||
|
name = f"{mls.output_name}_{name}"
|
||||||
if name not in unique_name_counters:
|
if name not in unique_name_counters:
|
||||||
unique_name_counters[name] = 1
|
unique_name_counters[name] = 1
|
||||||
else:
|
else:
|
||||||
@ -442,12 +436,25 @@ class CompileLoss(losses_module.Loss):
|
|||||||
flat_losses = []
|
flat_losses = []
|
||||||
flat_loss_weights = []
|
flat_loss_weights = []
|
||||||
|
|
||||||
if num_outputs == 1 and not is_function_like(loss):
|
if isinstance(loss, dict):
|
||||||
raise ValueError(
|
for name in loss.keys():
|
||||||
"When there is only a single output, the `loss` argument "
|
if name not in output_names:
|
||||||
"must be a callable. "
|
raise ValueError(
|
||||||
f"Received instead:\nloss={loss} of type {type(loss)}"
|
"In the dict argument `loss`, key "
|
||||||
)
|
f"'{name}' does not correspond to any model output. "
|
||||||
|
f"Received:\nloss={loss}"
|
||||||
|
)
|
||||||
|
if num_outputs == 1:
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
loss = nest.flatten(loss)
|
||||||
|
if isinstance(loss, list) and len(loss) == 1:
|
||||||
|
loss = loss[0]
|
||||||
|
if not is_function_like(loss):
|
||||||
|
raise ValueError(
|
||||||
|
"When there is only a single output, the `loss` argument "
|
||||||
|
"must be a callable. "
|
||||||
|
f"Received instead:\nloss={loss} of type {type(loss)}"
|
||||||
|
)
|
||||||
|
|
||||||
if is_function_like(loss) and nest.is_nested(y_pred):
|
if is_function_like(loss) and nest.is_nested(y_pred):
|
||||||
# The model has multiple outputs but only one loss fn
|
# The model has multiple outputs but only one loss fn
|
||||||
@ -471,6 +478,7 @@ class CompileLoss(losses_module.Loss):
|
|||||||
else:
|
else:
|
||||||
flat_loss_weights.append(1.0)
|
flat_loss_weights.append(1.0)
|
||||||
elif isinstance(loss, (list, tuple)):
|
elif isinstance(loss, (list, tuple)):
|
||||||
|
loss = nest.flatten(loss)
|
||||||
if len(loss) != len(y_pred):
|
if len(loss) != len(y_pred):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For a model with multiple outputs, "
|
"For a model with multiple outputs, "
|
||||||
@ -507,11 +515,11 @@ class CompileLoss(losses_module.Loss):
|
|||||||
f"length {len(loss_weights)} whereas the model has "
|
f"length {len(loss_weights)} whereas the model has "
|
||||||
f"{len(y_pred)} outputs."
|
f"{len(y_pred)} outputs."
|
||||||
)
|
)
|
||||||
if not all(isinstance(e, float) for e in loss_weights):
|
if not all(isinstance(e, (int, float)) for e in loss_weights):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For a model with multiple outputs, "
|
"For a model with multiple outputs, when providing "
|
||||||
"when providing the `loss_weights` argument as a "
|
"the `loss_weights` argument as a list, "
|
||||||
"list, each list entry should be a Python float (the "
|
"each list entry should be a Python int or float (the "
|
||||||
"weighting coefficient corresponding to the loss for "
|
"weighting coefficient corresponding to the loss for "
|
||||||
f"that output). Received: loss_weights={loss_weights}"
|
f"that output). Received: loss_weights={loss_weights}"
|
||||||
)
|
)
|
||||||
@ -526,12 +534,8 @@ class CompileLoss(losses_module.Loss):
|
|||||||
f"Received loss={loss}"
|
f"Received loss={loss}"
|
||||||
)
|
)
|
||||||
for name in loss.keys():
|
for name in loss.keys():
|
||||||
if name not in output_names:
|
if isinstance(loss[name], list) and len(loss[name]) == 1:
|
||||||
raise ValueError(
|
loss[name] = loss[name][0]
|
||||||
"In the dict argument `loss`, key "
|
|
||||||
f"'{name}' does not correspond to any model output. "
|
|
||||||
f"Received:\nloss={loss}"
|
|
||||||
)
|
|
||||||
if not is_function_like(loss[name]):
|
if not is_function_like(loss[name]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For a model with multiple outputs, "
|
"For a model with multiple outputs, "
|
||||||
|
@ -179,13 +179,13 @@ class TestCompileMetrics(testing.TestCase):
|
|||||||
self.assertAllClose(result["output_1_mse"], 0.055833336)
|
self.assertAllClose(result["output_1_mse"], 0.055833336)
|
||||||
self.assertAllClose(result["output_2_mse"], 0.055833336)
|
self.assertAllClose(result["output_2_mse"], 0.055833336)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
result["weighted_output_1_mean_squared_error"], 0.0725
|
result["output_1_weighted_mean_squared_error"], 0.0725
|
||||||
)
|
)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
result["weighted_output_2_mean_squared_error"], 0.0725
|
result["output_2_weighted_mean_squared_error"], 0.0725
|
||||||
)
|
)
|
||||||
self.assertAllClose(result["weighted_output_1_mse"], 0.0725)
|
self.assertAllClose(result["output_1_weighted_mse"], 0.0725)
|
||||||
self.assertAllClose(result["weighted_output_2_mse"], 0.0725)
|
self.assertAllClose(result["output_2_weighted_mse"], 0.0725)
|
||||||
|
|
||||||
compile_metrics.reset_state()
|
compile_metrics.reset_state()
|
||||||
result = compile_metrics.result()
|
result = compile_metrics.result()
|
||||||
@ -193,8 +193,8 @@ class TestCompileMetrics(testing.TestCase):
|
|||||||
self.assertEqual(len(result), 8)
|
self.assertEqual(len(result), 8)
|
||||||
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
|
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
|
||||||
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
|
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
|
||||||
self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0)
|
self.assertAllClose(result["output_1_weighted_mean_squared_error"], 0.0)
|
||||||
self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0)
|
self.assertAllClose(result["output_2_weighted_mean_squared_error"], 0.0)
|
||||||
|
|
||||||
def test_name_conversions(self):
|
def test_name_conversions(self):
|
||||||
compile_metrics = CompileMetrics(
|
compile_metrics = CompileMetrics(
|
||||||
|
Loading…
Reference in New Issue
Block a user