diff --git a/keras/layers/core.py b/keras/layers/core.py index 86b9ab771..17689cfb8 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -622,20 +622,27 @@ class Lambda(Layer): return dict(list(base_config.items()) + list(config.items())) @classmethod - def from_config(cls, config): + def from_config(cls, config, custom_objects={}): + # Insert custom objects into globals. + if custom_objects: + globs = globals().copy() + globs.update(custom_objects) + else: + globs = globals() + function_type = config.pop('function_type') if function_type == 'function': - function = globals()[config['function']] + function = globs[config['function']] elif function_type == 'lambda': - function = func_load(config['function'], globs=globals()) + function = func_load(config['function'], globs=globs) else: raise TypeError('Unknown function type:', function_type) output_shape_type = config.pop('output_shape_type') if output_shape_type == 'function': - output_shape = globals()[config['output_shape']] + output_shape = globs[config['output_shape']] elif output_shape_type == 'lambda': - output_shape = func_load(config['output_shape'], globs=globals()) + output_shape = func_load(config['output_shape'], globs=globs) else: output_shape = config['output_shape'] diff --git a/keras/utils/layer_utils.py b/keras/utils/layer_utils.py index d67ce9c7e..26804453a 100644 --- a/keras/utils/layer_utils.py +++ b/keras/utils/layer_utils.py @@ -1,4 +1,5 @@ from __future__ import print_function +import inspect from .generic_utils import get_from_module from .np_utils import convert_kernel @@ -31,7 +32,12 @@ def layer_from_config(config, custom_objects={}): else: layer_class = get_from_module(class_name, globals(), 'layer', instantiate=False) - return layer_class.from_config(config['config']) + + arg_spec = inspect.getargspec(layer_class.from_config) + if 'custom_objects' in arg_spec.args: + return layer_class.from_config(config['config'], custom_objects=custom_objects) + else: + return layer_class.from_config(config['config']) def print_summary(layers, relevant_nodes=None, diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py index 3610f2868..0d4649beb 100644 --- a/tests/test_model_saving.py +++ b/tests/test_model_saving.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_allclose from keras.models import Model, Sequential -from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed +from keras.layers import Dense, Dropout, Lambda, RepeatVector, TimeDistributed from keras.layers import Input from keras import optimizers from keras import objectives @@ -232,5 +232,35 @@ def test_loading_weights_by_name_2(): assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0 +# a function to be called from the Lambda layer +def square_fn(x): + return x * x + + +@keras_test +def test_saving_lambda_custom_objects(): + input = Input(shape=(3,)) + x = Lambda(lambda x: square_fn(x), output_shape=(3,))(input) + output = Dense(3)(x) + + model = Model(input, output) + model.compile(loss=objectives.MSE, + optimizer=optimizers.RMSprop(lr=0.0001), + metrics=[metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + _, fname = tempfile.mkstemp('.h5') + save_model(model, fname) + + model = load_model(fname, custom_objects={'square_fn': square_fn}) + os.remove(fname) + + out2 = model.predict(x) + assert_allclose(out, out2, atol=1e-05) + + if __name__ == '__main__': pytest.main([__file__])