Update compile loss and metrics to handle Dict and List Outputs (#362)

* Update compile loss and metrics to handle multi-output dict and list

* CompileMetrics to match  with per-output metrics

* Add model tests for invalid cases
This commit is contained in:
Ramesh Sampath 2023-06-22 01:21:55 +05:30 committed by Francois Chollet
parent d955292989
commit 0728d45414
5 changed files with 355 additions and 35 deletions

@ -155,6 +155,8 @@ class Functional(Function, Model):
# We will convert directly (to the correct dtype per input).
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
output_layers = [x._keras_history[0] for x in self.outputs]
self.output_names = [x.name for x in output_layers]
self._post_build()
@property

@ -8,22 +8,47 @@ from keras_core.models.model import Model
from keras_core.models.model import model_from_json
class ModelTest(testing.TestCase):
def _get_model(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Model([input_a, input_b], outputs)
return model
def _get_model():
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Model([input_a, input_b], outputs)
return model
def _get_model_multi_outputs_list():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x)
model = Model(x, [output_a, output_b])
return model
def _get_model_multi_outputs_list_no_output_names():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1)(x)
output_b = layers.Dense(1, activation="sigmoid")(x)
model = Model(x, [output_a, output_b])
return model
def _get_model_multi_outputs_dict():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x)
model = Model(x, {"output_a": output_a, "output_b": output_b})
return model
class ModelTest(testing.TestCase):
def test_functional_rerouting(self):
model = self._get_model()
model = _get_model()
self.assertTrue(isinstance(model, Functional))
def test_json_serialization(self):
model = self._get_model()
model = _get_model()
json_string = model.to_json()
new_model = model_from_json(json_string)
self.assertEqual(json_string, new_model.to_json())
@ -65,3 +90,244 @@ class ModelTest(testing.TestCase):
config, custom_objects={"CustomDense": CustomDense}
)
self.assertTrue(isinstance(new_model, Functional))
def test_functional_list_outputs_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss=["mean_squared_error", "binary_crossentropy"],
metrics=[
["mean_squared_error"],
["mean_squared_error", "accuracy"],
],
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_dict_outputs_dict_losses(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(
x,
{"output_a": y1, "output_b": y2},
batch_size=2,
epochs=1,
verbose=0,
)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_dict_losses_metrics(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_dict_losses_partial_metrics(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
def test_functional_list_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_c": "binary_crossentropy",
},
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"In the dict argument `loss`, "
"key 'output_c' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
def test_functional_list_outputs_dict_losses_no_output_names(self):
model = _get_model_multi_outputs_list_no_output_names()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={"output_a": "mean_squared_error"},
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"In the dict argument `loss`, "
"key 'output_a' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
def test_functional_list_outputs_dict_metrics_invalid_keys(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_c": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"In the dict argument `metrics`, "
"key 'output_c' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
def test_functional_dict_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_c": "binary_crossentropy",
},
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"In the dict argument `loss`, "
"key 'output_c' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
def test_functional_dict_outputs_dict_metrics_invalid_keys(self):
model = _get_model_multi_outputs_dict()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_c": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
ValueError,
"In the dict argument `metrics`, "
"key 'output_c' does not correspond to any model output",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)

@ -54,7 +54,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred):
return is_binary, is_sparse_categorical
def get_metric(identifier, y_true, y_pred):
def get_metric(identifier, y_true, y_pred, name_prefix=None):
if identifier is None:
return None # Ok to have no metric for an output.
@ -85,6 +85,8 @@ def get_metric(identifier, y_true, y_pred):
metric_obj = metrics_module.MeanMetricWrapper(
metric_obj, name=metric_name
)
if name_prefix and not metric_obj.name.startswith(name_prefix):
metric_obj.name = "_".join([name_prefix, metric_obj.name])
return metric_obj
@ -117,7 +119,13 @@ def get_loss(identifier, y_true, y_pred):
class CompileMetrics(metrics_module.Metric):
def __init__(self, metrics, weighted_metrics, name="compile_metric"):
def __init__(
self,
metrics,
weighted_metrics,
name="compile_metric",
output_names=None,
):
super().__init__(name=name)
if metrics and not isinstance(metrics, (list, tuple, dict)):
raise ValueError(
@ -136,6 +144,7 @@ class CompileMetrics(metrics_module.Metric):
self._user_weighted_metrics = weighted_metrics
self.built = False
self.name = "compile_metrics"
self.output_names = output_names
@property
def variables(self):
@ -150,9 +159,10 @@ class CompileMetrics(metrics_module.Metric):
return vars
def build(self, y_true, y_pred):
if isinstance(y_pred, dict):
if self.output_names:
output_names = self.output_names
elif isinstance(y_pred, dict):
output_names = sorted(list(y_pred.keys()))
num_outputs = len(output_names)
elif isinstance(y_pred, (list, tuple)):
num_outputs = len(y_pred)
if all(hasattr(x, "_keras_history") for x in y_pred):
@ -162,15 +172,14 @@ class CompileMetrics(metrics_module.Metric):
else:
output_names = None
num_outputs = 1
if output_names:
num_outputs = len(output_names)
y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true)
metrics = self._user_metrics
weighted_metrics = self._user_weighted_metrics
if output_names and not num_outputs:
num_outputs = len(output_names)
self._flat_metrics = self._build_metrics_set(
metrics,
num_outputs,
@ -238,7 +247,12 @@ class CompileMetrics(metrics_module.Metric):
"(the list of metrics corresponding to that output). "
f"Received:\n{argument_name}={metrics}"
)
for mls, yt, yp in zip(metrics, y_true, y_pred):
name = None
for idx, (mls, yt, yp) in enumerate(
zip(metrics, y_true, y_pred)
):
if output_names:
name = output_names[idx]
if not all(is_function_like(e) for e in mls):
raise ValueError(
f"All entries in the sublists of the "
@ -249,7 +263,7 @@ class CompileMetrics(metrics_module.Metric):
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp)
get_metric(m, yt, yp, name)
for m in mls
if m is not None
]
@ -290,7 +304,7 @@ class CompileMetrics(metrics_module.Metric):
flat_metrics.append(
MetricsList(
[
get_metric(m, yt, yp)
get_metric(m, yt, yp, name)
for m in metrics[name]
if m is not None
]
@ -310,6 +324,9 @@ class CompileMetrics(metrics_module.Metric):
m.update_state(y_t, y_p)
if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
# For multi-outputs, repeat sample weights for n outputs.
if len(sample_weight) < len(y_true):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
else:
sample_weight = [None for _ in range(len(y_true))]
for m, y_t, y_p, s_w in zip(
@ -375,7 +392,11 @@ class CompileMetrics(metrics_module.Metric):
class CompileLoss(losses_module.Loss):
def __init__(
self, loss, loss_weights=None, reduction="sum_over_batch_size"
self,
loss,
loss_weights=None,
reduction="sum_over_batch_size",
output_names=None,
):
if loss_weights and not isinstance(loss_weights, (list, tuple, dict)):
raise ValueError(
@ -386,12 +407,14 @@ class CompileLoss(losses_module.Loss):
self._user_loss = loss
self._user_loss_weights = loss_weights
self.built = False
self.output_names = output_names
super().__init__(name="compile_loss", reduction=reduction)
def build(self, y_true, y_pred):
if isinstance(y_pred, dict):
if self.output_names:
output_names = self.output_names
elif isinstance(y_pred, dict):
output_names = sorted(list(y_pred.keys()))
num_outputs = len(output_names)
elif isinstance(y_pred, (list, tuple)):
num_outputs = len(y_pred)
if all(hasattr(x, "_keras_history") for x in y_pred):
@ -401,6 +424,8 @@ class CompileLoss(losses_module.Loss):
else:
output_names = None
num_outputs = 1
if output_names:
num_outputs = len(output_names)
y_pred = nest.flatten(y_pred)
loss = self._user_loss
@ -561,11 +586,13 @@ class CompileLoss(losses_module.Loss):
if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
# For multi-outputs, repeat sample weights for n outputs.
if len(sample_weight) < len(y_true):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
else:
sample_weight = [None for _ in y_true]
loss_values = []
for loss, y_t, y_p, loss_weight, sample_weight in zip(
self.flat_losses,
y_true,

@ -115,21 +115,21 @@ class TestCompileMetrics(testing.TestCase):
metrics={
"output_1": [
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(name="mse"),
],
"output_2": [
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(name="mse"),
],
},
weighted_metrics={
"output_1": [
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(name="mse"),
],
"output_2": [
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(),
metrics_module.MeanSquaredError(name="mse"),
],
},
)
@ -169,15 +169,32 @@ class TestCompileMetrics(testing.TestCase):
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertEqual(len(result), 8)
self.assertAllClose(result["mean_squared_error"], 0.055833336)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0725)
# Result values obtained from `tf.keras`
# m = tf.keras.metrics.MeanSquaredError()
# m.update_state(y_true, y_pred1, sample_weight=weight)
# m.update_state(y_true, y_pred2, sample_weight=weight)
# m.result().numpy()
self.assertAllClose(result["output_1_mean_squared_error"], 0.055833336)
self.assertAllClose(result["output_2_mean_squared_error"], 0.055833336)
self.assertAllClose(result["output_1_mse"], 0.055833336)
self.assertAllClose(result["output_2_mse"], 0.055833336)
self.assertAllClose(
result["weighted_output_1_mean_squared_error"], 0.0725
)
self.assertAllClose(
result["weighted_output_2_mean_squared_error"], 0.0725
)
self.assertAllClose(result["weighted_output_1_mse"], 0.0725)
self.assertAllClose(result["weighted_output_2_mse"], 0.0725)
compile_metrics.reset_state()
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertEqual(len(result), 8)
self.assertAllClose(result["mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0)
self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0)
def test_name_conversions(self):
compile_metrics = CompileMetrics(

@ -34,12 +34,20 @@ class Trainer:
jit_compile="auto",
):
self.optimizer = optimizers.get(optimizer)
if hasattr(self, "output_names"):
output_names = self.output_names
else:
output_names = None
if loss is not None:
self._compile_loss = CompileLoss(loss, loss_weights)
self._compile_loss = CompileLoss(
loss, loss_weights, output_names=output_names
)
else:
self._compile_loss = None
if metrics is not None:
self._compile_metrics = CompileMetrics(metrics, weighted_metrics)
self._compile_metrics = CompileMetrics(
metrics, weighted_metrics, output_names=output_names
)
else:
self._compile_metrics = None
if jit_compile == "auto":