From 0728d4541497ddd5499bb3c4f7e2555fafd14cb5 Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Thu, 22 Jun 2023 01:21:55 +0530 Subject: [PATCH] Update compile loss and metrics to handle Dict and List Outputs (#362) * Update compile loss and metrics to handle multi-output dict and list * CompileMetrics to match with per-output metrics * Add model tests for invalid cases --- keras_core/models/functional.py | 2 + keras_core/models/model_test.py | 288 +++++++++++++++++++++- keras_core/trainers/compile_utils.py | 55 +++-- keras_core/trainers/compile_utils_test.py | 33 ++- keras_core/trainers/trainer.py | 12 +- 5 files changed, 355 insertions(+), 35 deletions(-) diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 8d7895bf1..20ea429a8 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -155,6 +155,8 @@ class Functional(Function, Model): # We will convert directly (to the correct dtype per input). self._convert_input_args = False self._allow_non_tensor_positional_args = True + output_layers = [x._keras_history[0] for x in self.outputs] + self.output_names = [x.name for x in output_layers] self._post_build() @property diff --git a/keras_core/models/model_test.py b/keras_core/models/model_test.py index a91db7fb7..abed61432 100644 --- a/keras_core/models/model_test.py +++ b/keras_core/models/model_test.py @@ -8,22 +8,47 @@ from keras_core.models.model import Model from keras_core.models.model import model_from_json -class ModelTest(testing.TestCase): - def _get_model(self): - input_a = Input(shape=(3,), batch_size=2, name="input_a") - input_b = Input(shape=(3,), batch_size=2, name="input_b") - x = input_a + input_b - x = layers.Dense(5)(x) - outputs = layers.Dense(4)(x) - model = Model([input_a, input_b], outputs) - return model +def _get_model(): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Model([input_a, input_b], outputs) + return model + +def _get_model_multi_outputs_list(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + + +def _get_model_multi_outputs_list_no_output_names(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1)(x) + output_b = layers.Dense(1, activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + + +def _get_model_multi_outputs_dict(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, {"output_a": output_a, "output_b": output_b}) + return model + + +class ModelTest(testing.TestCase): def test_functional_rerouting(self): - model = self._get_model() + model = _get_model() self.assertTrue(isinstance(model, Functional)) def test_json_serialization(self): - model = self._get_model() + model = _get_model() json_string = model.to_json() new_model = model_from_json(json_string) self.assertEqual(json_string, new_model.to_json()) @@ -65,3 +90,244 @@ class ModelTest(testing.TestCase): config, custom_objects={"CustomDense": CustomDense} ) self.assertTrue(isinstance(new_model, Functional)) + + def test_functional_list_outputs_list_losses(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=["mean_squared_error", "binary_crossentropy"], + metrics=[ + ["mean_squared_error"], + ["mean_squared_error", "accuracy"], + ], + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_dict_outputs_dict_losses(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + {"output_a": y1, "output_b": y2}, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_metrics(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_partial_metrics(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_losses_no_output_names(self): + model = _get_model_multi_outputs_list_no_output_names() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={"output_a": "mean_squared_error"}, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_a' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) diff --git a/keras_core/trainers/compile_utils.py b/keras_core/trainers/compile_utils.py index 6a60e0918..9980833f2 100644 --- a/keras_core/trainers/compile_utils.py +++ b/keras_core/trainers/compile_utils.py @@ -54,7 +54,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred): return is_binary, is_sparse_categorical -def get_metric(identifier, y_true, y_pred): +def get_metric(identifier, y_true, y_pred, name_prefix=None): if identifier is None: return None # Ok to have no metric for an output. @@ -85,6 +85,8 @@ def get_metric(identifier, y_true, y_pred): metric_obj = metrics_module.MeanMetricWrapper( metric_obj, name=metric_name ) + if name_prefix and not metric_obj.name.startswith(name_prefix): + metric_obj.name = "_".join([name_prefix, metric_obj.name]) return metric_obj @@ -117,7 +119,13 @@ def get_loss(identifier, y_true, y_pred): class CompileMetrics(metrics_module.Metric): - def __init__(self, metrics, weighted_metrics, name="compile_metric"): + def __init__( + self, + metrics, + weighted_metrics, + name="compile_metric", + output_names=None, + ): super().__init__(name=name) if metrics and not isinstance(metrics, (list, tuple, dict)): raise ValueError( @@ -136,6 +144,7 @@ class CompileMetrics(metrics_module.Metric): self._user_weighted_metrics = weighted_metrics self.built = False self.name = "compile_metrics" + self.output_names = output_names @property def variables(self): @@ -150,9 +159,10 @@ class CompileMetrics(metrics_module.Metric): return vars def build(self, y_true, y_pred): - if isinstance(y_pred, dict): + if self.output_names: + output_names = self.output_names + elif isinstance(y_pred, dict): output_names = sorted(list(y_pred.keys())) - num_outputs = len(output_names) elif isinstance(y_pred, (list, tuple)): num_outputs = len(y_pred) if all(hasattr(x, "_keras_history") for x in y_pred): @@ -162,15 +172,14 @@ class CompileMetrics(metrics_module.Metric): else: output_names = None num_outputs = 1 + if output_names: + num_outputs = len(output_names) y_pred = nest.flatten(y_pred) y_true = nest.flatten(y_true) metrics = self._user_metrics weighted_metrics = self._user_weighted_metrics - if output_names and not num_outputs: - num_outputs = len(output_names) - self._flat_metrics = self._build_metrics_set( metrics, num_outputs, @@ -238,7 +247,12 @@ class CompileMetrics(metrics_module.Metric): "(the list of metrics corresponding to that output). " f"Received:\n{argument_name}={metrics}" ) - for mls, yt, yp in zip(metrics, y_true, y_pred): + name = None + for idx, (mls, yt, yp) in enumerate( + zip(metrics, y_true, y_pred) + ): + if output_names: + name = output_names[idx] if not all(is_function_like(e) for e in mls): raise ValueError( f"All entries in the sublists of the " @@ -249,7 +263,7 @@ class CompileMetrics(metrics_module.Metric): flat_metrics.append( MetricsList( [ - get_metric(m, yt, yp) + get_metric(m, yt, yp, name) for m in mls if m is not None ] @@ -290,7 +304,7 @@ class CompileMetrics(metrics_module.Metric): flat_metrics.append( MetricsList( [ - get_metric(m, yt, yp) + get_metric(m, yt, yp, name) for m in metrics[name] if m is not None ] @@ -310,6 +324,9 @@ class CompileMetrics(metrics_module.Metric): m.update_state(y_t, y_p) if sample_weight is not None: sample_weight = nest.flatten(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in range(len(y_true))] for m, y_t, y_p, s_w in zip( @@ -375,7 +392,11 @@ class CompileMetrics(metrics_module.Metric): class CompileLoss(losses_module.Loss): def __init__( - self, loss, loss_weights=None, reduction="sum_over_batch_size" + self, + loss, + loss_weights=None, + reduction="sum_over_batch_size", + output_names=None, ): if loss_weights and not isinstance(loss_weights, (list, tuple, dict)): raise ValueError( @@ -386,12 +407,14 @@ class CompileLoss(losses_module.Loss): self._user_loss = loss self._user_loss_weights = loss_weights self.built = False + self.output_names = output_names super().__init__(name="compile_loss", reduction=reduction) def build(self, y_true, y_pred): - if isinstance(y_pred, dict): + if self.output_names: + output_names = self.output_names + elif isinstance(y_pred, dict): output_names = sorted(list(y_pred.keys())) - num_outputs = len(output_names) elif isinstance(y_pred, (list, tuple)): num_outputs = len(y_pred) if all(hasattr(x, "_keras_history") for x in y_pred): @@ -401,6 +424,8 @@ class CompileLoss(losses_module.Loss): else: output_names = None num_outputs = 1 + if output_names: + num_outputs = len(output_names) y_pred = nest.flatten(y_pred) loss = self._user_loss @@ -561,11 +586,13 @@ class CompileLoss(losses_module.Loss): if sample_weight is not None: sample_weight = nest.flatten(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in y_true] loss_values = [] - for loss, y_t, y_p, loss_weight, sample_weight in zip( self.flat_losses, y_true, diff --git a/keras_core/trainers/compile_utils_test.py b/keras_core/trainers/compile_utils_test.py index 95e482799..853b95574 100644 --- a/keras_core/trainers/compile_utils_test.py +++ b/keras_core/trainers/compile_utils_test.py @@ -115,21 +115,21 @@ class TestCompileMetrics(testing.TestCase): metrics={ "output_1": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], "output_2": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], }, weighted_metrics={ "output_1": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], "output_2": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], }, ) @@ -169,15 +169,32 @@ class TestCompileMetrics(testing.TestCase): result = compile_metrics.result() self.assertTrue(isinstance(result, dict)) self.assertEqual(len(result), 8) - self.assertAllClose(result["mean_squared_error"], 0.055833336) - self.assertAllClose(result["weighted_mean_squared_error"], 0.0725) + # Result values obtained from `tf.keras` + # m = tf.keras.metrics.MeanSquaredError() + # m.update_state(y_true, y_pred1, sample_weight=weight) + # m.update_state(y_true, y_pred2, sample_weight=weight) + # m.result().numpy() + self.assertAllClose(result["output_1_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_2_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_1_mse"], 0.055833336) + self.assertAllClose(result["output_2_mse"], 0.055833336) + self.assertAllClose( + result["weighted_output_1_mean_squared_error"], 0.0725 + ) + self.assertAllClose( + result["weighted_output_2_mean_squared_error"], 0.0725 + ) + self.assertAllClose(result["weighted_output_1_mse"], 0.0725) + self.assertAllClose(result["weighted_output_2_mse"], 0.0725) compile_metrics.reset_state() result = compile_metrics.result() self.assertTrue(isinstance(result, dict)) self.assertEqual(len(result), 8) - self.assertAllClose(result["mean_squared_error"], 0.0) - self.assertAllClose(result["weighted_mean_squared_error"], 0.0) + self.assertAllClose(result["output_1_mean_squared_error"], 0.0) + self.assertAllClose(result["output_2_mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0) def test_name_conversions(self): compile_metrics = CompileMetrics( diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 5f8c29a0e..b15ad1611 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -34,12 +34,20 @@ class Trainer: jit_compile="auto", ): self.optimizer = optimizers.get(optimizer) + if hasattr(self, "output_names"): + output_names = self.output_names + else: + output_names = None if loss is not None: - self._compile_loss = CompileLoss(loss, loss_weights) + self._compile_loss = CompileLoss( + loss, loss_weights, output_names=output_names + ) else: self._compile_loss = None if metrics is not None: - self._compile_metrics = CompileMetrics(metrics, weighted_metrics) + self._compile_metrics = CompileMetrics( + metrics, weighted_metrics, output_names=output_names + ) else: self._compile_metrics = None if jit_compile == "auto":