Use custom_objects to deserialize Lambda functions. (#4770)

This commit is contained in:
Gijs van Tulder 2016-12-19 21:15:06 +01:00 committed by François Chollet
parent 1278bf9cfa
commit 1fcb74f218
3 changed files with 50 additions and 7 deletions

@ -622,20 +622,27 @@ class Lambda(Layer):
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
def from_config(cls, config, custom_objects={}):
# Insert custom objects into globals.
if custom_objects:
globs = globals().copy()
globs.update(custom_objects)
else:
globs = globals()
function_type = config.pop('function_type')
if function_type == 'function':
function = globals()[config['function']]
function = globs[config['function']]
elif function_type == 'lambda':
function = func_load(config['function'], globs=globals())
function = func_load(config['function'], globs=globs)
else:
raise TypeError('Unknown function type:', function_type)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
output_shape = globals()[config['output_shape']]
output_shape = globs[config['output_shape']]
elif output_shape_type == 'lambda':
output_shape = func_load(config['output_shape'], globs=globals())
output_shape = func_load(config['output_shape'], globs=globs)
else:
output_shape = config['output_shape']

@ -1,4 +1,5 @@
from __future__ import print_function
import inspect
from .generic_utils import get_from_module
from .np_utils import convert_kernel
@ -31,7 +32,12 @@ def layer_from_config(config, custom_objects={}):
else:
layer_class = get_from_module(class_name, globals(), 'layer',
instantiate=False)
return layer_class.from_config(config['config'])
arg_spec = inspect.getargspec(layer_class.from_config)
if 'custom_objects' in arg_spec.args:
return layer_class.from_config(config['config'], custom_objects=custom_objects)
else:
return layer_class.from_config(config['config'])
def print_summary(layers, relevant_nodes=None,

@ -5,7 +5,7 @@ import numpy as np
from numpy.testing import assert_allclose
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed
from keras.layers import Dense, Dropout, Lambda, RepeatVector, TimeDistributed
from keras.layers import Input
from keras import optimizers
from keras import objectives
@ -232,5 +232,35 @@ def test_loading_weights_by_name_2():
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
# a function to be called from the Lambda layer
def square_fn(x):
return x * x
@keras_test
def test_saving_lambda_custom_objects():
input = Input(shape=(3,))
x = Lambda(lambda x: square_fn(x), output_shape=(3,))(input)
output = Dense(3)(x)
model = Model(input, output)
model.compile(loss=objectives.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)
model = load_model(fname, custom_objects={'square_fn': square_fn})
os.remove(fname)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
if __name__ == '__main__':
pytest.main([__file__])