Fix load_model for multiple output metrics in dictionary (#6122)
load_model fails when a model has multiple output layers that have more than one metric. Solve this problem by adding a clause that checks if metrics are a list. For more elaborate description see issue #3958 Include a unit test confirming that model with multiple outputs that have more than one metric can indeed be saved and reloaded.
This commit is contained in:
parent
4fe78f3400
commit
0930ca9eb7
@ -213,7 +213,14 @@ def load_model(filepath, custom_objects=None):
|
||||
if isinstance(obj, dict):
|
||||
deserialized = {}
|
||||
for key, value in obj.items():
|
||||
if value in custom_objects:
|
||||
deserialized[key] = []
|
||||
if isinstance(value, list):
|
||||
for element in value:
|
||||
if element in custom_objects:
|
||||
deserialized[key].append(custom_objects[element])
|
||||
else:
|
||||
deserialized[key].append(element)
|
||||
elif value in custom_objects:
|
||||
deserialized[key] = custom_objects[value]
|
||||
else:
|
||||
deserialized[key] = value
|
||||
|
@ -100,6 +100,35 @@ def test_fuctional_model_saving():
|
||||
assert_allclose(out, out2, atol=1e-05)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_saving_multiple_metrics_outputs():
|
||||
input = Input(shape=(5,))
|
||||
x = Dense(5)(input)
|
||||
output1 = Dense(1, name='output1')(x)
|
||||
output2 = Dense(1, name='output2')(x)
|
||||
|
||||
model = Model(inputs=input, outputs=[output1, output2])
|
||||
|
||||
metrics = {'output1': ['mse', 'binary_accuracy'],
|
||||
'output2': ['mse', 'binary_accuracy']
|
||||
}
|
||||
loss = {'output1': 'mse', 'output2': 'mse'}
|
||||
|
||||
model.compile(loss=loss, optimizer='sgd', metrics=metrics)
|
||||
|
||||
# assure that model is working
|
||||
x = np.array([[1, 1, 1, 1, 1]])
|
||||
out = model.predict(x)
|
||||
_, fname = tempfile.mkstemp('.h5')
|
||||
save_model(model, fname)
|
||||
|
||||
model = load_model(fname)
|
||||
os.remove(fname)
|
||||
|
||||
out2 = model.predict(x)
|
||||
assert_allclose(out, out2, atol=1e-05)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_saving_without_compilation():
|
||||
model = Sequential()
|
||||
|
Loading…
Reference in New Issue
Block a user