Compat fix

This commit is contained in:
Francois Chollet 2023-06-28 09:44:14 -07:00
parent 1d81c47283
commit 93e3eb6b43
3 changed files with 14 additions and 53 deletions

@ -107,14 +107,7 @@ class CustomModel(keras.Model):
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
"""
@ -258,14 +251,7 @@ class CustomModel(keras.Model):
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
@ -304,14 +290,7 @@ class CustomModel(keras.Model):
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel

@ -115,14 +115,7 @@ class CustomModel(keras.Model):
# Return a dict mapping metric names to current value
# Note that it will include the loss (tracked in self.metrics).
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
"""
@ -278,14 +271,7 @@ class CustomModel(keras.Model):
# Return a dict mapping metric names to current value
# Note that it will include the loss (tracked in self.metrics).
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
@ -324,14 +310,7 @@ class CustomModel(keras.Model):
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
metric_values = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
metric_values.update(result)
else:
metric_values[metric.name] = result
return metric_values
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel

@ -714,11 +714,14 @@ class Trainer:
def _pythonify_logs(self, logs):
result = {}
for key, value in sorted(logs.items()):
try:
value = float(value)
except:
pass
result[key] = value
if isinstance(value, dict):
result.update(self._pythonify_logs(value))
else:
try:
value = float(value)
except:
pass
result[key] = value
return result
def _flatten_metrics_in_order(self, logs):