Update compile loss and metrics to handle Dict and List Outputs (#362)
* Update compile loss and metrics to handle multi-output dict and list * CompileMetrics to match with per-output metrics * Add model tests for invalid cases
This commit is contained in:
parent
d955292989
commit
0728d45414
@ -155,6 +155,8 @@ class Functional(Function, Model):
|
||||
# We will convert directly (to the correct dtype per input).
|
||||
self._convert_input_args = False
|
||||
self._allow_non_tensor_positional_args = True
|
||||
output_layers = [x._keras_history[0] for x in self.outputs]
|
||||
self.output_names = [x.name for x in output_layers]
|
||||
self._post_build()
|
||||
|
||||
@property
|
||||
|
@ -8,22 +8,47 @@ from keras_core.models.model import Model
|
||||
from keras_core.models.model import model_from_json
|
||||
|
||||
|
||||
class ModelTest(testing.TestCase):
|
||||
def _get_model(self):
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Model([input_a, input_b], outputs)
|
||||
return model
|
||||
def _get_model():
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Model([input_a, input_b], outputs)
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_multi_outputs_list():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x)
|
||||
model = Model(x, [output_a, output_b])
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_multi_outputs_list_no_output_names():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1)(x)
|
||||
output_b = layers.Dense(1, activation="sigmoid")(x)
|
||||
model = Model(x, [output_a, output_b])
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_multi_outputs_dict():
|
||||
x = Input(shape=(3,), name="input_a")
|
||||
output_a = layers.Dense(1, name="output_a")(x)
|
||||
output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x)
|
||||
model = Model(x, {"output_a": output_a, "output_b": output_b})
|
||||
return model
|
||||
|
||||
|
||||
class ModelTest(testing.TestCase):
|
||||
def test_functional_rerouting(self):
|
||||
model = self._get_model()
|
||||
model = _get_model()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
|
||||
def test_json_serialization(self):
|
||||
model = self._get_model()
|
||||
model = _get_model()
|
||||
json_string = model.to_json()
|
||||
new_model = model_from_json(json_string)
|
||||
self.assertEqual(json_string, new_model.to_json())
|
||||
@ -65,3 +90,244 @@ class ModelTest(testing.TestCase):
|
||||
config, custom_objects={"CustomDense": CustomDense}
|
||||
)
|
||||
self.assertTrue(isinstance(new_model, Functional))
|
||||
|
||||
def test_functional_list_outputs_list_losses(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss=["mean_squared_error", "binary_crossentropy"],
|
||||
metrics=[
|
||||
["mean_squared_error"],
|
||||
["mean_squared_error", "accuracy"],
|
||||
],
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_dict_outputs_dict_losses(self):
|
||||
model = _get_model_multi_outputs_dict()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(
|
||||
x,
|
||||
{"output_a": y1, "output_b": y2},
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
verbose=0,
|
||||
)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_dict_losses_metrics(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_a": ["mean_squared_error"],
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_a_mean_squared_error",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_dict_losses_partial_metrics(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_b": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
hist_keys = sorted(hist.history.keys())
|
||||
# TODO `tf.keras` also outputs individual losses for outputs
|
||||
ref_keys = sorted(
|
||||
[
|
||||
"loss",
|
||||
# "output_a_loss",
|
||||
"output_b_accuracy",
|
||||
# "output_b_loss",
|
||||
"output_b_mean_squared_error",
|
||||
]
|
||||
)
|
||||
self.assertListEqual(hist_keys, ref_keys)
|
||||
|
||||
def test_functional_list_outputs_dict_losses_invalid_keys(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_c": "binary_crossentropy",
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"In the dict argument `loss`, "
|
||||
"key 'output_c' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
||||
def test_functional_list_outputs_dict_losses_no_output_names(self):
|
||||
model = _get_model_multi_outputs_list_no_output_names()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={"output_a": "mean_squared_error"},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"In the dict argument `loss`, "
|
||||
"key 'output_a' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
||||
def test_functional_list_outputs_dict_metrics_invalid_keys(self):
|
||||
model = _get_model_multi_outputs_list()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_c": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"In the dict argument `metrics`, "
|
||||
"key 'output_c' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
||||
def test_functional_dict_outputs_dict_losses_invalid_keys(self):
|
||||
model = _get_model_multi_outputs_dict()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_c": "binary_crossentropy",
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"In the dict argument `loss`, "
|
||||
"key 'output_c' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
||||
def test_functional_dict_outputs_dict_metrics_invalid_keys(self):
|
||||
model = _get_model_multi_outputs_dict()
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
x = np.random.rand(8, 3)
|
||||
y1 = np.random.rand(8, 1)
|
||||
y2 = np.random.randint(0, 2, (8, 1))
|
||||
model.compile(
|
||||
optimizer="sgd",
|
||||
loss={
|
||||
"output_a": "mean_squared_error",
|
||||
"output_b": "binary_crossentropy",
|
||||
},
|
||||
metrics={
|
||||
"output_c": ["mean_squared_error", "accuracy"],
|
||||
},
|
||||
)
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"In the dict argument `metrics`, "
|
||||
"key 'output_c' does not correspond to any model output",
|
||||
):
|
||||
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
|
||||
|
@ -54,7 +54,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred):
|
||||
return is_binary, is_sparse_categorical
|
||||
|
||||
|
||||
def get_metric(identifier, y_true, y_pred):
|
||||
def get_metric(identifier, y_true, y_pred, name_prefix=None):
|
||||
if identifier is None:
|
||||
return None # Ok to have no metric for an output.
|
||||
|
||||
@ -85,6 +85,8 @@ def get_metric(identifier, y_true, y_pred):
|
||||
metric_obj = metrics_module.MeanMetricWrapper(
|
||||
metric_obj, name=metric_name
|
||||
)
|
||||
if name_prefix and not metric_obj.name.startswith(name_prefix):
|
||||
metric_obj.name = "_".join([name_prefix, metric_obj.name])
|
||||
return metric_obj
|
||||
|
||||
|
||||
@ -117,7 +119,13 @@ def get_loss(identifier, y_true, y_pred):
|
||||
|
||||
|
||||
class CompileMetrics(metrics_module.Metric):
|
||||
def __init__(self, metrics, weighted_metrics, name="compile_metric"):
|
||||
def __init__(
|
||||
self,
|
||||
metrics,
|
||||
weighted_metrics,
|
||||
name="compile_metric",
|
||||
output_names=None,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
if metrics and not isinstance(metrics, (list, tuple, dict)):
|
||||
raise ValueError(
|
||||
@ -136,6 +144,7 @@ class CompileMetrics(metrics_module.Metric):
|
||||
self._user_weighted_metrics = weighted_metrics
|
||||
self.built = False
|
||||
self.name = "compile_metrics"
|
||||
self.output_names = output_names
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
@ -150,9 +159,10 @@ class CompileMetrics(metrics_module.Metric):
|
||||
return vars
|
||||
|
||||
def build(self, y_true, y_pred):
|
||||
if isinstance(y_pred, dict):
|
||||
if self.output_names:
|
||||
output_names = self.output_names
|
||||
elif isinstance(y_pred, dict):
|
||||
output_names = sorted(list(y_pred.keys()))
|
||||
num_outputs = len(output_names)
|
||||
elif isinstance(y_pred, (list, tuple)):
|
||||
num_outputs = len(y_pred)
|
||||
if all(hasattr(x, "_keras_history") for x in y_pred):
|
||||
@ -162,15 +172,14 @@ class CompileMetrics(metrics_module.Metric):
|
||||
else:
|
||||
output_names = None
|
||||
num_outputs = 1
|
||||
if output_names:
|
||||
num_outputs = len(output_names)
|
||||
|
||||
y_pred = nest.flatten(y_pred)
|
||||
y_true = nest.flatten(y_true)
|
||||
|
||||
metrics = self._user_metrics
|
||||
weighted_metrics = self._user_weighted_metrics
|
||||
if output_names and not num_outputs:
|
||||
num_outputs = len(output_names)
|
||||
|
||||
self._flat_metrics = self._build_metrics_set(
|
||||
metrics,
|
||||
num_outputs,
|
||||
@ -238,7 +247,12 @@ class CompileMetrics(metrics_module.Metric):
|
||||
"(the list of metrics corresponding to that output). "
|
||||
f"Received:\n{argument_name}={metrics}"
|
||||
)
|
||||
for mls, yt, yp in zip(metrics, y_true, y_pred):
|
||||
name = None
|
||||
for idx, (mls, yt, yp) in enumerate(
|
||||
zip(metrics, y_true, y_pred)
|
||||
):
|
||||
if output_names:
|
||||
name = output_names[idx]
|
||||
if not all(is_function_like(e) for e in mls):
|
||||
raise ValueError(
|
||||
f"All entries in the sublists of the "
|
||||
@ -249,7 +263,7 @@ class CompileMetrics(metrics_module.Metric):
|
||||
flat_metrics.append(
|
||||
MetricsList(
|
||||
[
|
||||
get_metric(m, yt, yp)
|
||||
get_metric(m, yt, yp, name)
|
||||
for m in mls
|
||||
if m is not None
|
||||
]
|
||||
@ -290,7 +304,7 @@ class CompileMetrics(metrics_module.Metric):
|
||||
flat_metrics.append(
|
||||
MetricsList(
|
||||
[
|
||||
get_metric(m, yt, yp)
|
||||
get_metric(m, yt, yp, name)
|
||||
for m in metrics[name]
|
||||
if m is not None
|
||||
]
|
||||
@ -310,6 +324,9 @@ class CompileMetrics(metrics_module.Metric):
|
||||
m.update_state(y_t, y_p)
|
||||
if sample_weight is not None:
|
||||
sample_weight = nest.flatten(sample_weight)
|
||||
# For multi-outputs, repeat sample weights for n outputs.
|
||||
if len(sample_weight) < len(y_true):
|
||||
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
|
||||
else:
|
||||
sample_weight = [None for _ in range(len(y_true))]
|
||||
for m, y_t, y_p, s_w in zip(
|
||||
@ -375,7 +392,11 @@ class CompileMetrics(metrics_module.Metric):
|
||||
|
||||
class CompileLoss(losses_module.Loss):
|
||||
def __init__(
|
||||
self, loss, loss_weights=None, reduction="sum_over_batch_size"
|
||||
self,
|
||||
loss,
|
||||
loss_weights=None,
|
||||
reduction="sum_over_batch_size",
|
||||
output_names=None,
|
||||
):
|
||||
if loss_weights and not isinstance(loss_weights, (list, tuple, dict)):
|
||||
raise ValueError(
|
||||
@ -386,12 +407,14 @@ class CompileLoss(losses_module.Loss):
|
||||
self._user_loss = loss
|
||||
self._user_loss_weights = loss_weights
|
||||
self.built = False
|
||||
self.output_names = output_names
|
||||
super().__init__(name="compile_loss", reduction=reduction)
|
||||
|
||||
def build(self, y_true, y_pred):
|
||||
if isinstance(y_pred, dict):
|
||||
if self.output_names:
|
||||
output_names = self.output_names
|
||||
elif isinstance(y_pred, dict):
|
||||
output_names = sorted(list(y_pred.keys()))
|
||||
num_outputs = len(output_names)
|
||||
elif isinstance(y_pred, (list, tuple)):
|
||||
num_outputs = len(y_pred)
|
||||
if all(hasattr(x, "_keras_history") for x in y_pred):
|
||||
@ -401,6 +424,8 @@ class CompileLoss(losses_module.Loss):
|
||||
else:
|
||||
output_names = None
|
||||
num_outputs = 1
|
||||
if output_names:
|
||||
num_outputs = len(output_names)
|
||||
|
||||
y_pred = nest.flatten(y_pred)
|
||||
loss = self._user_loss
|
||||
@ -561,11 +586,13 @@ class CompileLoss(losses_module.Loss):
|
||||
|
||||
if sample_weight is not None:
|
||||
sample_weight = nest.flatten(sample_weight)
|
||||
# For multi-outputs, repeat sample weights for n outputs.
|
||||
if len(sample_weight) < len(y_true):
|
||||
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
|
||||
else:
|
||||
sample_weight = [None for _ in y_true]
|
||||
|
||||
loss_values = []
|
||||
|
||||
for loss, y_t, y_p, loss_weight, sample_weight in zip(
|
||||
self.flat_losses,
|
||||
y_true,
|
||||
|
@ -115,21 +115,21 @@ class TestCompileMetrics(testing.TestCase):
|
||||
metrics={
|
||||
"output_1": [
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(name="mse"),
|
||||
],
|
||||
"output_2": [
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(name="mse"),
|
||||
],
|
||||
},
|
||||
weighted_metrics={
|
||||
"output_1": [
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(name="mse"),
|
||||
],
|
||||
"output_2": [
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(),
|
||||
metrics_module.MeanSquaredError(name="mse"),
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -169,15 +169,32 @@ class TestCompileMetrics(testing.TestCase):
|
||||
result = compile_metrics.result()
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(len(result), 8)
|
||||
self.assertAllClose(result["mean_squared_error"], 0.055833336)
|
||||
self.assertAllClose(result["weighted_mean_squared_error"], 0.0725)
|
||||
# Result values obtained from `tf.keras`
|
||||
# m = tf.keras.metrics.MeanSquaredError()
|
||||
# m.update_state(y_true, y_pred1, sample_weight=weight)
|
||||
# m.update_state(y_true, y_pred2, sample_weight=weight)
|
||||
# m.result().numpy()
|
||||
self.assertAllClose(result["output_1_mean_squared_error"], 0.055833336)
|
||||
self.assertAllClose(result["output_2_mean_squared_error"], 0.055833336)
|
||||
self.assertAllClose(result["output_1_mse"], 0.055833336)
|
||||
self.assertAllClose(result["output_2_mse"], 0.055833336)
|
||||
self.assertAllClose(
|
||||
result["weighted_output_1_mean_squared_error"], 0.0725
|
||||
)
|
||||
self.assertAllClose(
|
||||
result["weighted_output_2_mean_squared_error"], 0.0725
|
||||
)
|
||||
self.assertAllClose(result["weighted_output_1_mse"], 0.0725)
|
||||
self.assertAllClose(result["weighted_output_2_mse"], 0.0725)
|
||||
|
||||
compile_metrics.reset_state()
|
||||
result = compile_metrics.result()
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(len(result), 8)
|
||||
self.assertAllClose(result["mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["output_1_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["output_2_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0)
|
||||
self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0)
|
||||
|
||||
def test_name_conversions(self):
|
||||
compile_metrics = CompileMetrics(
|
||||
|
@ -34,12 +34,20 @@ class Trainer:
|
||||
jit_compile="auto",
|
||||
):
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
if hasattr(self, "output_names"):
|
||||
output_names = self.output_names
|
||||
else:
|
||||
output_names = None
|
||||
if loss is not None:
|
||||
self._compile_loss = CompileLoss(loss, loss_weights)
|
||||
self._compile_loss = CompileLoss(
|
||||
loss, loss_weights, output_names=output_names
|
||||
)
|
||||
else:
|
||||
self._compile_loss = None
|
||||
if metrics is not None:
|
||||
self._compile_metrics = CompileMetrics(metrics, weighted_metrics)
|
||||
self._compile_metrics = CompileMetrics(
|
||||
metrics, weighted_metrics, output_names=output_names
|
||||
)
|
||||
else:
|
||||
self._compile_metrics = None
|
||||
if jit_compile == "auto":
|
||||
|
Loading…
Reference in New Issue
Block a user