This commit is contained in:
Francois Chollet 2016-12-19 14:17:50 -08:00
commit d137d00182
3 changed files with 50 additions and 7 deletions

@ -622,20 +622,27 @@ class Lambda(Layer):
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
def from_config(cls, config, custom_objects={}):
# Insert custom objects into globals.
if custom_objects:
globs = globals().copy()
globs.update(custom_objects)
else:
globs = globals()
function_type = config.pop('function_type')
if function_type == 'function':
function = globals()[config['function']]
function = globs[config['function']]
elif function_type == 'lambda':
function = func_load(config['function'], globs=globals())
function = func_load(config['function'], globs=globs)
else:
raise TypeError('Unknown function type:', function_type)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
output_shape = globals()[config['output_shape']]
output_shape = globs[config['output_shape']]
elif output_shape_type == 'lambda':
output_shape = func_load(config['output_shape'], globs=globals())
output_shape = func_load(config['output_shape'], globs=globs)
else:
output_shape = config['output_shape']

@ -1,4 +1,5 @@
from __future__ import print_function
import inspect
from .generic_utils import get_from_module
from .np_utils import convert_kernel
@ -31,7 +32,12 @@ def layer_from_config(config, custom_objects={}):
else:
layer_class = get_from_module(class_name, globals(), 'layer',
instantiate=False)
return layer_class.from_config(config['config'])
arg_spec = inspect.getargspec(layer_class.from_config)
if 'custom_objects' in arg_spec.args:
return layer_class.from_config(config['config'], custom_objects=custom_objects)
else:
return layer_class.from_config(config['config'])
def print_summary(layers, relevant_nodes=None,

@ -5,7 +5,7 @@ import numpy as np
from numpy.testing import assert_allclose
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed
from keras.layers import Dense, Dropout, Lambda, RepeatVector, TimeDistributed
from keras.layers import Input
from keras import optimizers
from keras import objectives
@ -232,5 +232,35 @@ def test_loading_weights_by_name_2():
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
# a function to be called from the Lambda layer
def square_fn(x):
return x * x
@keras_test
def test_saving_lambda_custom_objects():
input = Input(shape=(3,))
x = Lambda(lambda x: square_fn(x), output_shape=(3,))(input)
output = Dense(3)(x)
model = Model(input, output)
model.compile(loss=objectives.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
_, fname = tempfile.mkstemp('.h5')
save_model(model, fname)
model = load_model(fname, custom_objects={'square_fn': square_fn})
os.remove(fname)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
if __name__ == '__main__':
pytest.main([__file__])